"""Conversation service.""" from typing import List, Optional from sqlalchemy.orm import Session from sqlalchemy import select, desc, func, or_ from langchain_core.messages import HumanMessage, AIMessage from th_agenter.db.database import AsyncSessionFactory from ..models.conversation import Conversation from ..models.message import Message, MessageRole from utils.util_schemas import ConversationCreate, ConversationUpdate from utils.util_exceptions import ConversationNotFoundError, DatabaseError from ..core.context import UserContext from datetime import datetime, timezone from loguru import logger class ConversationService: """Service for managing conversations and messages.""" def __init__(self, session: Session): self.session = session async def create_conversation( self, user_id: int, conversation_data: ConversationCreate ) -> Conversation: """Create a new conversation.""" self.session.desc = f"创建新会话 - 用户ID: {user_id},会话数据: {conversation_data}" try: conversation = Conversation( **conversation_data.model_dump(), user_id=user_id ) # Set audit fields conversation.set_audit_fields(user_id=user_id, is_update=False) self.session.add(conversation) await self.session.commit() await self.session.refresh(conversation) self.session.desc = f"创建新会话 Conversation ID: {conversation.id},用户ID: {user_id}" return conversation except Exception as e: self.session.desc = f"ERROR: 创建会话失败 - 用户ID: {user_id},错误: {str(e)}" await self.session.rollback() raise DatabaseError(f"创建会话失败: {str(e)}") async def get_conversation(self, conversation_id: int) -> Optional[Conversation]: """Get a conversation by ID.""" try: user_id = UserContext.get_current_user_id() self.session.desc = f"获取会话 - 会话ID: {conversation_id},用户ID: {user_id}" if user_id is None: logger.error(f"Failed to get conversation {conversation_id}: No user context available") return None conversation = await self.session.scalar( select(Conversation).where( Conversation.id == conversation_id, Conversation.user_id == user_id ) ) if not conversation: self.session.desc = f"警告: 会话 {conversation_id} 不存在,用户ID: {user_id}" return conversation except Exception as e: self.session.desc = f"ERROR: 获取会话失败 - 会话ID: {conversation_id},用户ID: {user_id},错误: {str(e)}" raise DatabaseError(f"Failed to get conversation: {str(e)}") async def get_user_conversations( self, skip: int = 0, limit: int = 50, search_query: Optional[str] = None, include_archived: bool = False, order_by: str = "updated_at", order_desc: bool = True ) -> List[Conversation]: """Get user's conversations with search and filtering.""" user_id = UserContext.get_current_user_id() if user_id is None: logger.error("Failed to get user conversations: No user context available") return [] query = select(Conversation).where( Conversation.user_id == user_id ) # Filter archived conversations if not include_archived: query = query.where(Conversation.is_archived == False) # Search functionality if search_query and search_query.strip(): search_term = f"%{search_query.strip()}%" query = query.where( or_( Conversation.title.ilike(search_term), Conversation.system_prompt.ilike(search_term) ) ) # Ordering order_column = getattr(Conversation, order_by, Conversation.updated_at) if order_desc: query = query.order_by(desc(order_column)) else: query = query.order_by(order_column) return (await self.session.scalars(query.offset(skip).limit(limit))).all() async def update_conversation( self, conversation_id: int, conversation_update: ConversationUpdate ) -> Optional[Conversation]: """Update a conversation.""" conversation = await self.get_conversation(conversation_id) if not conversation: return None update_data = conversation_update.dict(exclude_unset=True) for field, value in update_data.items(): setattr(conversation, field, value) # Update audit fields conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) try: await self.session.commit() await self.session.refresh(conversation) return conversation except Exception as e: logger.error(f"Failed to update conversation {conversation_id}: {str(e)}", exc_info=True) await self.session.rollback() raise DatabaseError(f"Failed to update conversation: {str(e)}") async def delete_conversation(self, conversation_id: int) -> bool: """Delete a conversation.""" conversation = await self.get_conversation(conversation_id) if not conversation: return False await self.session.delete(conversation) await self.session.commit() return True async def get_conversation_messages( self, conversation_id: int, skip: int = 0, limit: int = 100 ) -> List[Message]: """Get messages from a conversation.""" return (await self.session.scalars( select(Message).where( Message.conversation_id == conversation_id ).order_by(Message.created_at).offset(skip).limit(limit) )).all() async def add_message( self, conversation_id: int, content: str, role: MessageRole, message_metadata: Optional[dict] = None, context_documents: Optional[list] = None, prompt_tokens: Optional[int] = None, completion_tokens: Optional[int] = None, total_tokens: Optional[int] = None ) -> Message: """Add a message to a conversation.""" message = Message( conversation_id=conversation_id, content=content, role=role, message_metadata=message_metadata, context_documents=context_documents, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens ) # Set audit fields message.set_audit_fields() session = AsyncSessionFactory() session.begin() try: session.add(message) await session.commit() await session.refresh(message) # Update conversation's updated_at timestamp conversation = await self.get_conversation(conversation_id) if conversation: conversation.updated_at = datetime.now(timezone.utc) conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) await session.commit() except Exception as e: logger.error(f"Failed to add message to conversation {conversation_id}: {str(e)}", exc_info=True) await session.rollback() finally: await session.close() return message async def get_conversation_history_messages( self, conversation_id: int, limit: int = 20 ) -> List[Message]: """Get recent conversation history messages.""" history = await self.get_conversation_history(conversation_id, limit) history_messages = [] for message in history: if message.role == MessageRole.USER: history_messages.append(HumanMessage(content=message.content)) elif message.role == MessageRole.ASSISTANT: history_messages.append(AIMessage(content=message.content)) return history_messages async def get_conversation_history( self, conversation_id: int, limit: int = 20 ) -> List[Message]: """Get recent conversation history for context.""" return (await self.session.scalars( select(Message).where( Message.conversation_id == conversation_id ).order_by(desc(Message.created_at)).limit(limit) )).all()[::-1] # Reverse to get chronological order async def update_conversation_timestamp(self, conversation_id: int) -> None: """Update conversation's updated_at timestamp.""" conversation = await self.get_conversation(conversation_id) if conversation: conversation.updated_at = datetime.now(timezone.utc) conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) await self.session.commit() async def get_user_conversations_count( self, search_query: Optional[str] = None, include_archived: bool = False ) -> int: """Get total count of user's conversations.""" user_id = UserContext.get_current_user_id() query = select(func.count(Conversation.id)).where( Conversation.user_id == user_id ) if not include_archived: query = query.where(Conversation.is_archived == False) if search_query and search_query.strip(): search_term = f"%{search_query.strip()}%" query = query.where( or_( Conversation.title.ilike(search_term), Conversation.system_prompt.ilike(search_term) ) ) return (await self.session.scalar(query)) or 0 async def archive_conversation(self, conversation_id: int) -> bool: """Archive a conversation.""" conversation = await self.get_conversation(conversation_id) if not conversation: return False conversation.is_archived = True conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) await self.session.commit() return True async def unarchive_conversation(self, conversation_id: int) -> bool: """Unarchive a conversation.""" conversation = await self.get_conversation(conversation_id) if not conversation: return False conversation.is_archived = False conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) await self.session.commit() return True