hyf-backend/th_agenter/services/knowledge_base.py

245 lines
9.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Knowledge base service."""
# Standard library imports
from typing import List, Optional, Dict, Any
# Third-party imports
from loguru import logger
from sqlalchemy import select, and_, or_
from sqlalchemy.orm import Session
# Local imports
from ..core.config import get_settings
from ..core.context import UserContext
from ..models.knowledge_base import KnowledgeBase
from .document_processor import get_document_processor
from utils.util_schemas import KnowledgeBaseCreate, KnowledgeBaseUpdate
settings = get_settings()
class KnowledgeBaseService:
"""知识库基础服务类,用于管理知识基础。
该服务类提供了创建、获取、更新、删除和搜索知识库基础的功能。
"""
def __init__(self, session: Session):
"""初始化知识库基础服务类。
Args:
session (Session): 数据库会话用于执行ORM操作。
"""
if session is None:
logger.error("session为空session must be an instance of Session")
self.session = session
async def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase:
"""创建一个新的知识库实例。
Args:
kb_data (KnowledgeBaseCreate): 用于创建知识库实例的数据。
Returns:
KnowledgeBase: 创建的知识库实例。
Raises:
Exception: 如果创建过程中发生错误。
"""
try:
# Generate collection name for vector database
collection_name = f"kb_{kb_data.name.lower().replace(' ', '_').replace('-', '_')}"
kb = KnowledgeBase(
name=kb_data.name,
description=kb_data.description,
embedding_model=kb_data.embedding_model,
chunk_size=kb_data.chunk_size,
chunk_overlap=kb_data.chunk_overlap,
vector_db_type=settings.vector_db.type,
collection_name=collection_name
)
# 自动更新created_by和updated_by字段
kb.set_audit_fields()
self.session.add(kb)
await self.session.commit()
await self.session.refresh(kb)
self.session.desc = f"Created knowledge base: {kb.name} - collection_name = {collection_name}, embedding_model = {kb.embedding_model}"
return kb
except Exception as e:
await self.session.rollback()
logger.error(f"Failed to create knowledge base: {str(e)}")
raise
async def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]:
"""Search knowledge bases by name or description for the current user.
Args:
query (str): Search query.
skip (int, optional): Number of records to skip. Defaults to 0.
limit (int, optional): Maximum number of records to return. Defaults to 50.
Returns:
List[KnowledgeBase]: List of matching knowledge bases.
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.created_by == UserContext.get_current_user()['id'],
KnowledgeBase.is_active == True,
or_(
KnowledgeBase.name.ilike(f"%{query}%"),
KnowledgeBase.description.ilike(f"%{query}%")
)
)
stmt = stmt.offset(skip).limit(limit)
return (await self.session.execute(stmt)).scalars().all()
async def search(self, kb_id: int, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]:
"""Search in knowledge base using vector similarity.
Args:
kb_id (int): ID of the knowledge base to search in.
query (str): Search query.
top_k (int, optional): Maximum number of results to return. Defaults to 5.
similarity_threshold (float, optional): Minimum similarity score for results. Defaults to 0.7.
Returns:
List[Dict[str, Any]]: List of search results with content, source, score, and metadata.
"""
try:
logger.info(f"Searching in knowledge base {kb_id} for: {query}")
# Use document processor for vector search
search_results = (await get_document_processor(self.session)).search_similar_documents(
knowledge_base_id=kb_id,
query=query,
k=top_k
)
# Filter by similarity threshold
filtered_results = []
for result in search_results:
# Use already normalized similarity score
normalized_score = result.get('normalized_score', 0)
if normalized_score >= similarity_threshold:
filtered_results.append({
"content": result.get('content', ''),
"source": result.get('source', 'unknown'),
"score": normalized_score,
"metadata": result.get('metadata', {}),
"document_id": result.get('document_id', 'unknown'),
"chunk_id": result.get('chunk_id', 'unknown')
})
logger.info(f"Found {len(filtered_results)} relevant documents (threshold: {similarity_threshold})")
return filtered_results
except Exception as e:
logger.error(f"Search failed for knowledge base {kb_id}: {str(e)}")
return []
# ----------------------------------------------------------------------------------
async def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]:
"""根据名称获取当前用户的知识库实例。
Args:
name (str): 知识库实例的名称。
Returns:
Optional[KnowledgeBase]: 如果找到则返回知识库实例否则返回None。
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.name == name,
KnowledgeBase.created_by == UserContext.get_current_user()['id']
)
result = (await self.session.execute(stmt)).scalar_one_or_none()
return result
async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]:
"""获取当前用户的所有知识库的列表。
Args:
skip (int, optional): 跳过的记录数。默认值为0。
limit (int, optional): 返回的最大记录数。默认值为50。
active_only (bool, optional): 是否仅返回活动的知识库。默认值为True。
Returns:
List[KnowledgeBase]: 当前用户的知识库列表。
"""
stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user()['id']) # 使用字典键索引访问用户ID
if active_only:
stmt = stmt.where(KnowledgeBase.is_active == True)
stmt = stmt.offset(skip).limit(limit)
return (await self.session.execute(stmt)).scalars().all()
async def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]:
"""根据ID获取知识库实例。
Args:
kb_id (int): 知识库实例的ID。
Returns:
Optional[KnowledgeBase]: 如果找到则返回知识库实例否则返回None。
"""
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
return (await self.session.execute(stmt)).scalar_one_or_none()
async def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]:
"""更新知识库实例。
Args:
kb_id (int): 待更新的知识库实例ID。
kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据。
Returns:
Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例否则返回None。
Raises:
Exception: 如果更新过程中发生错误。
"""
kb = await self.get_knowledge_base(kb_id)
if not kb:
return None
# Update fields
update_data = kb_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(kb, field, value)
# Set audit fields
kb.set_audit_fields(is_update=True)
await self.session.commit()
await self.session.refresh(kb)
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})"
return kb
async def delete_knowledge_base(self, kb_id: int) -> bool:
"""删除知识库实例。
Args:
kb_id (int): 待删除的知识库实例ID。
Returns:
bool: 如果知识库实例被成功删除则返回True否则返回False。
Raises:
Exception: 如果删除过程中发生错误。
"""
kb = await self.get_knowledge_base(kb_id)
if not kb:
return False
# TODO: Clean up vector database collection
# This should be implemented when vector database service is ready
await self.session.delete(kb)
await self.session.commit()
return True