hyf-backend/th_agenter/services/langchain_chat.py

397 lines
16 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
"""LangChain-based chat service."""
import json
import asyncio
import os
from typing import AsyncGenerator, Optional, List, Dict, Any
from sqlalchemy.orm import Session
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from ..core.config import settings
from ..models.message import MessageRole
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
from utils.util_exceptions import ChatServiceError, OpenAIError, AuthenticationError, RateLimitError
from loguru import logger
from .conversation import ConversationService
class StreamingCallbackHandler(BaseCallbackHandler):
"""Custom callback handler for streaming responses."""
def __init__(self):
self.tokens = []
def on_llm_new_token(self, token: str, **kwargs) -> None:
"""Handle new token from LLM."""
self.tokens.append(token)
def get_response(self) -> str:
"""Get the complete response."""
return "".join(self.tokens)
def clear(self):
"""Clear the tokens."""
self.tokens = []
class LangChainChatService:
"""LangChain-based chat service for AI model integration."""
def __init__(self, session: Session):
self.session = session
self.conversation_service = ConversationService(session)
async def initialize(self):
from ..core.new_agent import new_agent
# Initialize LangChain ChatOpenAI
self.llm = await new_agent(self.session, streaming=False)
self.session.desc = "LangChainChatService初始化 - llm 实例化完毕"
# Streaming LLM for stream responses
self.streaming_llm = await new_agent(self.session, streaming=True)
self.session.desc = "LangChainChatService初始化 - streaming_llm 实例化完毕"
self.streaming_handler = StreamingCallbackHandler()
self.session.desc = "LangChainChatService初始化 - streaming_handler 实例化完毕"
def _prepare_langchain_messages(self, conversation, history: List) -> List:
"""Prepare messages for LangChain format."""
messages = []
# Add system message if conversation has system prompt
if hasattr(conversation, 'system_prompt') and conversation.system_prompt:
messages.append(SystemMessage(content=conversation.system_prompt))
else:
# Default system message
messages.append(SystemMessage(
content="You are a helpful AI assistant. Please provide accurate and helpful responses."
))
# Add conversation history
for msg in history[:-1]: # Exclude the last message (current user message)
if msg.role == MessageRole.USER:
messages.append(HumanMessage(content=msg.content))
elif msg.role == MessageRole.ASSISTANT:
messages.append(AIMessage(content=msg.content))
# Add current user message
if history:
last_msg = history[-1]
if last_msg.role == MessageRole.USER:
messages.append(HumanMessage(content=last_msg.content))
return messages
async def chat(
self,
conversation_id: int,
message: str,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> ChatResponse:
"""Send a message and get AI response using LangChain."""
logger.info(f"Processing LangChain chat request for conversation {conversation_id}")
try:
# Get conversation details
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError("Conversation not found")
# Add user message to database
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Get conversation history for context
history = await self.conversation_service.get_conversation_history(
conversation_id, limit=20
)
# Prepare messages for LangChain
langchain_messages = self._prepare_langchain_messages(conversation, history)
# Update LLM parameters if provided
llm_to_use = self.llm
if temperature is not None or max_tokens is not None:
llm_config = await settings.llm.get_current_config()
llm_to_use = ChatOpenAI(
model=llm_config["model"],
openai_api_key=llm_config["api_key"],
openai_api_base=llm_config["base_url"],
temperature=temperature if temperature is not None else float(conversation.temperature),
max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens,
streaming=False
)
# Call LangChain LLM
response = await llm_to_use.ainvoke(langchain_messages)
# Extract response content
assistant_content = response.content
# Add assistant message to database
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=assistant_content,
role=MessageRole.ASSISTANT,
message_metadata={
"model": llm_to_use.model_name,
"langchain_version": "0.1.0",
"provider": "langchain_openai"
}
)
# Update conversation timestamp
await self.conversation_service.update_conversation_timestamp(conversation_id)
logger.info(f"Successfully processed LangChain chat request for conversation {conversation_id}")
return ChatResponse(
user_message=MessageResponse.from_orm(user_message),
assistant_message=MessageResponse.from_orm(assistant_message),
total_tokens=None, # LangChain doesn't provide token count by default
model_used=llm_to_use.model_name
)
except Exception as e:
logger.error(f"Failed to process LangChain chat request for conversation {conversation_id}: {str(e)}", exc_info=True)
# Classify error types for better handling
error_type = type(e).__name__
error_message = self._format_error_message(e)
# Add error message to database
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=error_message,
role=MessageRole.ASSISTANT,
message_metadata={
"error": True,
"error_type": error_type,
"original_error": str(e),
"langchain_error": True
}
)
# Re-raise specific exceptions for proper error handling
if "rate limit" in str(e).lower():
raise RateLimitError(str(e))
elif "api key" in str(e).lower() or "authentication" in str(e).lower():
raise AuthenticationError(str(e))
elif "openai" in str(e).lower():
raise OpenAIError(str(e))
return ChatResponse(
user_message=MessageResponse.from_orm(user_message),
assistant_message=MessageResponse.from_orm(assistant_message),
total_tokens=0,
model_used=self.llm.model_name
)
async def chat_stream(
self,
conversation_id: int,
message: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> AsyncGenerator[str, None]:
"""Send a message and get streaming AI response using LangChain."""
logger.info(f"通过 LangChain 进行流式处理对话 请求,会话 ID: {conversation_id}")
try:
# Get conversation details
conversation = await self.conversation_service.get_conversation(conversation_id)
conv = conversation.to_dict()
if not conversation:
raise ChatServiceError("Conversation not found")
# Add user message to database
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Get conversation history for context
history = await self.conversation_service.get_conversation_history(
conversation_id, limit=20
)
# Prepare messages for LangChain
langchain_messages = self._prepare_langchain_messages(conv, history)
# Update streaming LLM parameters if provided
streaming_llm_to_use = self.streaming_llm
if temperature is not None or max_tokens is not None:
llm_config = await settings.llm.get_current_config()
streaming_llm_to_use = ChatOpenAI(
model=llm_config["model"],
openai_api_key=llm_config["api_key"],
openai_api_base=llm_config["base_url"],
temperature=temperature if temperature is not None else float(conversation.temperature),
max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens,
streaming=True
)
# Clear previous streaming handler state
self.streaming_handler.clear()
# Stream response
full_response = ""
try:
async for chunk in streaming_llm_to_use._astream(langchain_messages):
# Handle different chunk types to avoid KeyError
chunk_content = None
if hasattr(chunk, 'content'):
# For object-like chunks with content attribute
chunk_content = chunk.content
elif isinstance(chunk, dict) and 'content' in chunk:
# For dict-like chunks with content key
chunk_content = chunk['content']
elif isinstance(chunk, dict) and 'error' in chunk:
# Handle error chunks explicitly
logger.error(f"Error in LLM response: {chunk['error']}")
yield self._format_error_message(Exception(chunk['error']))
continue
if chunk_content:
full_response += chunk_content
yield chunk_content
except Exception as e:
logger.error(f"Error in LLM streaming: {e}")
yield f"{self._format_error_message(e)} >>> {e}"
# Add complete assistant message to database
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=full_response,
role=MessageRole.ASSISTANT,
message_metadata={
"model": streaming_llm_to_use.model_name,
"langchain_version": "0.1.0",
"provider": "langchain_openai",
"streaming": True
}
)
# Update conversation timestamp
await self.conversation_service.update_conversation_timestamp(conversation_id)
logger.info(f"完成 LangChain 流式处理对话,会话 ID: {conversation_id}")
except Exception as e:
# 安全地格式化异常信息避免再次引发KeyError
error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id} >>> {e}"
logger.error(error_info, exc_info=True)
# Format error message for user
error_message = self._format_error_message(e)
yield error_message
# Add error message to database
await self.conversation_service.add_message(
conversation_id=conversation_id,
content=error_message,
role=MessageRole.ASSISTANT,
message_metadata={
"error": True,
"error_type": type(e).__name__,
"original_error": str(e),
"langchain_error": True,
"streaming": True
}
)
async def get_available_models(self) -> List[str]:
"""Get list of available models from LangChain."""
try:
# LangChain doesn't have a direct method to list models
# Return commonly available OpenAI models
return [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-turbo-preview",
"gpt-4o",
"gpt-4o-mini"
]
except Exception as e:
logger.error(f"Failed to get available models: {str(e)}")
return ["gpt-3.5-turbo"]
async def update_model_config(
self,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
):
"""Update LLM configuration."""
from ..core.new_agent import new_agent
# 重新创建LLM实例
self.llm = await new_agent(
model=model,
temperature=temperature,
streaming=False
)
self.streaming_llm = await new_agent(
model=model,
temperature=temperature,
streaming=True
)
logger.info(f"Updated LLM configuration: model={model}, temperature={temperature}, max_tokens={max_tokens}")
def _format_error_message(self, error: Exception) -> str:
"""Format error message for user display."""
error_type = type(error).__name__
error_str = str(error)
# Provide user-friendly error messages
if "rate limit" in error_str.lower():
return "服务器繁忙,请稍后再试。"
elif "api key" in error_str.lower() or "authentication" in error_str.lower():
return f"API认证失败请检查配置文件。"
elif "timeout" in error_str.lower():
return "请求超时,请重试。"
elif "connection" in error_str.lower():
return "网络连接错误,请检查网络连接。"
elif "model" in error_str.lower() and "not found" in error_str.lower():
return "指定的模型不可用,请选择其他模型。"
else:
return f"处理请求时发生错误:{error_str}"
async def _retry_with_backoff(self, func, max_retries: int = 3, base_delay: float = 1.0):
"""Retry function with exponential backoff."""
for attempt in range(max_retries):
try:
return await func()
except Exception as e:
if attempt == max_retries - 1:
raise e
# Check if error is retryable
if not self._is_retryable_error(e):
raise e
delay = base_delay * (2 ** attempt)
logger.warning(f"Attempt {attempt + 1} failed, retrying in {delay}s: {str(e)}")
await asyncio.sleep(delay)
def _is_retryable_error(self, error: Exception) -> bool:
"""Check if an error is retryable."""
error_str = str(error).lower()
retryable_errors = [
"timeout",
"connection",
"server error",
"internal error",
"rate limit"
]
return any(err in error_str for err in retryable_errors)