hyf-backend/th_agenter/services/chat.py

329 lines
15 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
"""Chat service for AI model integration using LangChain."""
from th_agenter import db
import json
import asyncio
import os
from typing import AsyncGenerator, Optional, List, Dict, Any, TypedDict
from sqlalchemy.orm import Session
from loguru import logger
from th_agenter.core.new_agent import new_agent, new_llm
from ..core.config import settings
from ..models.message import MessageRole
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
from utils.util_exceptions import ChatServiceError, HxfResponse, OpenAIError
from .conversation import ConversationService
from .langchain_chat import LangChainChatService
try:
from .knowledge_chat import KnowledgeChatService
except ModuleNotFoundError as e:
KnowledgeChatService = None # 需 pip install langchain-chroma
from .agent.agent_service import get_agent_service
from .agent.langgraph_agent_service import get_langgraph_agent_service
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
class AgentState(TypedDict):
messages: List[dict] # 存储对话消息(核心记忆)
class ChatService:
"""Service for handling AI chat functionality using LangChain."""
_checkpointer_initialized = False
_conn_string = None
async def chat(
self,
conversation_id: int,
message: str,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
use_agent: bool = False,
use_langgraph: bool = False,
use_knowledge_base: bool = False,
knowledge_base_id: Optional[int] = None
) -> ChatResponse:
"""Send a message and get AI response using LangChain, Agent, or Knowledge Base."""
if use_knowledge_base and knowledge_base_id:
if not self.knowledge_chat_service:
raise ChatServiceError("知识库功能需要安装: pip install langchain-chroma")
logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
return await self.knowledge_chat_service.chat_with_knowledge_base(
conversation_id=conversation_id,
message=message,
knowledge_base_id=knowledge_base_id,
stream=stream,
temperature=temperature,
max_tokens=max_tokens
)
elif use_langgraph:
logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent")
# Get conversation history for LangGraph agent
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
messages = await self.conversation_service.get_conversation_messages(conversation_id)
chat_history = [{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content
} for msg in messages]
# Use LangGraph agent service
agent_result = await self.langgraph_agent_service.chat(message, chat_history)
if agent_result["success"]:
# Save user message
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Save assistant response
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
message_metadata={"intermediate_steps": agent_result["intermediate_steps"]}
)
return ChatResponse(
message=MessageResponse(
id=assistant_message.id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
conversation_id=conversation_id,
created_at=assistant_message.created_at,
metadata=assistant_message.metadata
)
)
else:
raise ChatServiceError(f"LangGraph Agent error: {agent_result.get('error', 'Unknown error')}")
elif use_agent:
logger.info(f"Processing chat request for conversation {conversation_id} via Agent")
# Get conversation history for agent
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
messages = await self.conversation_service.get_conversation_messages(conversation_id)
chat_history = [{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content
} for msg in messages]
# Use agent service
agent_result = await self.agent_service.chat(message, chat_history)
if agent_result["success"]:
# Save user message
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Save assistant response
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
message_metadata={"tool_calls": agent_result["tool_calls"]}
)
return ChatResponse(
message=MessageResponse(
id=assistant_message.id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
conversation_id=conversation_id,
created_at=assistant_message.created_at,
metadata=assistant_message.metadata
)
)
else:
raise ChatServiceError(f"Agent error: {agent_result.get('error', 'Unknown error')}")
else:
logger.info(f"Processing chat request for conversation {conversation_id} via LangChain")
# Delegate to LangChain service
return await self.langchain_chat_service.chat(
conversation_id=conversation_id,
message=message,
stream=stream,
temperature=temperature,
max_tokens=max_tokens
)
async def get_available_models(self) -> List[str]:
"""Get list of available models from LangChain."""
logger.info("Getting available models via LangChain")
# Delegate to LangChain service
return await self.langchain_chat_service.get_available_models()
def update_model_config(
self,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
):
"""Update LLM configuration via LangChain."""
logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}")
# Delegate to LangChain service
self.langchain_chat_service.update_model_config(
model=model,
temperature=temperature,
max_tokens=max_tokens
)
# -------------------------------------------------------------------------
def __init__(self, session: Session):
self.session = session
self.knowledge_chat_service = KnowledgeChatService(session) if KnowledgeChatService else None
async def initialize(self, conversation_id: int, streaming: bool = False):
self.conversation_service = ConversationService(self.session)
self.session.desc = "ChatService初始化 - ConversationService 实例化完毕"
self.conversation = await self.conversation_service.get_conversation(
conversation_id=conversation_id
)
if not self.conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
if not ChatService._checkpointer_initialized:
from langgraph.checkpoint.postgres import PostgresSaver
import psycopg2
import re
# LangGraph 使用的 Postgres 连接串psycopg 格式postgresql://
# 优先级LANGGRAPH_PG_URL > 从 DATABASE_URL 派生 > 默认 localhost
conn_string = os.getenv("LANGGRAPH_PG_URL")
if not conn_string:
db_url = os.getenv("DATABASE_URL", "")
if db_url and "postgresql" in db_url.split("://")[0].lower():
# 将 postgresql+asyncpg:// 转为 postgresql://,供 LangGraph/psycopg 使用
conn_string = re.sub(
r"^postgresql\+[a-zA-Z0-9]+://",
"postgresql://",
db_url,
count=1,
)
else:
conn_string = "postgresql://drgraph:yingping@localhost:5433/th_agenter"
ChatService._conn_string = conn_string
2026-01-21 13:45:39 +08:00
# 检查必要的表是否已存在
tables_need_setup = True
try:
# 连接到数据库并检查表是否存在
conn = psycopg2.connect(conn_string)
2026-01-21 13:45:39 +08:00
cursor = conn.cursor()
# 检查langgraph需要的表是否存在
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name IN ('checkpoints', 'checkpoint_writes', 'checkpoint_blobs')
""")
existing_tables = [row[0] for row in cursor.fetchall()]
# 检查是否所有必要的表都存在
required_tables = ['checkpoints', 'checkpoint_writes', 'checkpoint_blobs']
if all(table in existing_tables for table in required_tables):
tables_need_setup = False
self.session.desc = "ChatService初始化 - 检测到langgraph表已存在跳过setup"
cursor.close()
conn.close()
except Exception as e:
self.session.desc = f"ChatService初始化 - checkpoint失败: {str(e)}将进行setup"
tables_need_setup = True
# 只有在需要时才进行setup
if tables_need_setup:
self.session.desc = "ChatService初始化 - 正在进行PostgresSaver setup"
try:
async with AsyncPostgresSaver.from_conn_string(conn_string) as checkpointer:
2026-01-21 13:45:39 +08:00
await checkpointer.setup()
self.session.desc = "ChatService初始化 - PostgresSaver setup完成"
logger.info("PostgresSaver setup完成")
except Exception as e:
self.session.desc = f"ChatService初始化 - PostgresSaver setup失败: {str(e)}"
logger.error(f"PostgresSaver setup失败: {e}")
raise
else:
self.session.desc = "ChatService初始化 - 使用现有的langgraph表"
# 存储连接字符串供后续使用
ChatService._checkpointer_initialized = True
self.llm = await new_llm(session=self.session, streaming=streaming)
self.session.desc = f"ChatService初始化 - 获取对话实例完毕 > {self.conversation}"
def get_config(self):
config = {
"configurable": {
"thread_id": str(self.conversation.id),
"checkpoint_ns": "drgraph"
}
}
return config
async def chat_stream(
self,
message: str
) -> AsyncGenerator[str, None]:
"""Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base."""
self.session.desc = f"ChatService - 发送消息 {message} >>> 流式对话请求,会话 ID: {self.conversation.id}"
await self.conversation_service.add_message(
conversation_id=self.conversation.id,
role=MessageRole.USER,
content=message
)
full_assistant_content = ""
async with AsyncPostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage
agent = create_react_agent(self.llm, [], checkpointer=checkpointer)
async for chunk in agent.astream(
{"messages": [HumanMessage(content=message)]},
config=self.get_config(),
stream_mode="messages"
):
part = chunk[0].content if hasattr(chunk[0], "content") else str(chunk[0])
full_assistant_content += part
json_result = {"data": {"v": part}}
yield json.dumps(
json_result,
ensure_ascii=True
)
if len(full_assistant_content) > 0:
await self.conversation_service.add_message(
conversation_id=self.conversation.id,
role=MessageRole.ASSISTANT,
content=full_assistant_content
)
def get_conversation_history_messages(
self, conversation_id: int, skip: int = 0, limit: int = 100
):
"""Get conversation history messages with pagination."""
result = []
with PostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
checkpoints = checkpointer.list(self.get_config())
for checkpoint in checkpoints:
print(checkpoint)
result.append(checkpoint.messages)
return result