hyf-backend/th_agenter/services/database_config_service.py

375 lines
14 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
"""数据库配置服务"""
from loguru import logger
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from sqlalchemy import select
from cryptography.fernet import Fernet
import os
from ..models.database_config import DatabaseConfig
from utils.util_exceptions import ValidationError, NotFoundError
from .postgresql_tool_manager import get_postgresql_tool
from .mysql_tool_manager import get_mysql_tool
class DatabaseConfigService:
"""数据库配置管理服务"""
def __init__(self, db_session: Session):
self.session = db_session
self.postgresql_tool = get_postgresql_tool()
self.mysql_tool = get_mysql_tool()
# 初始化加密密钥
self.encryption_key = self._get_or_create_encryption_key()
self.cipher = Fernet(self.encryption_key)
def _get_or_create_encryption_key(self) -> bytes:
"""获取或创建加密密钥"""
key_file = "db/db_config_key.key"
if os.path.exists(key_file):
print('find db_config_key')
with open(key_file, 'rb') as f:
return f.read()
else:
print('not find db_config_key')
key = Fernet.generate_key()
with open(key_file, 'wb') as f:
f.write(key)
return key
def _encrypt_password(self, password: str) -> str:
"""加密密码"""
return self.cipher.encrypt(password.encode()).decode()
def _decrypt_password(self, encrypted_password: str) -> str:
"""解密密码"""
return self.cipher.decrypt(encrypted_password.encode()).decode()
async def create_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
"""创建数据库配置"""
try:
# 验证配置
required_fields = ['name', 'db_type', 'host', 'port', 'database', 'username', 'password']
for field in required_fields:
if field not in config_data:
raise ValidationError(f"缺少必需字段: {field}")
# 测试连接
test_config = {
'host': config_data['host'],
'port': config_data['port'],
'database': config_data['database'],
'username': config_data['username'],
'password': config_data['password']
}
if 'postgresql' == config_data['db_type']:
test_result = await self.postgresql_tool.execute(
operation="test_connection",
connection_config=test_config
)
if not test_result.success:
raise ValidationError(f"数据库连接测试失败: {test_result.error}")
elif 'mysql' == config_data['db_type']:
test_result = await self.mysql_tool.execute(
operation="test_connection",
connection_config=test_config
)
if not test_result.success:
raise ValidationError(f"数据库连接测试失败: {test_result.error}")
# 如果设置为默认配置,先取消其他默认配置
if config_data.get('is_default', False):
stmt = select(DatabaseConfig).where(
DatabaseConfig.created_by == user_id,
DatabaseConfig.is_default == True
)
result = self.session.execute(stmt)
for config in result.scalars():
config.is_default = False
# 创建配置
db_config = DatabaseConfig(
created_by=user_id,
name=config_data['name'],
db_type=config_data['db_type'],
host=config_data['host'],
port=config_data['port'],
database=config_data['database'],
username=config_data['username'],
password=self._encrypt_password(config_data['password']),
is_active=config_data.get('is_active', True),
is_default=config_data.get('is_default', False),
connection_params=config_data.get('connection_params')
)
self.session.add(db_config)
await self.session.commit()
await self.session.refresh(db_config)
logger.info(f"创建数据库配置成功: {db_config.name} (ID: {db_config.id})")
return db_config
except Exception as e:
await self.session.rollback()
logger.error(f"创建数据库配置失败: {str(e)}")
raise
async def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]:
"""获取用户的数据库配置列表"""
stmt = select(DatabaseConfig).where(DatabaseConfig.created_by == user_id)
if active_only:
stmt = stmt.where(DatabaseConfig.is_active == True)
stmt = stmt.order_by(DatabaseConfig.created_at.desc())
return (await self.session.execute(stmt)).scalars().all()
async def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]:
"""根据ID获取配置"""
stmt = select(DatabaseConfig).where(
DatabaseConfig.id == config_id,
DatabaseConfig.created_by == user_id
)
return (await self.session.execute(stmt)).scalar_one_or_none()
async def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]:
"""获取用户的默认配置"""
stmt = select(DatabaseConfig).where(
DatabaseConfig.created_by == user_id,
# DatabaseConfig.is_default == True,
DatabaseConfig.is_active == True
)
return (await self.session.execute(stmt)).scalar_one_or_none()
async def test_connection(self, config_id: int, user_id: int) -> Dict[str, Any]:
"""测试数据库连接"""
config = self.get_config_by_id(config_id, user_id)
if not config:
raise NotFoundError("数据库配置不存在")
test_config = {
'host': config.host,
'port': config.port,
'database': config.database,
'username': config.username,
'password': self._decrypt_password(config.password)
}
result = await self.postgresql_tool.execute(
operation="test_connection",
connection_config=test_config
)
return {
'success': result.success,
'message': result.result.get('message') if result.success else result.error,
'details': result.result if result.success else None
}
async def connect_and_get_tables(self, config_id: int, user_id: int) -> Dict[str, Any]:
"""连接数据库并获取表列表"""
config = self.get_config_by_id(config_id, user_id)
if not config:
raise NotFoundError("数据库配置不存在")
connection_config = {
'host': config.host,
'port': config.port,
'database': config.database,
'username': config.username,
'password': self._decrypt_password(config.password)
}
if 'postgresql' == config.db_type:
# 连接数据库
connect_result = await self.postgresql_tool.execute(
operation="connect",
connection_config=connection_config,
user_id=str(user_id)
)
elif 'mysql' == config.db_type:
# 连接数据库
connect_result = await self.mysql_tool.execute(
operation="connect",
connection_config=connection_config,
user_id=str(user_id)
)
if not connect_result.success:
return {
'success': False,
'message': connect_result.error
}
# 连接信息已保存到PostgreSQLMCPTool的connections中
return {
'success': True,
'data': connect_result.result,
'config_name': config.name
}
async def get_table_data(self, table_name: str, user_id: int, db_type: str, limit: int = 100) -> Dict[str, Any]:
"""获取表数据预览(复用已建立的连接)"""
try:
user_id_str = str(user_id)
# 根据db_type选择相应的数据库工具
if db_type.lower() == 'postgresql':
db_tool = self.postgresql_tool
elif db_type.lower() == 'mysql':
db_tool = self.mysql_tool
else:
return {
'success': False,
'message': f'不支持的数据库类型: {db_type}'
}
# 检查是否已有连接
if user_id_str not in db_tool.connections:
return {
'success': False,
'message': '数据库连接已断开,请重新连接数据库'
}
# 直接使用已建立的连接执行查询
sql_query = f"SELECT * FROM {table_name}"
result = await db_tool.execute(
operation="execute_query",
user_id=user_id_str,
sql_query=sql_query,
limit=limit
)
if not result.success:
return {
'success': False,
'message': result.error
}
return {
'success': True,
'data': result.result,
'db_type': db_type
}
except Exception as e:
logger.error(f"获取表数据失败: {str(e)}", exc_info=True)
return {
'success': False,
'message': f'获取表数据失败: {str(e)}'
}
def disconnect_database(self, user_id: int) -> Dict[str, Any]:
"""断开数据库连接"""
try:
# 从PostgreSQLMCPTool断开连接
self.postgresql_tool.execute(
operation="disconnect",
user_id=str(user_id)
)
# 从本地连接管理中移除
if user_id in self.user_connections:
del self.user_connections[user_id]
return {
'success': True,
'message': '数据库连接已断开'
}
except Exception as e:
return {
'success': False,
'message': f'断开连接失败: {str(e)}'
}
async def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]:
"""根据数据库类型获取用户配置"""
stmt = select(DatabaseConfig).where(
DatabaseConfig.created_by == user_id,
DatabaseConfig.db_type == db_type,
DatabaseConfig.is_active == True
)
return await self.session.scalar(stmt)
async def create_or_update_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
"""创建或更新数据库配置保证db_type唯一性"""
try:
# 检查是否已存在该类型的配置
existing_config = self.get_config_by_type(user_id, config_data['db_type'])
if existing_config:
# 更新现有配置
for key, value in config_data.items():
if key == 'password':
setattr(existing_config, key, self._encrypt_password(value))
elif hasattr(existing_config, key):
setattr(existing_config, key, value)
await self.session.commit()
await self.session.refresh(existing_config)
logger.info(f"更新数据库配置成功: {existing_config.name} (ID: {existing_config.id})")
return existing_config
else:
# 创建新配置
return await self.create_config(user_id, config_data)
except Exception as e:
await self.session.rollback()
logger.error(f"创建或更新数据库配置失败: {str(e)}")
raise
async def describe_table(self, table_name: str, user_id: int) -> Dict[str, Any]:
"""获取表结构信息(复用已建立的连接)"""
try:
logger.error(f"未实现的逻辑,暂自编 - describe_table: {table_name}")
user_id_str = str(user_id)
# 获取用户默认数据库配置
default_config = self.get_default_config(user_id)
if not default_config:
return {
'success': False,
'message': '未找到默认数据库配置'
}
# 根据db_type选择相应的数据库工具
if default_config.db_type.lower() == 'postgresql':
db_tool = self.postgresql_tool
elif default_config.db_type.lower() == 'mysql':
db_tool = self.mysql_tool
else:
return {
'success': False,
'message': f'不支持的数据库类型: {default_config.db_type}'
}
# 检查是否已有连接
if user_id_str not in db_tool.connections:
return {
'success': False,
'message': '数据库连接已断开,请重新连接数据库'
}
# 使用已建立的连接执行describe_table操作
result = await db_tool.execute(
operation="describe_table",
user_id=user_id_str,
table_name=table_name
)
if not result.success:
return {
'success': False,
'message': result.error
}
return {
'success': True,
'data': result.result,
'db_type': default_config.db_type
}
except Exception as e:
logger.error(f"获取表结构失败: {str(e)}", exc_info=True)
return {
'success': False,
'message': f'获取表结构失败: {str(e)}'
}