feat: 添加 Zhipu 嵌入模型支持并修复 Python 3.9 兼容性问题
- 添加 Zhipu (智谱AI) 嵌入模型支持 - 修复 Python 3.9 兼容性问题(anext -> async for) - 更新 README.md 添加项目介绍和前端开发指南 - 添加向量数据库配置文档
This commit is contained in:
parent
85d8f49b7a
commit
01070cd44d
|
|
@ -76,3 +76,7 @@
|
|||
# 启动项目命令
|
||||
python3 -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
5. 查看日志
|
||||
```bash
|
||||
tail -f /tmp/uvicorn.log
|
||||
```
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
# 本地向量数据库配置指南
|
||||
|
||||
## 当前配置
|
||||
|
||||
项目使用 **ChromaDB** 作为本地向量数据库,数据存储在 `./data/chroma/` 目录下。
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 1. 环境变量配置 (.env)
|
||||
|
||||
在 `.env` 文件中配置以下参数:
|
||||
|
||||
```env
|
||||
# 向量数据库类型(虽然代码中已使用 Chroma,但建议设置为 chroma)
|
||||
VECTOR_DB_TYPE=chroma
|
||||
|
||||
# 向量数据库存储路径(本地文件系统)
|
||||
# 相对路径会基于项目根目录
|
||||
VECTOR_DB_PERSIST_DIRECTORY=./data/chroma
|
||||
|
||||
# 集合名称(默认)
|
||||
VECTOR_DB_COLLECTION_NAME=documents
|
||||
```
|
||||
|
||||
### 2. 目录结构
|
||||
|
||||
向量数据库按知识库 ID 组织,每个知识库有独立的目录:
|
||||
|
||||
```
|
||||
data/chroma/
|
||||
├── kb_1/ # 知识库 1 的向量数据
|
||||
├── kb_2/ # 知识库 2 的向量数据
|
||||
├── kb_13/ # 知识库 13 的向量数据
|
||||
└── ...
|
||||
```
|
||||
|
||||
### 3. 使用方式
|
||||
|
||||
向量数据库会在以下场景自动创建:
|
||||
|
||||
1. **上传文档时**:如果上传时选择立即处理,会自动创建向量数据库
|
||||
2. **处理文档时**:调用 `POST /api/knowledge-bases/{kb_id}/documents/{doc_id}/process` 接口
|
||||
|
||||
### 4. 验证安装
|
||||
|
||||
运行以下命令验证向量数据库是否正常工作:
|
||||
|
||||
```bash
|
||||
python3 -c "
|
||||
import chromadb
|
||||
from pathlib import Path
|
||||
|
||||
# 测试创建本地 Chroma 数据库
|
||||
test_path = './data/chroma/test_kb'
|
||||
client = chromadb.PersistentClient(path=test_path)
|
||||
collection = client.get_or_create_collection(name='test')
|
||||
print('✅ ChromaDB 本地数据库创建成功')
|
||||
"
|
||||
```
|
||||
|
||||
### 5. 依赖包
|
||||
|
||||
已安装的依赖:
|
||||
- `chromadb>=1.0.20` - ChromaDB 核心库
|
||||
- `langchain-chroma>=0.1.0` - LangChain Chroma 集成
|
||||
|
||||
### 6. 注意事项
|
||||
|
||||
- 向量数据库数据存储在本地文件系统,无需额外服务
|
||||
- 每个知识库的向量数据独立存储
|
||||
- 删除知识库时,对应的向量数据目录也会被清理
|
||||
- 确保 `data/chroma/` 目录有写入权限
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 问题:向量数据库不存在
|
||||
|
||||
**原因**:文档尚未被处理和向量化
|
||||
|
||||
**解决**:
|
||||
1. 先调用处理文档接口:`POST /api/knowledge-bases/{kb_id}/documents/{doc_id}/process`
|
||||
2. 处理完成后,向量数据库会自动创建
|
||||
|
||||
### 问题:权限错误
|
||||
|
||||
**解决**:
|
||||
```bash
|
||||
chmod -R 755 data/chroma/
|
||||
```
|
||||
|
||||
### 问题:磁盘空间不足
|
||||
|
||||
**解决**:清理不需要的知识库向量数据,或扩展存储空间
|
||||
|
|
@ -2,6 +2,7 @@ from typing import Dict, Any, List, Optional
|
|||
import json
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from th_agenter.models.conversation import Conversation
|
||||
from th_agenter.models.message import Message
|
||||
from th_agenter.db.database import get_session
|
||||
|
|
@ -27,30 +28,34 @@ class ConversationContextService:
|
|||
新创建的对话ID
|
||||
"""
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
|
||||
conversation = Conversation(
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
session.add(conversation)
|
||||
await session.commit()
|
||||
await session.refresh(conversation)
|
||||
|
||||
# 初始化对话上下文
|
||||
self.context_cache[conversation.id] = {
|
||||
'conversation_id': conversation.id,
|
||||
'user_id': user_id,
|
||||
'file_list': [],
|
||||
'selected_files': [],
|
||||
'query_history': [],
|
||||
'created_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return conversation.id
|
||||
# Python 3.9 兼容:使用 async for 替代 anext
|
||||
async for session in get_session():
|
||||
try:
|
||||
conversation = Conversation(
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
session.add(conversation)
|
||||
await session.commit()
|
||||
await session.refresh(conversation)
|
||||
|
||||
# 初始化对话上下文
|
||||
self.context_cache[conversation.id] = {
|
||||
'conversation_id': conversation.id,
|
||||
'user_id': user_id,
|
||||
'file_list': [],
|
||||
'selected_files': [],
|
||||
'query_history': [],
|
||||
'created_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await session.close()
|
||||
return conversation.id
|
||||
finally:
|
||||
break # 只取第一个 session
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建对话失败: {e}")
|
||||
|
|
@ -74,52 +79,58 @@ class ConversationContextService:
|
|||
|
||||
# 从数据库加载
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
|
||||
conversation = session.query(Conversation).filter(
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
# 加载消息历史
|
||||
messages = session.query(Message).filter(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at).all()
|
||||
|
||||
# 重建上下文
|
||||
context = {
|
||||
'conversation_id': conversation_id,
|
||||
'user_id': conversation.user_id,
|
||||
'file_list': [],
|
||||
'selected_files': [],
|
||||
'query_history': [],
|
||||
'created_at': conversation.created_at.isoformat()
|
||||
}
|
||||
|
||||
# 从消息中提取查询历史
|
||||
for message in messages:
|
||||
if message.role == 'user':
|
||||
context['query_history'].append({
|
||||
'query': message.content,
|
||||
'timestamp': message.created_at.isoformat()
|
||||
})
|
||||
elif message.role == 'assistant' and message.metadata:
|
||||
# 从助手消息的元数据中提取文件信息
|
||||
try:
|
||||
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
||||
if 'selected_files' in metadata:
|
||||
context['selected_files'] = metadata['selected_files']
|
||||
if 'file_list' in metadata:
|
||||
context['file_list'] = metadata['file_list']
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 缓存上下文
|
||||
self.context_cache[conversation_id] = context
|
||||
|
||||
return context
|
||||
# Python 3.9 兼容:使用 async for 替代 anext
|
||||
async for session in get_session():
|
||||
try:
|
||||
conversation = await session.scalar(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
await session.close()
|
||||
return None
|
||||
|
||||
# 加载消息历史
|
||||
messages = await session.scalars(
|
||||
select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
|
||||
)
|
||||
messages_list = list(messages)
|
||||
|
||||
# 重建上下文
|
||||
context = {
|
||||
'conversation_id': conversation_id,
|
||||
'user_id': conversation.user_id,
|
||||
'file_list': [],
|
||||
'selected_files': [],
|
||||
'query_history': [],
|
||||
'created_at': conversation.created_at.isoformat()
|
||||
}
|
||||
|
||||
# 从消息中提取查询历史
|
||||
for message in messages_list:
|
||||
if message.role == 'user':
|
||||
context['query_history'].append({
|
||||
'query': message.content,
|
||||
'timestamp': message.created_at.isoformat()
|
||||
})
|
||||
elif message.role == 'assistant' and message.metadata:
|
||||
# 从助手消息的元数据中提取文件信息
|
||||
try:
|
||||
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
||||
if 'selected_files' in metadata:
|
||||
context['selected_files'] = metadata['selected_files']
|
||||
if 'file_list' in metadata:
|
||||
context['file_list'] = metadata['file_list']
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 缓存上下文
|
||||
self.context_cache[conversation_id] = context
|
||||
|
||||
await session.close()
|
||||
return context
|
||||
finally:
|
||||
break # 只取第一个 session
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取对话上下文失败: {e}")
|
||||
|
|
@ -194,29 +205,30 @@ class ConversationContextService:
|
|||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=json.dumps(metadata) if metadata else None,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
session.add(message)
|
||||
await session.commit()
|
||||
|
||||
# 更新对话的最后更新时间
|
||||
conversation = session.query(Conversation).filter(
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
|
||||
if conversation:
|
||||
conversation.updated_at = datetime.utcnow()
|
||||
# Python 3.9 兼容:使用 async for 替代 anext
|
||||
async for session in get_session():
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=json.dumps(metadata) if metadata else None,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
session.add(message)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
# 更新对话的最后更新时间
|
||||
conversation = await session.scalar(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
|
||||
if conversation:
|
||||
conversation.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
|
||||
await session.close()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存消息失败: {e}")
|
||||
|
|
@ -262,31 +274,33 @@ class ConversationContextService:
|
|||
消息历史列表
|
||||
"""
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
|
||||
messages = session.query(Message).filter(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at).all()
|
||||
|
||||
history = []
|
||||
for message in messages:
|
||||
msg_data = {
|
||||
'id': message.id,
|
||||
'role': message.role,
|
||||
'content': message.content,
|
||||
'timestamp': message.created_at.isoformat()
|
||||
}
|
||||
# Python 3.9 兼容:使用 async for 替代 anext
|
||||
async for session in get_session():
|
||||
messages = await session.scalars(
|
||||
select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
|
||||
)
|
||||
messages_list = list(messages)
|
||||
|
||||
if message.metadata:
|
||||
try:
|
||||
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
||||
msg_data['metadata'] = metadata
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
history = []
|
||||
for message in messages_list:
|
||||
msg_data = {
|
||||
'id': message.id,
|
||||
'role': message.role,
|
||||
'content': message.content,
|
||||
'timestamp': message.created_at.isoformat()
|
||||
}
|
||||
|
||||
if message.metadata:
|
||||
try:
|
||||
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
||||
msg_data['metadata'] = metadata
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
history.append(msg_data)
|
||||
|
||||
history.append(msg_data)
|
||||
|
||||
return history
|
||||
await session.close()
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取对话历史失败: {e}")
|
||||
|
|
|
|||
|
|
@ -84,9 +84,8 @@ class DocumentProcessor:
|
|||
config = None
|
||||
if session:
|
||||
config = await llm_config_service.get_default_embedding_config(session)
|
||||
if config:
|
||||
if(session != None):
|
||||
session.desc = f"获取默认嵌入模型配置: {config}"
|
||||
if config and session:
|
||||
session.desc = f"获取默认嵌入模型配置: {config}"
|
||||
# # 转换配置格式
|
||||
# config = {
|
||||
# "provider": config.provider,
|
||||
|
|
@ -96,39 +95,55 @@ class DocumentProcessor:
|
|||
|
||||
# 如果未找到配置,使用默认配置
|
||||
if not config:
|
||||
session.desc = f"ERROR: 未找到嵌入模型配置"
|
||||
if session:
|
||||
session.desc = f"ERROR: 未找到嵌入模型配置"
|
||||
raise HTTPException(status_code=400, detail="未找到嵌入模型配置")
|
||||
session.desc = f"获取嵌入模型配置 > 结果:{config}"
|
||||
if session:
|
||||
session.desc = f"获取嵌入模型配置 > 结果:{config}"
|
||||
|
||||
# 根据配置创建嵌入模型
|
||||
if config.provider == "openai":
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
self.embeddings = OpenAIEmbeddings(
|
||||
model=config.get("model", "text-embedding-3-small"),
|
||||
api_key=config.get("api_key")
|
||||
model=config.model_name or "text-embedding-3-small",
|
||||
api_key=config.api_key
|
||||
)
|
||||
session.desc = f"创建嵌入模型 - OpenAIEmbeddings(model={config.get('model', 'text-embedding-3-small')})"
|
||||
if session:
|
||||
session.desc = f"创建嵌入模型 - OpenAIEmbeddings(model={config.model_name or 'text-embedding-3-small'})"
|
||||
elif config.provider == "zhipu":
|
||||
from .zhipu_embeddings import ZhipuOpenAIEmbeddings
|
||||
self.embeddings = ZhipuOpenAIEmbeddings(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url or "https://open.bigmodel.cn/api/paas/v4",
|
||||
model=config.model_name or "embedding-3",
|
||||
dimensions=settings.vector_db.embedding_dimension
|
||||
)
|
||||
if session:
|
||||
session.desc = f"创建嵌入模型 - ZhipuOpenAIEmbeddings(model={config.model_name or 'embedding-3'}, base_url={config.base_url})"
|
||||
elif config.provider == "ollama":
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
self.embeddings = OllamaEmbeddings(
|
||||
model=config.model_name,
|
||||
base_url=config.base_url
|
||||
)
|
||||
session.desc = f"创建嵌入模型 - OllamaEmbeddings({self.embeddings.base_url} - {self.embeddings.model})"
|
||||
if session:
|
||||
session.desc = f"创建嵌入模型 - OllamaEmbeddings({self.embeddings.base_url} - {self.embeddings.model})"
|
||||
elif config.provider == "local":
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
self.embeddings = HuggingFaceEmbeddings(
|
||||
model_name=config.get("model", "sentence-transformers/all-MiniLM-L6-v2")
|
||||
model_name=config.model_name or "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
session.desc = f"创建嵌入模型 - HuggingFaceEmbeddings(model={config.get('model', 'sentence-transformers/all-MiniLM-L6-v2')})"
|
||||
if session:
|
||||
session.desc = f"创建嵌入模型 - HuggingFaceEmbeddings(model={config.model_name or 'sentence-transformers/all-MiniLM-L6-v2'})"
|
||||
else:
|
||||
# 默认使用OpenAI
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
self.embeddings = OpenAIEmbeddings(
|
||||
model=config.get("model", "text-embedding-3-small"),
|
||||
api_key=config.get("api_key")
|
||||
model=config.model_name or "text-embedding-3-small",
|
||||
api_key=config.api_key
|
||||
)
|
||||
session.desc = f"ERROR: 未支持的嵌入提供者: {config['provider']},已使用默认的 OpenAIEmbeddings - 可能不正确或无效"
|
||||
if session:
|
||||
session.desc = f"ERROR: 未支持的嵌入提供者: {config.provider},已使用默认的 OpenAIEmbeddings - 可能不正确或无效"
|
||||
|
||||
return self.embeddings
|
||||
except Exception as e:
|
||||
|
|
@ -388,17 +403,19 @@ class DocumentProcessor:
|
|||
self.add_documents_to_vector_store(session, knowledge_base_id, chunks, document_id)
|
||||
|
||||
# 4. 更新文档状态
|
||||
session = await anext(get_session())
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
|
||||
|
||||
if document:
|
||||
document.status = "processed"
|
||||
document.chunk_count = len(chunks)
|
||||
await session.commit()
|
||||
finally:
|
||||
await session.close()
|
||||
# Python 3.9 兼容:使用 async for 替代 anext
|
||||
async for db_session in get_session():
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
document = await db_session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
|
||||
|
||||
if document:
|
||||
document.is_processed = True
|
||||
document.chunk_count = len(chunks)
|
||||
await db_session.commit()
|
||||
finally:
|
||||
await db_session.close()
|
||||
break # 只取第一个 session
|
||||
|
||||
result = {
|
||||
"document_id": document_id,
|
||||
|
|
@ -416,16 +433,18 @@ class DocumentProcessor:
|
|||
|
||||
# 更新文档状态为失败
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
|
||||
if document:
|
||||
document.status = "failed"
|
||||
document.error_message = str(e)
|
||||
await session.commit()
|
||||
finally:
|
||||
await session.close()
|
||||
# Python 3.9 兼容:使用 async for 替代 anext
|
||||
async for db_session in get_session():
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
document = await db_session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
|
||||
if document:
|
||||
document.is_processed = False
|
||||
document.processing_error = str(e)
|
||||
await db_session.commit()
|
||||
finally:
|
||||
await db_session.close()
|
||||
break # 只取第一个 session
|
||||
except Exception as db_error:
|
||||
session.desc = f"ERROR: 更新文档状态失败: {str(db_error)}"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue