hyf-backend/th_agenter/services/conversation.py

295 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Conversation service."""
from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import select, desc, func, or_
from langchain_core.messages import HumanMessage, AIMessage
from th_agenter.db.database import AsyncSessionFactory
from ..models.conversation import Conversation
from ..models.message import Message, MessageRole
from utils.util_schemas import ConversationCreate, ConversationUpdate
from utils.util_exceptions import ConversationNotFoundError, DatabaseError
from ..core.context import UserContext
from datetime import datetime, timezone
from loguru import logger
class ConversationService:
"""Service for managing conversations and messages."""
def __init__(self, session: Session):
self.session = session
async def create_conversation(
self,
user_id: int,
conversation_data: ConversationCreate
) -> Conversation:
"""Create a new conversation."""
self.session.desc = f"创建新会话 - 用户ID: {user_id},会话数据: {conversation_data}"
try:
conversation = Conversation(
**conversation_data.model_dump(),
user_id=user_id
)
# Set audit fields
conversation.set_audit_fields(user_id=user_id, is_update=False)
self.session.add(conversation)
await self.session.commit()
await self.session.refresh(conversation)
self.session.desc = f"创建新会话 Conversation ID: {conversation.id}用户ID: {user_id}"
return conversation
except Exception as e:
self.session.desc = f"ERROR: 创建会话失败 - 用户ID: {user_id},错误: {str(e)}"
await self.session.rollback()
raise DatabaseError(f"创建会话失败: {str(e)}")
async def get_conversation(self, conversation_id: int) -> Optional[Conversation]:
"""Get a conversation by ID."""
try:
user_id = UserContext.get_current_user_id()
self.session.desc = f"获取会话 - 会话ID: {conversation_id}用户ID: {user_id}"
if user_id is None:
logger.error(f"Failed to get conversation {conversation_id}: No user context available")
return None
conversation = await self.session.scalar(
select(Conversation).where(
Conversation.id == conversation_id,
Conversation.user_id == user_id
)
)
if not conversation:
self.session.desc = f"警告: 会话 {conversation_id} 不存在用户ID: {user_id}"
return conversation
except Exception as e:
self.session.desc = f"ERROR: 获取会话失败 - 会话ID: {conversation_id}用户ID: {user_id},错误: {str(e)}"
raise DatabaseError(f"Failed to get conversation: {str(e)}")
async def get_user_conversations(
self,
skip: int = 0,
limit: int = 50,
search_query: Optional[str] = None,
include_archived: bool = False,
order_by: str = "updated_at",
order_desc: bool = True
) -> List[Conversation]:
"""Get user's conversations with search and filtering."""
user_id = UserContext.get_current_user_id()
if user_id is None:
logger.error("Failed to get user conversations: No user context available")
return []
query = select(Conversation).where(
Conversation.user_id == user_id
)
# Filter archived conversations
if not include_archived:
query = query.where(Conversation.is_archived == False)
# Search functionality
if search_query and search_query.strip():
search_term = f"%{search_query.strip()}%"
query = query.where(
or_(
Conversation.title.ilike(search_term),
Conversation.system_prompt.ilike(search_term)
)
)
# Ordering
order_column = getattr(Conversation, order_by, Conversation.updated_at)
if order_desc:
query = query.order_by(desc(order_column))
else:
query = query.order_by(order_column)
return (await self.session.scalars(query.offset(skip).limit(limit))).all()
async def update_conversation(
self,
conversation_id: int,
conversation_update: ConversationUpdate
) -> Optional[Conversation]:
"""Update a conversation."""
conversation = await self.get_conversation(conversation_id)
if not conversation:
return None
update_data = conversation_update.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(conversation, field, value)
# Update audit fields
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
try:
await self.session.commit()
await self.session.refresh(conversation)
return conversation
except Exception as e:
logger.error(f"Failed to update conversation {conversation_id}: {str(e)}", exc_info=True)
await self.session.rollback()
raise DatabaseError(f"Failed to update conversation: {str(e)}")
async def delete_conversation(self, conversation_id: int) -> bool:
"""Delete a conversation."""
conversation = await self.get_conversation(conversation_id)
if not conversation:
return False
await self.session.delete(conversation)
await self.session.commit()
return True
async def get_conversation_messages(
self,
conversation_id: int,
skip: int = 0,
limit: int = 100
) -> List[Message]:
"""Get messages from a conversation."""
return (await self.session.scalars(
select(Message).where(
Message.conversation_id == conversation_id
).order_by(Message.created_at).offset(skip).limit(limit)
)).all()
async def add_message(
self,
conversation_id: int,
content: str,
role: MessageRole,
message_metadata: Optional[dict] = None,
context_documents: Optional[list] = None,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None
) -> Message:
"""Add a message to a conversation."""
message = Message(
conversation_id=conversation_id,
content=content,
role=role,
message_metadata=message_metadata,
context_documents=context_documents,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
)
# Set audit fields
message.set_audit_fields()
session = AsyncSessionFactory()
session.begin()
try:
session.add(message)
await session.commit()
await session.refresh(message)
# Update conversation's updated_at timestamp
conversation = await self.get_conversation(conversation_id)
if conversation:
conversation.updated_at = datetime.now(timezone.utc)
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
await session.commit()
except Exception as e:
logger.error(f"Failed to add message to conversation {conversation_id}: {str(e)}", exc_info=True)
await session.rollback()
finally:
await session.close()
return message
async def get_conversation_history_messages(
self,
conversation_id: int,
limit: int = 20
) -> List[Message]:
"""Get recent conversation history messages."""
history = await self.get_conversation_history(conversation_id, limit)
history_messages = []
for message in history:
if message.role == MessageRole.USER:
history_messages.append(HumanMessage(content=message.content))
elif message.role == MessageRole.ASSISTANT:
history_messages.append(AIMessage(content=message.content))
return history_messages
async def get_conversation_history(
self,
conversation_id: int,
limit: int = 20
) -> List[Message]:
"""Get recent conversation history for context."""
return (await self.session.scalars(
select(Message).where(
Message.conversation_id == conversation_id
).order_by(desc(Message.created_at)).limit(limit)
)).all()[::-1] # Reverse to get chronological order
async def update_conversation_timestamp(self, conversation_id: int) -> None:
"""Update conversation's updated_at timestamp."""
conversation = await self.get_conversation(conversation_id)
if conversation:
conversation.updated_at = datetime.now(timezone.utc)
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
await self.session.commit()
async def get_user_conversations_count(
self,
search_query: Optional[str] = None,
include_archived: bool = False
) -> int:
"""Get total count of user's conversations."""
user_id = UserContext.get_current_user_id()
query = select(func.count(Conversation.id)).where(
Conversation.user_id == user_id
)
if not include_archived:
query = query.where(Conversation.is_archived == False)
if search_query and search_query.strip():
search_term = f"%{search_query.strip()}%"
query = query.where(
or_(
Conversation.title.ilike(search_term),
Conversation.system_prompt.ilike(search_term)
)
)
return (await self.session.scalar(query)) or 0
async def archive_conversation(self, conversation_id: int) -> bool:
"""Archive a conversation."""
conversation = await self.get_conversation(conversation_id)
if not conversation:
return False
conversation.is_archived = True
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
await self.session.commit()
return True
async def unarchive_conversation(self, conversation_id: int) -> bool:
"""Unarchive a conversation."""
conversation = await self.get_conversation(conversation_id)
if not conversation:
return False
conversation.is_archived = False
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
await self.session.commit()
return True