hyf-backend/th_agenter/models/llm_config.py

162 lines
7.1 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
"""LLM Configuration model for managing multiple AI models."""
from datetime import datetime
from typing import Dict, Any, Optional
from sqlalchemy import String, Text, Boolean, Integer, Float, JSON, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from ..db.base import BaseModel
class LLMConfig(BaseModel):
"""LLM Configuration model for managing AI model settings."""
__tablename__ = "llm_configs"
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # 配置名称
provider: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # 服务商openai, deepseek, doubao, zhipu, moonshot, baidu
model_name: Mapped[str] = mapped_column(String(100), nullable=False) # 模型名称
api_key: Mapped[str] = mapped_column(String(500), nullable=False) # API密钥加密存储
base_url: Mapped[Optional[str]] = mapped_column(String(200), nullable=True) # API基础URL
# 模型参数
max_tokens: Mapped[int] = mapped_column(Integer, default=2048, nullable=False)
temperature: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
top_p: Mapped[float] = mapped_column(Float, default=1.0, nullable=False)
frequency_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
presence_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
# 配置信息
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 配置描述
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否启用
is_default: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为默认配置
is_embedding: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为嵌入模型
# 扩展配置JSON格式
extra_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) # 额外配置参数
# 使用统计
usage_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) # 使用次数
last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后使用时间
def __repr__(self):
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model_name='{self.model_name}', base_url='{self.base_url}')>"
def to_dict(self, include_sensitive=False):
"""Convert to dictionary, optionally excluding sensitive data."""
data = super().to_dict()
data.update({
'name': self.name,
'provider': self.provider,
'model_name': self.model_name,
'base_url': self.base_url,
'max_tokens': self.max_tokens,
'temperature': self.temperature,
'top_p': self.top_p,
'frequency_penalty': self.frequency_penalty,
'presence_penalty': self.presence_penalty,
'description': self.description,
'is_active': self.is_active,
'is_default': self.is_default,
'is_embedding': self.is_embedding,
'extra_config': self.extra_config,
'usage_count': self.usage_count,
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None
})
if include_sensitive:
data['api_key'] = self.api_key
else:
# 只显示API密钥的前几位和后几位
if self.api_key:
key_len = len(self.api_key)
if key_len > 8:
data['api_key_masked'] = f"{self.api_key[:4]}...{self.api_key[-4:]}"
else:
data['api_key_masked'] = "***"
else:
data['api_key_masked'] = None
return data
def get_client_config(self) -> Dict[str, Any]:
"""获取用于创建客户端的配置."""
config = {
'api_key': self.api_key,
'base_url': self.base_url,
'model': self.model_name,
'max_tokens': self.max_tokens,
'temperature': self.temperature,
'top_p': self.top_p,
'frequency_penalty': self.frequency_penalty,
'presence_penalty': self.presence_penalty
}
# 添加额外配置
if self.extra_config:
config.update(self.extra_config)
return config
def validate_config(self) -> Dict[str, Any]:
"""验证配置是否有效."""
if not self.name or not self.name.strip():
return {"valid": False, "error": "配置名称不能为空"}
if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu', 'ollama']:
return {"valid": False, "error": f"不支持的服务商 {self.provider}"}
if not self.model_name or not self.model_name.strip():
return {"valid": False, "error": "模型名称不能为空"}
if not self.api_key or not self.api_key.strip():
return {"valid": False, "error": "API密钥不能为空"}
if self.max_tokens <= 0 or self.max_tokens > 32000:
return {"valid": False, "error": "最大令牌数必须在1-32000之间"}
if self.temperature < 0 or self.temperature > 2:
return {"valid": False, "error": "温度参数必须在0-2之间"}
return {"valid": True, "error": None}
def increment_usage(self):
"""增加使用次数."""
self.usage_count += 1
self.last_used_at = datetime.now()
@classmethod
def get_default_config(cls, provider: str, is_embedding: bool = False):
"""获取服务商的默认配置模板."""
templates = {
'openai': {
'base_url': 'https://api.openai.com/v1',
'model_name': 'gpt-4.0-mini' if not is_embedding else 'text-embedding-ada-002',
'max_tokens': 2048,
'temperature': 0.7
},
'deepseek': {
'base_url': 'https://api.deepseek.com/v1',
'model_name': 'deepseek-chat' if not is_embedding else 'deepseek-embedding',
'max_tokens': 2048,
'temperature': 0.7
},
'doubao': {
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
'model_name': 'doubao-lite-4k' if not is_embedding else 'doubao-embedding',
'max_tokens': 2048,
'temperature': 0.7
},
'zhipu': {
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
'model_name': 'glm-4' if not is_embedding else 'embedding-3',
'max_tokens': 2048,
'temperature': 0.7
},
'moonshot': {
'base_url': 'https://api.moonshot.cn/v1',
'model_name': 'moonshot-v1-8k' if not is_embedding else 'moonshot-embedding',
'max_tokens': 2048,
'temperature': 0.7
}
}
return templates.get(provider, {})