hyf-backend/th_agenter/services/llm_config_service.py

123 lines
4.7 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
"""LLM配置服务 - 从数据库读取默认配置"""
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from sqlalchemy import and_, select
from ..models.llm_config import LLMConfig
from ..db.database import get_session
from loguru import logger
class LLMConfigService:
"""LLM配置管理服务"""
async def get_default_chat_config(self, session: Session) -> Optional[LLMConfig]:
"""获取默认对话模型配置"""
# async for session in get_session():
try:
stmt = select(LLMConfig).where(
and_(
LLMConfig.is_default == True,
LLMConfig.is_embedding == False,
LLMConfig.is_active == True
)
)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
logger.warning("未找到默认对话模型配置")
return None
return config
except Exception as e:
logger.error(f"获取默认对话模型配置失败: {str(e)}")
return None
async def get_default_embedding_config(self, session: Session) -> Optional[LLMConfig]:
"""获取默认嵌入模型配置"""
try:
stmt = select(LLMConfig).where(
and_(
LLMConfig.is_default == True,
LLMConfig.is_embedding == True,
LLMConfig.is_active == True
)
)
config = None
if session != None:
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
if session != None:
session.desc = "ERROR: 未找到默认嵌入模型配置"
return None
session.desc = f"获取默认嵌入模型配置 > 结果:{config}"
return config
except Exception as e:
if session != None:
session.desc = f"ERROR: 获取默认嵌入模型配置失败: {str(e)}"
return None
async def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]:
"""根据ID获取配置"""
try:
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
return (await self.db.execute(stmt)).scalar_one_or_none()
except Exception as e:
logger.error(f"获取配置失败: {str(e)}")
return None
def get_active_configs(self, is_embedding: Optional[bool] = None) -> List[LLMConfig]:
"""获取所有激活的配置"""
try:
stmt = select(LLMConfig).where(LLMConfig.is_active == True)
if is_embedding is not None:
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
stmt = stmt.order_by(LLMConfig.created_at)
return self.db.execute(stmt).scalars().all()
except Exception as e:
logger.error(f"获取激活配置失败: {str(e)}")
return []
async def _get_fallback_chat_config(self) -> Dict[str, Any]:
"""获取fallback对话模型配置从环境变量"""
from ..core.config import get_settings
settings = get_settings()
return await settings.llm.get_current_config()
async def _get_fallback_embedding_config(self) -> Dict[str, Any]:
"""获取fallback嵌入模型配置从环境变量"""
from ..core.config import get_settings
settings = get_settings()
return await settings.embedding.get_current_config()
def test_config(self, config_id: int, test_message: str = "Hello") -> Dict[str, Any]:
"""测试配置连接"""
try:
config = self.get_config_by_id(config_id)
if not config:
return {"success": False, "error": "配置不存在"}
# 这里可以添加实际的连接测试逻辑
# 例如发送一个简单的请求来验证配置是否有效
return {"success": True, "message": "配置测试成功"}
except Exception as e:
logger.error(f"测试配置失败: {str(e)}")
return {"success": False, "error": str(e)}
# # 全局实例
# _llm_config_service = None
# def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService:
# """获取LLM配置服务实例"""
# global _llm_config_service
# if _llm_config_service is None or db_session is not None:
# _llm_config_service = LLMConfigService(db_session)
# return _llm_config_service