157 lines
6.7 KiB
Python
157 lines
6.7 KiB
Python
"""LLM配置服务 - 从数据库读取默认配置"""
|
||
|
||
from typing import Optional, Dict, Any, List
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import and_, select
|
||
|
||
from ..models.llm_config import LLMConfig
|
||
from ..db.database import get_session
|
||
from loguru import logger
|
||
|
||
class LLMConfigService:
|
||
"""LLM配置管理服务"""
|
||
|
||
async def get_default_chat_config(self, session: Session) -> Optional[LLMConfig]:
|
||
"""获取默认对话模型配置"""
|
||
# async for session in get_session():
|
||
try:
|
||
stmt = select(LLMConfig).where(
|
||
and_(
|
||
LLMConfig.is_default == True,
|
||
LLMConfig.is_embedding == False,
|
||
LLMConfig.is_active == True
|
||
)
|
||
)
|
||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||
|
||
if not config:
|
||
logger.warning("未找到默认对话模型配置")
|
||
return None
|
||
|
||
return config
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取默认对话模型配置失败: {str(e)}")
|
||
return None
|
||
|
||
async def get_default_embedding_config(self, session: Session) -> Optional[LLMConfig]:
|
||
"""获取默认嵌入模型配置,如果没有默认配置则尝试使用任何激活的嵌入模型配置"""
|
||
if session is None:
|
||
logger.error("get_default_embedding_config: session 为 None,无法查询配置")
|
||
return None
|
||
|
||
try:
|
||
# 首先尝试获取默认嵌入模型配置
|
||
stmt = select(LLMConfig).where(
|
||
and_(
|
||
LLMConfig.is_default == True,
|
||
LLMConfig.is_embedding == True,
|
||
LLMConfig.is_active == True
|
||
)
|
||
)
|
||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||
|
||
if config:
|
||
session.desc = f"找到默认嵌入模型配置: {config.name} (ID: {config.id})"
|
||
return config
|
||
|
||
# 如果没有默认配置,尝试获取任何激活的嵌入模型配置作为后备
|
||
session.desc = "未找到默认嵌入模型配置,尝试查找任何激活的嵌入模型配置"
|
||
logger.info("未找到默认嵌入模型配置,尝试查找任何激活的嵌入模型配置")
|
||
|
||
stmt = select(LLMConfig).where(
|
||
and_(
|
||
LLMConfig.is_embedding == True,
|
||
LLMConfig.is_active == True
|
||
)
|
||
).order_by(LLMConfig.created_at) # 按创建时间排序,取第一个
|
||
|
||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||
|
||
if config:
|
||
session.desc = f"使用激活的嵌入模型配置(非默认): {config.name} (ID: {config.id})"
|
||
logger.info(f"使用激活的嵌入模型配置(非默认): {config.name} (ID: {config.id})")
|
||
return config
|
||
|
||
# 如果还是没找到,记录详细信息
|
||
session.desc = "ERROR: 未找到任何激活的嵌入模型配置"
|
||
logger.error("未找到任何激活的嵌入模型配置")
|
||
|
||
# 尝试查询所有嵌入模型配置(包括未激活的),用于调试
|
||
all_embedding_stmt = select(LLMConfig).where(LLMConfig.is_embedding == True)
|
||
all_embedding = (await session.execute(all_embedding_stmt)).scalars().all()
|
||
if all_embedding:
|
||
logger.warning(f"找到 {len(all_embedding)} 个嵌入模型配置,但都不是激活状态:")
|
||
for cfg in all_embedding:
|
||
logger.warning(f" - {cfg.name} (ID: {cfg.id}, is_active={cfg.is_active}, is_default={cfg.is_default})")
|
||
else:
|
||
logger.warning("数据库中没有任何嵌入模型配置")
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
session.desc = f"ERROR: 获取嵌入模型配置失败: {str(e)}"
|
||
logger.error(f"获取嵌入模型配置失败: {str(e)}", exc_info=True)
|
||
return None
|
||
|
||
async def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]:
|
||
"""根据ID获取配置"""
|
||
try:
|
||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||
return (await self.db.execute(stmt)).scalar_one_or_none()
|
||
except Exception as e:
|
||
logger.error(f"获取配置失败: {str(e)}")
|
||
return None
|
||
|
||
def get_active_configs(self, is_embedding: Optional[bool] = None) -> List[LLMConfig]:
|
||
"""获取所有激活的配置"""
|
||
try:
|
||
stmt = select(LLMConfig).where(LLMConfig.is_active == True)
|
||
|
||
if is_embedding is not None:
|
||
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
|
||
|
||
stmt = stmt.order_by(LLMConfig.created_at)
|
||
return self.db.execute(stmt).scalars().all()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取激活配置失败: {str(e)}")
|
||
return []
|
||
|
||
async def _get_fallback_chat_config(self) -> Dict[str, Any]:
|
||
"""获取fallback对话模型配置(从环境变量)"""
|
||
from ..core.config import get_settings
|
||
settings = get_settings()
|
||
return await settings.llm.get_current_config()
|
||
|
||
async def _get_fallback_embedding_config(self) -> Dict[str, Any]:
|
||
"""获取fallback嵌入模型配置(从环境变量)"""
|
||
from ..core.config import get_settings
|
||
settings = get_settings()
|
||
return await settings.embedding.get_current_config()
|
||
|
||
def test_config(self, config_id: int, test_message: str = "Hello") -> Dict[str, Any]:
|
||
"""测试配置连接"""
|
||
try:
|
||
config = self.get_config_by_id(config_id)
|
||
if not config:
|
||
return {"success": False, "error": "配置不存在"}
|
||
|
||
# 这里可以添加实际的连接测试逻辑
|
||
# 例如发送一个简单的请求来验证配置是否有效
|
||
|
||
return {"success": True, "message": "配置测试成功"}
|
||
|
||
except Exception as e:
|
||
logger.error(f"测试配置失败: {str(e)}")
|
||
return {"success": False, "error": str(e)}
|
||
|
||
# # 全局实例
|
||
# _llm_config_service = None
|
||
|
||
# def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService:
|
||
# """获取LLM配置服务实例"""
|
||
# global _llm_config_service
|
||
# if _llm_config_service is None or db_session is not None:
|
||
# _llm_config_service = LLMConfigService(db_session)
|
||
# return _llm_config_service |