329 lines
15 KiB
Python
329 lines
15 KiB
Python
"""Document service."""
|
||
|
||
import os
|
||
from pathlib import Path
|
||
from typing import List, Optional, Dict, Any
|
||
from sqlalchemy import select, func
|
||
from sqlalchemy.orm import Session
|
||
from fastapi import UploadFile
|
||
|
||
from ..models.knowledge_base import Document, KnowledgeBase
|
||
from ..core.config import get_settings
|
||
from utils.util_file import FileUtils
|
||
from .storage import storage_service
|
||
from .document_processor import get_document_processor
|
||
from utils.util_schemas import DocumentChunk
|
||
from loguru import logger
|
||
|
||
settings = get_settings()
|
||
|
||
|
||
class DocumentService:
|
||
"""Document service for managing documents in knowledge bases."""
|
||
|
||
def __init__(self, session: Session):
|
||
self.session = session
|
||
self.file_utils = FileUtils()
|
||
|
||
async def upload_document(self, file: UploadFile, kb_id: int) -> Document:
|
||
"""Upload a document to knowledge base."""
|
||
self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id}"
|
||
# Validate knowledge base exists
|
||
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
||
kb = await self.session.scalar(stmt)
|
||
if not kb:
|
||
self.session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||
raise ValueError(f"知识库 {kb_id} 不存在")
|
||
|
||
# Validate file
|
||
if not file.filename:
|
||
self.session.desc = f"ERROR: 上传文件时未提供文件名"
|
||
raise ValueError("No filename provided")
|
||
|
||
# Validate file extension
|
||
file_extension = Path(file.filename).suffix.lower()
|
||
if file_extension not in settings.file.allowed_extensions:
|
||
self.session.desc = f"ERROR: 非期望的文件类型 {file_extension}"
|
||
raise ValueError(f"非期望的文件类型 {file_extension}")
|
||
|
||
# Upload file using storage service
|
||
storage_info = await storage_service.upload_file(file, kb_id)
|
||
self.session.desc = f"文档 {file.filename} 上传到 {storage_info}"
|
||
|
||
# Create document record
|
||
document = Document(
|
||
knowledge_base_id=kb_id,
|
||
filename=os.path.basename(storage_info["file_path"]),
|
||
original_filename=file.filename,
|
||
file_path=storage_info.get("full_path", storage_info["file_path"]), # Use absolute path if available
|
||
file_size=storage_info["size"],
|
||
file_type=file_extension,
|
||
mime_type=storage_info["mime_type"],
|
||
is_processed=False
|
||
)
|
||
|
||
# Set audit fields
|
||
document.set_audit_fields()
|
||
|
||
self.session.add(document)
|
||
await self.session.commit()
|
||
await self.session.refresh(document)
|
||
|
||
self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})"
|
||
return document
|
||
|
||
async def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]:
|
||
"""根据文档ID查询文档,可选地根据知识库ID过滤。"""
|
||
self.session.desc = f"根据文档ID查询文档 {doc_id}"
|
||
stmt = select(Document).where(Document.id == doc_id)
|
||
if kb_id is not None:
|
||
stmt = stmt.where(Document.knowledge_base_id == kb_id)
|
||
return await self.session.scalar(stmt)
|
||
|
||
async def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]:
|
||
"""根据知识库ID查询文档,支持分页。"""
|
||
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
|
||
stmt = (
|
||
select(Document)
|
||
.where(Document.knowledge_base_id == kb_id)
|
||
.offset(skip)
|
||
.limit(limit)
|
||
)
|
||
return (await self.session.scalars(stmt)).all()
|
||
|
||
async def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]:
|
||
"""根据知识库ID查询文档,支持分页,并返回总文档数。"""
|
||
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
|
||
# Get total count
|
||
count_stmt = select(func.count(Document.id)).where(Document.knowledge_base_id == kb_id)
|
||
total = await self.session.scalar(count_stmt)
|
||
|
||
# Get documents with pagination
|
||
documents_stmt = (
|
||
select(Document)
|
||
.where(Document.knowledge_base_id == kb_id)
|
||
.offset(skip)
|
||
.limit(limit)
|
||
)
|
||
documents = (await self.session.scalars(documents_stmt)).all()
|
||
|
||
return documents, total
|
||
|
||
async def delete_document(self, doc_id: int, kb_id: int = None) -> bool:
|
||
"""根据文档ID删除文档,可选地根据知识库ID过滤。"""
|
||
self.session.desc = f"删除文档 {doc_id}"
|
||
document = await self.get_document(doc_id, kb_id)
|
||
if not document:
|
||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||
return False
|
||
|
||
# Delete file from storage
|
||
try:
|
||
await storage_service.delete_file(document.file_path)
|
||
self.session.desc = f"SUCCESS: 删除文档 {doc_id} 关联文件 {document.file_path}"
|
||
except Exception as e:
|
||
self.session.desc = f"EXCEPTION: 删除文档 {doc_id} 关联文件时失败: {e}"
|
||
|
||
# TODO: Remove from vector database
|
||
# This should be implemented when vector database service is ready
|
||
self.session.desc = f"从向量数据库删除文档 {doc_id}"
|
||
(await get_document_processor(self.session)).delete_document_from_vector_store(kb_id,doc_id)
|
||
# Delete database record
|
||
self.session.desc = f"删除数据库记录 {doc_id}"
|
||
await self.session.delete(document)
|
||
await self.session.commit()
|
||
self.session.desc = f"SUCCESS: 成功删除文档 {doc_id}"
|
||
return True
|
||
|
||
async def process_document(self, doc_id: int, kb_id: int = None) -> Dict[str, Any]:
|
||
"""处理文档,提取文本并创建嵌入向量。"""
|
||
try:
|
||
self.session.desc = f"处理文档 {doc_id} - 提取文本并创建嵌入向量"
|
||
document = await self.get_document(doc_id, kb_id)
|
||
self.session.desc = f"获取文档 {doc_id} >>> {document}"
|
||
if not document:
|
||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||
raise ValueError(f"Document {doc_id} not found")
|
||
|
||
# document.file_path[为('C:\\DrGraph\\TH_Backend\\data\\uploads\\kb_1\\997eccbb-9081-4ddf-879e-bc7d781fab50_答辩.txt',) ,需要取第一个元素
|
||
file_path = document.file_path
|
||
knowledge_base_id=document.knowledge_base_id
|
||
is_processed=document.is_processed
|
||
|
||
if is_processed:
|
||
self.session.desc = f"INFO: 文档 {doc_id} 已处理"
|
||
return {
|
||
"document_id": doc_id,
|
||
"status": "already_processed",
|
||
"message": "文档已处理"
|
||
}
|
||
|
||
self.session.desc = f"查询文档完毕 {doc_id} >>> is_processed = {is_processed}"
|
||
# 更新文档状态为处理中
|
||
document.processing_error = None
|
||
await self.session.commit()
|
||
self.session.desc = f"更新文档状态为处理中 {doc_id}"
|
||
|
||
# 调用文档处理器进行处理
|
||
document_processor = await get_document_processor(self.session)
|
||
self.session.desc = f"调用文档处理器进行处理=== {doc_id} >>> {document_processor}"
|
||
result = await document_processor.process_document(
|
||
session=self.session,
|
||
document_id=doc_id,
|
||
file_path=file_path,
|
||
knowledge_base_id=knowledge_base_id
|
||
)
|
||
self.session.desc = f"处理文档完毕 {doc_id}"
|
||
|
||
# 如果处理成功,更新文档状态
|
||
if result["status"] == "success":
|
||
document.is_processed = True
|
||
document.chunk_count = result.get("chunks_count", 0)
|
||
await self.session.commit()
|
||
await self.session.refresh(document)
|
||
logger.info(f"Processed document: {document.filename} (ID: {doc_id})")
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
await self.session.rollback()
|
||
self.session.desc = f"EXCEPTION: 处理文档 {doc_id} 时失败: {e}"
|
||
|
||
# Update document with error
|
||
try:
|
||
document = await self.get_document(doc_id, kb_id)
|
||
if document:
|
||
document.processing_error = str(e)
|
||
await self.session.commit()
|
||
except Exception as db_error:
|
||
logger.error(f"Failed to update document error status: {db_error}")
|
||
|
||
return {
|
||
"document_id": doc_id,
|
||
"status": "failed",
|
||
"error": str(e),
|
||
"message": "文档处理失败"
|
||
}
|
||
|
||
async def _extract_text(self, document: Document) -> str:
|
||
"""从文档文件中提取文本内容。"""
|
||
try:
|
||
if document.is_text_file:
|
||
# Read text files directly
|
||
with open(document.file_path, 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
|
||
elif document.is_pdf_file:
|
||
# TODO: Implement PDF text extraction using PyPDF2 or similar
|
||
# For now, return placeholder
|
||
return f"PDF content from {document.original_filename}"
|
||
|
||
elif document.is_office_file:
|
||
# TODO: Implement Office file text extraction using python-docx, openpyxl, etc.
|
||
# For now, return placeholder
|
||
return f"Office document content from {document.original_filename}"
|
||
|
||
else:
|
||
self.session.desc = f"ERROR: 不支持的文件类型: {document.file_type}"
|
||
raise ValueError(f"不支持的文件类型: {document.file_type}")
|
||
|
||
except Exception as e:
|
||
self.session.desc = f"EXCEPTION: 从文档 {document.file_path} 提取文本时失败: {e}"
|
||
raise
|
||
|
||
async def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool:
|
||
"""更新文档处理状态。"""
|
||
self.session.desc = f"更新文档 {doc_id} 处理状态为 {is_processed}"
|
||
document = await self.get_document(doc_id)
|
||
if not document:
|
||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||
return False
|
||
|
||
document.is_processed = is_processed
|
||
document.processing_error = error
|
||
|
||
await self.session.commit()
|
||
self.session.desc = f"SUCCESS: 更新文档 {doc_id} 处理状态为 {is_processed}"
|
||
return True
|
||
|
||
async def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]:
|
||
"""在知识库中搜索文档使用向量相似度。"""
|
||
try:
|
||
# 使用文档处理器进行相似性搜索
|
||
self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {limit}条"
|
||
results = (await get_document_processor(self.session)).search_similar_documents(kb_id, query, limit)
|
||
self.session.desc = f"SUCCESS: 搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {len(results)} 条结果"
|
||
return results
|
||
except Exception as e:
|
||
self.session.desc = f"EXCEPTION: 搜索知识库 {kb_id} 中的文档使用向量相似度时失败: {e}"
|
||
logger.error(f"查找知识库 {kb_id} 中的文档使用向量相似度时失败: {e}")
|
||
return []
|
||
|
||
async def get_document_stats(self, kb_id: int) -> Dict[str, Any]:
|
||
"""获取知识库中的文档统计信息。"""
|
||
documents = await self.get_documents(kb_id, limit=1000) # Get all documents
|
||
|
||
total_count = len(documents)
|
||
processed_count = len([doc for doc in documents if doc.is_processed])
|
||
total_size = sum(doc.file_size for doc in documents)
|
||
|
||
file_types = {}
|
||
for doc in documents:
|
||
file_type = doc.file_type
|
||
file_types[file_type] = file_types.get(file_type, 0) + 1
|
||
|
||
return {
|
||
"total_documents": total_count,
|
||
"processed_documents": processed_count,
|
||
"pending_documents": total_count - processed_count,
|
||
"total_size_bytes": total_size,
|
||
"total_size_mb": round(total_size / (1024 * 1024), 2),
|
||
"file_types": file_types
|
||
}
|
||
|
||
async def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]:
|
||
"""获取特定文档的文档块。"""
|
||
try:
|
||
self.session.desc = f"获取文档 {doc_id} 的文档块"
|
||
stmt = select(Document).where(Document.id == doc_id)
|
||
document = await self.session.scalar(stmt)
|
||
if not document:
|
||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||
return []
|
||
|
||
self.session.desc = f"获取文档 {doc_id} 的文档块 > document"
|
||
# Get chunks from document processor
|
||
try:
|
||
doc_processor = await get_document_processor(self.session)
|
||
chunks_data = doc_processor.get_document_chunks(document.knowledge_base_id, doc_id)
|
||
except Exception as e:
|
||
error_msg = f"获取文档处理器失败: {str(e)}"
|
||
self.session.desc = f"ERROR: {error_msg}"
|
||
logger.error(error_msg)
|
||
# 如果是因为嵌入模型未配置,返回空列表而不是抛出异常
|
||
if "未找到嵌入模型配置" in str(e) or "embeddings 未设置" in str(e):
|
||
return []
|
||
raise
|
||
|
||
self.session.desc = f"获取文档 {doc_id} 的文档块 > chunks_data"
|
||
# Convert to DocumentChunk objects
|
||
chunks = []
|
||
for chunk_data in chunks_data:
|
||
chunk = DocumentChunk(
|
||
id=chunk_data["id"],
|
||
content=chunk_data["content"],
|
||
metadata=chunk_data["metadata"],
|
||
page_number=chunk_data.get("page_number"),
|
||
chunk_index=chunk_data["chunk_index"],
|
||
start_char=chunk_data.get("start_char"),
|
||
end_char=chunk_data.get("end_char"),
|
||
vector_id=chunk_data.get("vector_id")
|
||
)
|
||
chunks.append(chunk)
|
||
|
||
self.session.desc = f"SUCCESS: 获取文档 {doc_id} 的文档块: {len(chunks)} 个"
|
||
return chunks
|
||
|
||
except Exception as e:
|
||
self.session.desc = f"EXCEPTION: 获取文档 {doc_id} 的文档块时失败: {e}"
|
||
return [] |