"""Pydantic schemas for API requests and responses.""" from typing import Optional, List, Any, Dict, TYPE_CHECKING, Union from datetime import datetime from pydantic import BaseModel, Field, model_validator from enum import Enum if TYPE_CHECKING: from th_agenter.schemas.permission import RoleResponse class MessageRole(str, Enum): """消息角色枚举""" USER = "user" ASSISTANT = "assistant" SYSTEM = "system" class MessageType(str, Enum): """消息类型枚举""" TEXT = "text" IMAGE = "image" FILE = "file" AUDIO = "audio" # Base schemas class BaseResponse(BaseModel): """基础响应模型""" id: int created_at: datetime updated_at: datetime class Config: from_attributes = True # User schemas class UserBase(BaseModel): """用户基础模型""" username: str = Field(..., min_length=3, max_length=50) email: str = Field(..., max_length=100) full_name: Optional[str] = Field(None, max_length=100) bio: Optional[str] = None avatar_url: Optional[str] = None class UserCreate(UserBase): """用户创建模型""" password: str = Field(..., min_length=6) class UserUpdate(BaseModel): """用户更新模型""" username: Optional[str] = Field(None, min_length=3, max_length=50) email: Optional[str] = Field(None, max_length=100) full_name: Optional[str] = Field(None, max_length=100) bio: Optional[str] = None avatar_url: Optional[str] = None password: Optional[str] = Field(None, min_length=6) is_active: Optional[bool] = None department_id: Optional[int] = None class UserResponse(BaseResponse, UserBase): """用户响应模型""" is_active: bool department_id: Optional[int] = None roles: Optional[List['RoleResponse']] = Field(default=[], description="用户角色列表") permissions: Optional[List[Dict[str, Any]]] = Field(default=[], description="用户权限列表") is_superuser: Optional[bool] = Field(default=False, description="是否为超级管理员") @classmethod def from_orm(cls, obj): """从ORM对象创建响应对象,安全处理关系属性(同步版本).""" # 获取基本字段 data = { 'id': obj.id, 'username': obj.username, 'email': obj.email, 'full_name': obj.full_name, 'is_active': obj.is_active, 'department_id': obj.department_id, 'created_at': obj.created_at, 'updated_at': obj.updated_at, 'created_by': obj.created_by, 'updated_by': obj.updated_by, } # 安全处理roles关系 - 仅使用已加载的关系,不尝试刷新 try: if hasattr(obj, 'roles'): try: from th_agenter.schemas.permission import RoleResponse # 仅访问已加载的角色,不触发新查询 data['roles'] = [RoleResponse.from_orm(role) for role in obj.roles if role.is_active] except Exception: # 如果访问roles失败(DetachedInstanceError或延迟加载错误),使用空列表 data['roles'] = [] else: data['roles'] = [] except Exception: data['roles'] = [] # 安全处理权限信息 - 仅使用已加载的关系,不尝试刷新 try: permissions = set() if hasattr(obj, 'roles'): try: for role in obj.roles: if role.is_active: try: for perm in role.permissions: if perm.is_active: permissions.add((perm.code, perm.name)) except Exception: # 权限加载失败,跳过 continue except Exception: # 角色加载失败,跳过 pass data['permissions'] = [{'code': code, 'name': name} for code, name in permissions] except Exception: data['permissions'] = [] # 添加is_superuser字段 try: # 检查是否有is_admin属性或is_superuser属性 if hasattr(obj, 'is_admin'): data['is_superuser'] = obj.is_admin elif hasattr(obj, 'is_superuser'): if callable(obj.is_superuser): try: data['is_superuser'] = obj.is_superuser() except Exception: data['is_superuser'] = False else: data['is_superuser'] = obj.is_superuser else: data['is_superuser'] = False except Exception: data['is_superuser'] = False return cls(**data) @classmethod async def from_orm_async(cls, obj): """从ORM对象创建响应对象,安全处理关系属性(异步版本).""" # 获取基本字段 data = { 'id': obj.id, 'username': obj.username, 'email': obj.email, 'full_name': obj.full_name, 'is_active': obj.is_active, 'department_id': obj.department_id, 'created_at': obj.created_at, 'updated_at': obj.updated_at, 'created_by': obj.created_by, 'updated_by': obj.updated_by, } # 安全处理roles关系 try: from sqlalchemy.orm import object_session from sqlalchemy.ext.asyncio import AsyncSession session = object_session(obj) roles_loaded = [] if hasattr(obj, 'roles'): # 根据会话类型加载角色 if session and isinstance(session, AsyncSession): # 异步会话,使用await刷新 await session.refresh(obj, ['roles']) roles_loaded = obj.roles if obj.roles is not None else [] else: # 同步会话或无会话,直接访问 try: roles_loaded = obj.roles if obj.roles is not None else [] except Exception: roles_loaded = [] else: roles_loaded = [] from th_agenter.schemas.permission import RoleResponse data['roles'] = [RoleResponse.from_orm(role) for role in roles_loaded] except Exception as e: # 如果访问roles失败,使用空列表 data['roles'] = [] # 添加权限信息 try: # 获取数据库会话 from sqlalchemy.orm import object_session session = object_session(obj) is_super_admin = False if hasattr(obj, 'has_role'): if callable(obj.has_role): # 检查has_role是否为异步方法 import inspect if inspect.iscoroutinefunction(obj.has_role): is_super_admin = await obj.has_role('SUPER_ADMIN') else: is_super_admin = obj.has_role('SUPER_ADMIN') if is_super_admin: # 超级管理员拥有所有权限 if session: from th_agenter.models.permission import Permission if isinstance(session, AsyncSession): from sqlalchemy import select all_permissions = await session.execute(select(Permission).filter(Permission.is_active == True)) all_permissions = all_permissions.scalars().all() else: all_permissions = session.query(Permission).filter(Permission.is_active == True).all() data['permissions'] = [{'code': perm.code, 'name': perm.name} for perm in all_permissions] else: data['permissions'] = [{'code': '*', 'name': '所有权限'}] else: # 从角色获取权限 permissions = set() # 使用已加载的角色,避免再次访问关系 for role in roles_loaded: if role.is_active: # 同样处理role.permissions关系 role_perms = [] if hasattr(role, 'permissions'): try: if session and isinstance(session, AsyncSession): await session.refresh(role, ['permissions']) role_perms = role.permissions if role.permissions is not None else [] else: role_perms = role.permissions if role.permissions is not None else [] except Exception: role_perms = [] for perm in role_perms: if perm.is_active: permissions.add((perm.code, perm.name)) data['permissions'] = [{'code': code, 'name': name} for code, name in permissions] except Exception as e: # 如果访问权限失败,使用空列表 data['permissions'] = [] # 添加is_superuser字段 try: # 检查是否有is_admin属性或is_superuser属性 if hasattr(obj, 'is_admin'): data['is_superuser'] = obj.is_admin elif hasattr(obj, 'is_superuser'): if callable(obj.is_superuser): import inspect if inspect.iscoroutinefunction(obj.is_superuser): data['is_superuser'] = await obj.is_superuser() else: data['is_superuser'] = obj.is_superuser() else: data['is_superuser'] = obj.is_superuser else: data['is_superuser'] = False except Exception: data['is_superuser'] = False return cls(**data) # Authentication schemas class LoginRequest(BaseModel): """登录请求模型,兼容前端多余字段(如 selectAccount、captcha、username)""" email: str = Field(..., max_length=100) password: str = Field(..., min_length=6) model_config = {"extra": "ignore"} class Token(BaseModel): """访问令牌响应模型""" access_token: str token_type: str expires_in: int # Conversation schemas class ConversationBase(BaseModel): """对话基础模型""" title: str = Field(..., min_length=1, max_length=200) system_prompt: Optional[str] = None model_name: str = Field(default="gpt-3.5-turbo", max_length=100) temperature: str = Field(default="0.7", max_length=10) max_tokens: int = Field(default=2048, ge=1, le=8192) knowledge_base_id: Optional[int] = None class ConversationCreate(ConversationBase): """对话创建模型""" pass class ConversationUpdate(BaseModel): """对话更新模型""" title: Optional[str] = Field(None, min_length=1, max_length=200) system_prompt: Optional[str] = None model_name: Optional[str] = Field(None, max_length=100) temperature: Optional[str] = Field(None, max_length=10) max_tokens: Optional[int] = Field(None, ge=1, le=8192) is_archived: Optional[bool] = None class ConversationResponse(BaseResponse, ConversationBase): """对话响应模型""" user_id: int is_archived: bool message_count: int = 0 last_message_at: Optional[datetime] = None messages: Optional[List["MessageResponse"]] = None # Message schemas class MessageBase(BaseModel): """消息基础模型""" content: str = Field(..., min_length=1) role: MessageRole message_type: MessageType = MessageType.TEXT metadata: Optional[Dict[str, Any]] = Field(None, alias="message_metadata") class MessageCreate(MessageBase): """消息创建模型""" conversation_id: int class MessageResponse(BaseResponse, MessageBase): """消息响应模型""" conversation_id: int context_documents: Optional[List[Dict[str, Any]]] = None prompt_tokens: Optional[int] = None completion_tokens: Optional[int] = None total_tokens: Optional[int] = None class Config: from_attributes = True populate_by_name = True # Chat schemas class ChatRequest(BaseModel): """聊天请求模型""" message: str = Field(..., min_length=1, max_length=10000) stream: bool = Field(default=False) use_knowledge_base: bool = Field(default=False) knowledge_base_id: Optional[int] = Field(default=None, description="Knowledge base ID for RAG mode") use_agent: bool = Field(default=False, description="Enable agent mode with tool calling capabilities") use_langgraph: bool = Field(default=False, description="Enable LangGraph agent mode with advanced tool calling") temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) max_tokens: Optional[int] = Field(default=2048, ge=1, le=8192) class ChatResponse(BaseModel): """聊天响应模型""" user_message: MessageResponse assistant_message: MessageResponse total_tokens: Optional[int] = None model_used: str class AgentChatRequest(BaseModel): """agentChat 请求:AI大模型、提示词、关联知识库""" model_id: int = Field(..., ge=1, description="AI大模型配置ID") prompt: Optional[str] = Field(default=None, max_length=20000, description="提示词,与 message 二选一") message: Optional[str] = Field(default=None, max_length=20000, description="提示词(与 prompt 等价,二选一)") knowledge_base_id: Optional[int] = Field(default=None, ge=1, description="关联知识库ID(单个),与 knowledge_base_ids 二选一") knowledge_base_ids: Optional[List[Union[int, str]]] = Field(default=None, description="关联知识库ID列表,如 [1, 2] 或 ['3']") top_k: int = Field(default=5, ge=1, le=20, description="知识库检索返回条数") temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) max_tokens: Optional[int] = Field(default=None, ge=1, le=32768) @model_validator(mode="after") def require_prompt_or_message(self): if not ((self.prompt or "").strip() or (self.message or "").strip()): raise ValueError("prompt 或 message 至少提供一个") return self class AgentChatResponse(BaseModel): """agentChat 响应""" response: str = Field(..., description="模型输出结果") model_id: int = Field(..., description="使用的大模型配置ID") model_name: str = Field(..., description="使用的大模型名称") knowledge_base_id: Optional[int] = Field(default=None, description="关联的知识库ID(若使用)") knowledge_base_used: bool = Field(default=False, description="是否使用了知识库RAG") references: Optional[List[Dict[str, Any]]] = Field(default=None, description="引用的知识库片段(若使用RAG)") class StreamChunk(BaseModel): """流式响应块模型""" content: str role: MessageRole = MessageRole.ASSISTANT finish_reason: Optional[str] = None tokens_used: Optional[int] = None # Knowledge Base schemas class KnowledgeBaseBase(BaseModel): """知识库基础模型""" name: str = Field(..., min_length=1, max_length=100) description: Optional[str] = None embedding_model: str = Field(default="sentence-transformers/all-MiniLM-L6-v2") chunk_size: int = Field(default=1000, ge=100, le=5000) chunk_overlap: int = Field(default=200, ge=0, le=1000) class KnowledgeBaseCreate(KnowledgeBaseBase): """知识库创建模型""" pass class KnowledgeBaseUpdate(BaseModel): """知识库更新模型""" name: Optional[str] = Field(None, min_length=1, max_length=100) description: Optional[str] = None embedding_model: Optional[str] = None chunk_size: Optional[int] = Field(None, ge=100, le=5000) chunk_overlap: Optional[int] = Field(None, ge=0, le=1000) is_active: Optional[bool] = None class KnowledgeBaseResponse(BaseResponse, KnowledgeBaseBase): """知识库响应模型""" is_active: bool vector_db_type: str collection_name: Optional[str] document_count: int = 0 active_document_count: int = 0 # Document schemas class DocumentBase(BaseModel): """文档基础模型""" filename: str original_filename: str file_type: str file_size: int class DocumentUpload(BaseModel): """文档上传模型""" knowledge_base_id: int process_immediately: bool = Field(default=True) class DocumentResponse(BaseResponse, DocumentBase): """文档响应模型""" knowledge_base_id: int file_path: str mime_type: Optional[str] is_processed: bool processing_error: Optional[str] chunk_count: int = 0 embedding_model: Optional[str] file_size_mb: float class DocumentListResponse(BaseModel): """文档列表响应模型""" documents: List[DocumentResponse] total: int page: int page_size: int class DocumentProcessingStatus(BaseModel): """文档处理状态模型""" document_id: int status: str # 'pending', 'processing', 'completed', 'failed' progress: float = Field(default=0.0, ge=0.0, le=100.0) error_message: Optional[str] = None chunks_created: int = 0 estimated_time_remaining: Optional[int] = None # seconds # Error schemas # Document chunk schemas class DocumentChunk(BaseModel): """文档分块模型""" id: str content: str metadata: Dict[str, Any] = Field(default_factory=dict) page_number: Optional[int] = None chunk_index: int start_char: Optional[int] = None end_char: Optional[int] = None vector_id: Optional[str] = None class DocumentChunksResponse(BaseModel): """文档分块响应模型""" document_id: int document_name: str total_chunks: int chunks: List[DocumentChunk] class ErrorResponse(BaseModel): """错误响应模型""" error: str detail: Optional[str] = None code: Optional[str] = None # 通用返回结构 class NormalResponse(BaseModel): """通用返回模型""" success: bool message: str data: Optional[Dict[str, Any]] = None class ExcelPreviewRequest(BaseModel): """Excel预览请求模型""" file_id: str page: int = 1 page_size: int = 20 class FileListResponse(BaseModel): """文件列表响应模型""" success: bool message: str data: Optional[Dict[str, Any]] = None # 解决前向引用问题 def rebuild_models(): """重建模型以解决前向引用问题.""" try: from th_agenter.schemas.permission import RoleResponse UserResponse.model_rebuild() except ImportError: # 如果无法导入RoleResponse,跳过重建 pass # 在模块加载时尝试重建模型 rebuild_models()