hyf-backend/th_agenter/services/mcp/postgresql_mcp.py

385 lines
14 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
"""PostgreSQL MCP (Model Context Protocol) tool for database operations."""
import json
import psycopg2
from typing import List, Dict, Any, Optional
from datetime import datetime
from th_agenter.services.agent.base import BaseTool, ToolParameter, ToolParameterType, ToolResult
class PostgreSQLMCPTool(BaseTool):
"""PostgreSQL MCP tool for database operations and intelligent querying."""
def __init__(self):
super().__init__()
self.connections = {} # 存储用户的数据库连接
def get_name(self) -> str:
return "postgresql_mcp"
def get_description(self) -> str:
return "PostgreSQL MCP服务工具提供数据库连接、表结构查询、SQL执行等功能支持智能数据问答。"
def get_parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ToolParameterType.STRING,
description="操作类型",
required=True,
enum=["connect", "list_tables", "describe_table", "execute_query", "test_connection", "disconnect"]
),
ToolParameter(
name="connection_config",
type=ToolParameterType.OBJECT,
description="数据库连接配置 {host, port, database, username, password}",
required=False
),
ToolParameter(
name="user_id",
type=ToolParameterType.STRING,
description="用户ID用于管理连接",
required=False
),
ToolParameter(
name="table_name",
type=ToolParameterType.STRING,
description="表名用于describe_table操作",
required=False
),
ToolParameter(
name="sql_query",
type=ToolParameterType.STRING,
description="SQL查询语句用于execute_query操作",
required=False
),
ToolParameter(
name="limit",
type=ToolParameterType.INTEGER,
description="查询结果限制数量默认100",
required=False,
default=100
)
]
def _create_connection(self, config: Dict[str, Any]) -> psycopg2.extensions.connection:
"""创建PostgreSQL数据库连接"""
try:
connection = psycopg2.connect(
host=config['host'],
port=int(config.get('port', 5432)),
user=config['username'],
password=config['password'],
database=config['database'],
connect_timeout=10
)
return connection
except Exception as e:
raise Exception(f"PostgreSQL连接失败: {str(e)}")
def _test_connection(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""测试数据库连接"""
try:
conn = self._create_connection(config)
cursor = conn.cursor()
# 获取数据库版本信息
cursor.execute("SELECT version();")
version = cursor.fetchone()[0]
# 检查pgvector扩展
cursor.execute("SELECT * FROM pg_extension WHERE extname = 'vector';")
has_vector = bool(cursor.fetchall())
cursor.close()
conn.close()
return {
"success": True,
"version": version,
"has_pgvector": has_vector,
"message": "连接测试成功"
}
except Exception as e:
return {
"success": False,
"error": str(e),
"message": "连接测试失败"
}
def _get_tables(self, connection) -> List[Dict[str, Any]]:
"""获取数据库表列表"""
cursor = connection.cursor()
try:
cursor.execute("""
SELECT
table_name,
table_type,
table_schema
FROM information_schema.tables
WHERE table_schema = 'public'
ORDER BY table_name;
""")
tables = []
for row in cursor.fetchall():
tables.append({
"table_name": row[0],
"table_type": row[1],
"table_schema": row[2]
})
return tables
finally:
cursor.close()
def _describe_table(self, connection, table_name: str) -> Dict[str, Any]:
"""获取表结构信息"""
cursor = connection.cursor()
try:
# 获取列信息
cursor.execute("""
SELECT
column_name,
data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
FROM information_schema.columns
WHERE table_name = %s AND table_schema = 'public'
ORDER BY ordinal_position;
""", (table_name,))
columns = []
for row in cursor.fetchall():
columns.append({
"column_name": row[0],
"data_type": row[1],
"is_nullable": row[2],
"column_default": row[3],
"character_maximum_length": row[4],
"numeric_precision": row[5],
"numeric_scale": row[6]
})
# 获取主键信息
cursor.execute("""
SELECT column_name
FROM information_schema.key_column_usage
WHERE table_name = %s AND table_schema = 'public'
AND constraint_name IN (
SELECT constraint_name
FROM information_schema.table_constraints
WHERE table_name = %s AND constraint_type = 'PRIMARY KEY'
);
""", (table_name, table_name))
primary_keys = [row[0] for row in cursor.fetchall()]
# 获取表行数
cursor.execute(f"SELECT COUNT(*) FROM {table_name};")
row_count = cursor.fetchone()[0]
return {
"table_name": table_name,
"columns": columns,
"primary_keys": primary_keys,
"row_count": row_count
}
finally:
cursor.close()
def _execute_query(self, connection, sql_query: str, limit: int = 100) -> Dict[str, Any]:
"""执行SQL查询"""
cursor = connection.cursor()
try:
# 添加LIMIT限制如果查询中没有
if limit and "LIMIT" not in sql_query.upper():
sql_query = f"{sql_query.rstrip(';')} LIMIT {limit};"
cursor.execute(sql_query)
# 获取列名
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# 获取结果
if cursor.description: # SELECT查询
rows = cursor.fetchall()
data = []
for row in rows:
row_dict = {}
for i, value in enumerate(row):
if i < len(columns):
# 处理特殊数据类型
if isinstance(value, datetime):
row_dict[columns[i]] = value.isoformat()
else:
row_dict[columns[i]] = value
data.append(row_dict)
return {
"success": True,
"data": data,
"columns": columns,
"row_count": len(data),
"query": sql_query
}
else: # INSERT/UPDATE/DELETE查询
affected_rows = cursor.rowcount
return {
"success": True,
"affected_rows": affected_rows,
"query": sql_query,
"message": f"查询执行成功,影响 {affected_rows}"
}
finally:
cursor.close()
async def execute(self, operation: str, connection_config: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None, table_name: Optional[str] = None,
sql_query: Optional[str] = None, limit: int = 100) -> ToolResult:
"""执行PostgreSQL MCP操作"""
try:
logger.info(f"执行PostgreSQL MCP操作: {operation}")
if operation == "test_connection":
if not connection_config:
return ToolResult(
success=False,
error="缺少连接配置参数"
)
result = self._test_connection(connection_config)
return ToolResult(
success=result["success"],
result=result,
error=result.get("error")
)
elif operation == "connect":
if not connection_config or not user_id:
return ToolResult(
success=False,
error="缺少连接配置或用户ID参数"
)
try:
connection = self._create_connection(connection_config)
self.connections[user_id] = {
"connection": connection,
"config": connection_config,
"connected_at": datetime.now().isoformat()
}
# 获取表列表
tables = self._get_tables(connection)
return ToolResult(
success=True,
result={
"message": "数据库连接成功",
"database": connection_config["database"],
"tables": tables,
"table_count": len(tables)
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"连接失败: {str(e)}"
)
elif operation == "list_tables":
if not user_id or user_id not in self.connections:
return ToolResult(
success=False,
error="用户未连接数据库请先执行connect操作"
)
connection = self.connections[user_id]["connection"]
tables = self._get_tables(connection)
return ToolResult(
success=True,
result={
"tables": tables,
"table_count": len(tables)
}
)
elif operation == "describe_table":
if not user_id or user_id not in self.connections:
return ToolResult(
success=False,
error="用户未连接数据库请先执行connect操作"
)
if not table_name:
return ToolResult(
success=False,
error="缺少table_name参数"
)
connection = self.connections[user_id]["connection"]
table_info = self._describe_table(connection, table_name)
return ToolResult(
success=True,
result=table_info
)
elif operation == "execute_query":
if not user_id or user_id not in self.connections:
return ToolResult(
success=False,
error="用户未连接数据库请先执行connect操作"
)
if not sql_query:
return ToolResult(
success=False,
error="缺少sql_query参数"
)
connection = self.connections[user_id]["connection"]
query_result = self._execute_query(connection, sql_query, limit)
return ToolResult(
success=True,
result=query_result
)
elif operation == "disconnect":
if user_id and user_id in self.connections:
try:
self.connections[user_id]["connection"].close()
del self.connections[user_id]
return ToolResult(
success=True,
result={"message": "数据库连接已断开"}
)
except Exception as e:
return ToolResult(
success=False,
error=f"断开连接失败: {str(e)}"
)
else:
return ToolResult(
success=True,
result={"message": "用户未连接数据库"}
)
else:
return ToolResult(
success=False,
error=f"不支持的操作类型: {operation}"
)
except Exception as e:
logger.error(f"PostgreSQL MCP工具执行失败: {str(e)}", exc_info=True)
return ToolResult(
success=False,
error=f"工具执行失败: {str(e)}"
)