from loguru import logger from typing import List, Dict, Optional, Union, AsyncGenerator, Generator, Any # 核心:导入 LangChain 的基础语言模型抽象类 from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatResult from langchain_core.callbacks import CallbackManagerForLLMRun from dataclasses import dataclass, field from typing import Optional, Dict, Any, List from datetime import datetime @dataclass class LLMConfig_DataClass: """ 统一的LLM配置基类,覆盖在线/本地/嵌入式模型所有配置,映射数据库完整字段 通过 provider + is_embedding 区分模型类型: - 在线模型:provider in ['openai', 'zhipu', 'baidu'] + is_embedding=False - 本地模型:provider in ['llama', 'qwen', 'yi'] + is_embedding=False - 嵌入式模型:provider in ['bge', 'text2vec'] + is_embedding=True """ # ====================== 数据库核心公共字段(必选/可选) ====================== # 基础标识字段 name: str # 模型自定义名称(如 "gpt-5") model_name: str # 模型官方标识名(如 "gpt-5"、"BAAI/bge-small-zh-v1.5") provider: str # 提供商(openai/llama/bge/zhipu 等) id: Optional[int] = None # 数据库主键ID description: Optional[str] = None # 模型描述 is_active: bool = True # 是否启用 is_default: bool = False # 是否默认模型 is_embedding: bool = False # 是否为嵌入式模型(核心区分标识) # ====================== 通用生成参数(所有推理模型共用) ====================== temperature: float = 0.7 # 生成温度(默认值对齐数据库示例) max_tokens: int = 3000 # 最大生成长度(默认值对齐数据库示例) top_p: float = 0.6 # 采样Top-P frequency_penalty: float = 0.0 # 频率惩罚 presence_penalty: float = 0.0 # 存在惩罚 # ====================== 在线模型专属参数(非必填,仅在线模型生效) ====================== api_key: Optional[str] = None # API密钥(在线模型必填) base_url: Optional[str] = None # API代理地址(如 https://api.openai-proxy.org/v1) # timeout: int = 30 # 请求超时时间(秒) max_retries: int = 3 # 最大重试次数 api_version: Optional[str] = None # API版本(如 OpenAI 的 2024-02-15-preview) # ====================== 本地模型专属参数(非必填,仅本地模型生效) ====================== model_path: Optional[str] = None # 本地模型文件路径(本地模型必填) device: str = "cpu" # 运行设备(cpu/cuda/mps) n_ctx: int = 2048 # 上下文窗口大小 n_threads: int = 8 # 推理线程数 quantization: str = "q4_0" # 量化级别(q4_0/q8_0/f16) load_in_8bit: bool = False # 是否8bit加载 load_in_4bit: bool = False # 是否4bit加载 prompt_template: Optional[str] = None # 自定义Prompt模板 # ====================== 嵌入式模型专属参数(非必填,仅嵌入式模型生效) ====================== normalize_embeddings: bool = True # 是否归一化向量 batch_size: int = 32 # 批量编码大小 encode_kwargs: Dict[str, Any] = field(default_factory=dict) # 编码扩展参数 dimension: Optional[int] = None # 向量维度(如 768) # ====================== 元数据字段(数据库自动维护) ====================== extra_config: Dict[str, Any] = field(default_factory=dict) # 额外扩展配置 usage_count: int = 0 # 使用次数 last_used_at: Optional[datetime] = None # 最后使用时间 created_at: Optional[datetime] = None # 创建时间 updated_at: Optional[datetime] = None # 更新时间 created_by: Optional[int] = None # 创建人ID updated_by: Optional[int] = None # 更新人ID api_key_masked: Optional[str] = "" # 掩码后的API密钥(数据库存储) # ====================== 核心工具方法 ====================== def __post_init__(self): """后置初始化:自动校验和修正配置""" # 1. 嵌入式模型强制清空推理参数(避免误用) if self.is_embedding: self.max_tokens = 0 self.temperature = 0.0 self.top_p = 0.0 # 2. 校验必填参数(按模型类型) self._validate_required_fields() def _validate_required_fields(self): """按模型类型校验必填参数""" # 在线模型校验 if not self.is_embedding and self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']: if not self.api_key: raise ValueError(f"[{self.name}] 在线模型({self.provider})必须配置 api_key") # 本地模型校验 if not self.is_embedding and self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']: if not self.model_path: raise ValueError(f"[{self.name}] 本地模型({self.provider})必须配置 model_path") def to_dict(self) -> Dict[str, Any]: """转换为字典(用于存入/更新数据库)""" return { key: value for key, value in self.__dict__.items() if not key.startswith('_') # 排除私有属性 } @classmethod def from_db_dict(cls, db_dict: Dict[str, Any]) -> "LLMConfig_DataClass": """从数据库字典初始化配置(核心方法)""" # 1. 时间字段转换:字符串 → datetime time_fields = ['last_used_at', 'created_at', 'updated_at'] for field_name in time_fields: val = db_dict.get(field_name) if val and isinstance(val, str): try: db_dict[field_name] = datetime.fromisoformat(val.replace('Z', '+00:00')) except (ValueError, TypeError): db_dict[field_name] = None # 2. 过滤数据库中无关字段(如 api_key_masked) valid_fields = cls.__dataclass_fields__.keys() filtered_dict = {k: v for k, v in db_dict.items() if k in valid_fields} # 3. 初始化并返回配置实例 return cls(**filtered_dict) def get_model_type(self) -> str: """快速判断模型类型(返回:online/local/embedding)""" if self.is_embedding: return "embedding" if self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']: return "online" if self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']: return "local" return "unknown" class BaseLLM(BaseChatModel): """ 继承 LangChain 的 BaseChatModel(BaseLanguageModel 的子类) 使其能直接用于 create_agent """ # 配置参数(通过 __init__ 初始化) config: Any = None model: Any = None def __init__(self, config): super().__init__() # 必须调用父类构造函数 self.config = config self.model = None self._validate_config() logger.info(f"初始化 {self.__class__.__name__},模型: {config.model_name}") # ---------------------- 必须实现的核心抽象方法(LangChain 协议) ---------------------- def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """ 核心同步生成方法(LangChain 要求必须实现) messages: 消息列表(如 [HumanMessage(content="你好")]) 返回 ChatResult 类型(LangChain 标准输出) """ logger.error(f"{self.__class__.__name__} 未实现 同步 _generate 方法") async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ** kwargs: Any, ) -> ChatResult: """异步生成方法(LangChain 异步协议)""" logger.error(f"{self.__class__.__name__} 未实现 异步 _agenerate 方法") @property def _llm_type(self) -> str: """返回模型类型标识(如 "openai"、"llama"、"bge")""" return self.__class__.__name__ def load_model(self) -> None: """加载模型(自定义逻辑)""" logger.error(f"{self.__class__.__name__} 未实现 load_model 方法") def close(self) -> None: """释放资源(自定义逻辑)""" if self.model: logger.info(f"释放 {self.__class__.__name__} 模型资源") self.model = None def __enter__(self): self.load_model() return self def __exit__(self, exc_type, exc_val, exc_tb): self.close()