156 lines
6.3 KiB
Python
156 lines
6.3 KiB
Python
|
|
"""LLM Configuration Pydantic schemas."""
|
|||
|
|
|
|||
|
|
from typing import Optional, Dict, Any
|
|||
|
|
from pydantic import BaseModel, Field, field_validator, computed_field
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMConfigBase(BaseModel):
|
|||
|
|
"""大模型配置基础模式."""
|
|||
|
|
name: str = Field(..., min_length=1, max_length=100, description="配置名称")
|
|||
|
|
provider: str = Field(..., min_length=1, max_length=50, description="服务商")
|
|||
|
|
model_name: str = Field(..., min_length=1, max_length=100, description="模型名称")
|
|||
|
|
api_key: str = Field(..., min_length=1, description="API密钥")
|
|||
|
|
base_url: Optional[str] = Field(None, description="API基础URL")
|
|||
|
|
max_tokens: Optional[int] = Field(4096, ge=1, le=32000, description="最大令牌数")
|
|||
|
|
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="温度参数")
|
|||
|
|
top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Top-p参数")
|
|||
|
|
frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="频率惩罚")
|
|||
|
|
presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="存在惩罚")
|
|||
|
|
description: Optional[str] = Field(None, max_length=500, description="配置描述")
|
|||
|
|
|
|||
|
|
is_active: bool = Field(True, description="是否激活")
|
|||
|
|
is_default: bool = Field(False, description="是否为默认配置")
|
|||
|
|
is_embedding: bool = Field(False, description="是否为嵌入模型")
|
|||
|
|
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMConfigCreate(LLMConfigBase):
|
|||
|
|
"""创建大模型配置模式."""
|
|||
|
|
|
|||
|
|
@field_validator('provider')
|
|||
|
|
@classmethod
|
|||
|
|
def validate_provider(cls, v: str) -> str:
|
|||
|
|
allowed_providers = [
|
|||
|
|
'openai', 'azure', 'anthropic', 'google', 'baidu',
|
|||
|
|
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
|
|||
|
|
'ollama', 'custom', "doubao", "ollama"
|
|||
|
|
]
|
|||
|
|
if v.lower() not in allowed_providers:
|
|||
|
|
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
|||
|
|
return v.lower()
|
|||
|
|
|
|||
|
|
@field_validator('api_key')
|
|||
|
|
@classmethod
|
|||
|
|
def validate_api_key(cls, v: str) -> str:
|
|||
|
|
if len(v.strip()) < 10:
|
|||
|
|
raise ValueError('API密钥长度不能少于10个字符')
|
|||
|
|
return v.strip()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMConfigUpdate(BaseModel):
|
|||
|
|
"""更新大模型配置模式."""
|
|||
|
|
name: Optional[str] = Field(None, min_length=1, max_length=100, description="配置名称")
|
|||
|
|
provider: Optional[str] = Field(None, min_length=1, max_length=50, description="服务商")
|
|||
|
|
model_name: Optional[str] = Field(None, min_length=1, max_length=100, description="模型名称")
|
|||
|
|
api_key: Optional[str] = Field(None, min_length=1, description="API密钥")
|
|||
|
|
base_url: Optional[str] = Field(None, description="API基础URL")
|
|||
|
|
max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="最大令牌数")
|
|||
|
|
temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="温度参数")
|
|||
|
|
top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Top-p参数")
|
|||
|
|
frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="频率惩罚")
|
|||
|
|
presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="存在惩罚")
|
|||
|
|
description: Optional[str] = Field(None, max_length=500, description="配置描述")
|
|||
|
|
|
|||
|
|
is_active: Optional[bool] = Field(None, description="是否激活")
|
|||
|
|
is_default: Optional[bool] = Field(None, description="是否为默认配置")
|
|||
|
|
is_embedding: Optional[bool] = Field(None, description="是否为嵌入模型")
|
|||
|
|
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
|
|||
|
|
|
|||
|
|
@field_validator('provider')
|
|||
|
|
@classmethod
|
|||
|
|
def validate_provider(cls, v: Optional[str]) -> Optional[str]:
|
|||
|
|
if v is not None:
|
|||
|
|
allowed_providers = [
|
|||
|
|
'openai', 'azure', 'anthropic', 'google', 'baidu',
|
|||
|
|
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
|
|||
|
|
'ollama', 'custom',"doubao", "ollama"
|
|||
|
|
]
|
|||
|
|
if v.lower() not in allowed_providers:
|
|||
|
|
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
|||
|
|
return v.lower()
|
|||
|
|
return v
|
|||
|
|
|
|||
|
|
@field_validator('api_key')
|
|||
|
|
@classmethod
|
|||
|
|
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
|
|||
|
|
if v is not None and len(v.strip()) < 10:
|
|||
|
|
raise ValueError('API密钥长度不能少于10个字符')
|
|||
|
|
return v.strip() if v else v
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMConfigResponse(BaseModel):
|
|||
|
|
"""大模型配置响应模式."""
|
|||
|
|
id: int
|
|||
|
|
name: str
|
|||
|
|
provider: str
|
|||
|
|
model_name: str
|
|||
|
|
api_key: Optional[str] = None # 完整的API密钥(仅在include_sensitive=True时返回)
|
|||
|
|
base_url: Optional[str] = None
|
|||
|
|
max_tokens: Optional[int] = None
|
|||
|
|
temperature: Optional[float] = None
|
|||
|
|
top_p: Optional[float] = None
|
|||
|
|
frequency_penalty: Optional[float] = None
|
|||
|
|
presence_penalty: Optional[float] = None
|
|||
|
|
description: Optional[str] = None
|
|||
|
|
|
|||
|
|
is_active: bool
|
|||
|
|
is_default: bool
|
|||
|
|
is_embedding: bool
|
|||
|
|
extra_config: Optional[Dict[str, Any]] = None
|
|||
|
|
created_at: datetime
|
|||
|
|
updated_at: Optional[datetime] = None
|
|||
|
|
created_by: Optional[int] = None
|
|||
|
|
updated_by: Optional[int] = None
|
|||
|
|
|
|||
|
|
model_config = {
|
|||
|
|
'from_attributes': True
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
@computed_field
|
|||
|
|
@property
|
|||
|
|
def api_key_masked(self) -> Optional[str]:
|
|||
|
|
# 在响应中隐藏API密钥,只显示前4位和后4位
|
|||
|
|
if self.api_key:
|
|||
|
|
key = self.api_key
|
|||
|
|
if len(key) > 8:
|
|||
|
|
return f"{key[:4]}{'*' * (len(key) - 8)}{key[-4:]}"
|
|||
|
|
else:
|
|||
|
|
return '*' * len(key)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMConfigTest(BaseModel):
|
|||
|
|
"""大模型配置测试模式."""
|
|||
|
|
message: Optional[str] = Field(
|
|||
|
|
"Hello, this is a test message.",
|
|||
|
|
max_length=1000,
|
|||
|
|
description="测试消息"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMConfigClientResponse(BaseModel):
|
|||
|
|
"""大模型配置客户端响应模式(用于前端)."""
|
|||
|
|
id: int
|
|||
|
|
name: str
|
|||
|
|
provider: str
|
|||
|
|
model_name: str
|
|||
|
|
max_tokens: Optional[int] = None
|
|||
|
|
temperature: Optional[float] = None
|
|||
|
|
top_p: Optional[float] = None
|
|||
|
|
is_active: bool
|
|||
|
|
description: Optional[str] = None
|
|||
|
|
|
|||
|
|
model_config = {
|
|||
|
|
'from_attributes': True
|
|||
|
|
}
|