hyf-backend/th_agenter/services/workflow_engine.py

1137 lines
49 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
"""Workflow execution engine."""
import asyncio
import json
import time
from datetime import datetime
from typing import Dict, Any, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from ..models.workflow import Workflow, WorkflowExecution, NodeExecution, ExecutionStatus, NodeType
from ..models.llm_config import LLMConfig
from ..services.llm_service import LLMService
from ..db.database import get_session, AsyncSessionFactory
from loguru import logger
class WorkflowEngine:
"""工作流执行引擎"""
def __init__(self, session: AsyncSession):
self.session = session
self.llm_service = LLMService()
async def execute_workflow(self, workflow: Workflow, input_data: Optional[Dict[str, Any]] = None,
user_id: int = None, session: AsyncSession = None):
"""执行工作流"""
from ..schemas.workflow import WorkflowExecutionResponse
id = workflow.id
if session:
self.session = session
session.desc = f"执行工作流数据 - {id} > Enter"
# 创建执行记录
execution = WorkflowExecution(
workflow_id=id,
status=ExecutionStatus.RUNNING,
input_data=input_data or {},
executor_id=user_id,
started_at=datetime.now().isoformat()
)
session.desc = f"执行工作流数据 - {id} > 创建执行记录"
execution.set_audit_fields(user_id)
self.session.add(execution)
await self.session.commit()
await self.session.refresh(execution)
session.desc = f"执行工作流数据 - {id} > 添加执行记录"
try:
# 重新加载 workflow 对象,确保数据是最新的
from sqlalchemy import select
from ..models.workflow import Workflow
result = await session.execute(
select(Workflow).where(Workflow.id == id)
)
workflow = result.scalar_one_or_none()
session.desc = f"执行工作流数据 - {id} > reload workflow"
# 解析工作流定义
definition = workflow.definition
nodes = {node['id']: node for node in definition['nodes']}
connections = definition['connections']
session.desc = f"执行工作流数据 - {id} > definition {id}"
# 构建节点依赖图
node_graph = self._build_node_graph(nodes, connections)
session.desc = f"执行工作流数据 - {id} > _build_node_graph {id}"
# 执行工作流
result = await self._execute_nodes(execution, nodes, node_graph, input_data or {})
session.desc = f"执行工作流数据 - {id} > _execute_nodes {id}"
# 更新执行状态
execution.status = ExecutionStatus.COMPLETED
execution.output_data = result
execution.completed_at = datetime.now().isoformat()
session.desc = f"执行工作流数据 - {id} > execution {id}"
except Exception as e:
logger.error(f"工作流执行失败 - {id}: {str(e)}")
execution.status = ExecutionStatus.FAILED
execution.error_message = str(e)
execution.completed_at = datetime.now().isoformat()
execution.set_audit_fields(user_id, is_update=True)
session.desc = f"执行工作流数据 - {id} > set_audit_fields {id}"
await self.session.commit()
await self.session.refresh(execution)
session.desc = f"执行工作流数据 - {id} > refresh {id}"
from sqlalchemy import select
from ..models.workflow import NodeExecution
result = await session.execute(
select(NodeExecution).where(NodeExecution.workflow_execution_id == execution.id)
)
node_executions = result.scalars().all()
session.desc = f"执行工作流数据 - {id} > load node_executions {id}"
node_executions = [node.to_dict() for node in node_executions]
execution_dict = execution.to_dict()
execution_dict['node_executions'] = node_executions
session.desc = f"执行工作流数据 - {id} > build response {id}"
return WorkflowExecutionResponse(**execution_dict)
async def execute_workflow_stream(self, workflow: 'Workflow', input_data: Optional[Dict[str, Any]] = None,
user_id: int = None, session: AsyncSession = None):
"""流式执行工作流,实时推送节点状态"""
from ..schemas.workflow import WorkflowExecutionResponse
from typing import AsyncGenerator
if session:
self.session = session
# 创建执行记录
execution = WorkflowExecution(
workflow_id=workflow.id,
status=ExecutionStatus.RUNNING,
input_data=input_data or {},
executor_id=user_id,
started_at=datetime.now().isoformat()
)
execution.set_audit_fields(user_id)
self.session.add(execution)
await self.session.commit()
await self.session.refresh(execution)
# 发送工作流开始执行的消息
yield {
'type': 'workflow_status',
'execution_id': execution.id,
'status': 'started',
'data': {
"workflow_id": workflow.id,
"workflow_name": workflow.name,
"input_data": input_data or {},
"started_at": execution.started_at
},
'timestamp': datetime.now().isoformat()
}
try:
# 解析工作流定义
definition = workflow.definition
nodes = {node['id']: node for node in definition['nodes']}
connections = definition['connections']
# 构建节点依赖图
node_graph = self._build_node_graph(nodes, connections)
# 执行工作流(流式版本)
result = None
async for step_data in self._execute_nodes_stream(execution, nodes, node_graph, input_data or {}):
yield step_data
# 如果是最终结果,保存它
if step_data.get('type') == 'workflow_result':
result = step_data.get('data', {})
# 更新执行状态
execution.status = ExecutionStatus.COMPLETED
execution.output_data = result
execution.completed_at = datetime.now().isoformat()
# 发送工作流完成的消息
yield {
'type': 'workflow_status',
'execution_id': execution.id,
'status': 'completed',
'data': {
"output_data": result,
"completed_at": execution.completed_at
},
'timestamp': datetime.now().isoformat()
}
except Exception as e:
logger.error(f"工作流执行失败: {str(e)}")
execution.status = ExecutionStatus.FAILED
execution.error_message = str(e)
execution.completed_at = datetime.now().isoformat()
# 发送工作流失败的消息
yield {
'type': 'workflow_status',
'execution_id': execution.id,
'status': 'failed',
'data': {
"error_message": str(e),
"completed_at": execution.completed_at
},
'timestamp': datetime.now().isoformat()
}
execution.set_audit_fields(user_id, is_update=True)
await self.session.commit()
await self.session.refresh(execution)
def _build_node_graph(self, nodes: Dict[str, Any], connections: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""构建节点依赖图"""
graph = {}
for node_id, node in nodes.items():
graph[node_id] = {
'node': node,
'inputs': [], # 输入节点
'outputs': [] # 输出节点
}
for connection in connections:
# 支持两种字段名格式from/to 和 from_node/to_node
from_node = connection.get('from') or connection.get('from_node')
to_node = connection.get('to') or connection.get('to_node')
if from_node in graph and to_node in graph:
graph[from_node]['outputs'].append(to_node)
graph[to_node]['inputs'].append(from_node)
return graph
async def _execute_nodes(self, execution: WorkflowExecution, nodes: Dict[str, Any],
node_graph: Dict[str, Dict[str, Any]], workflow_input: Dict[str, Any]) -> Dict[str, Any]:
"""执行节点"""
# 找到开始节点
start_nodes = [node_id for node_id, info in node_graph.items()
if info['node']['type'] == 'start']
if not start_nodes:
raise ValueError("未找到开始节点")
if len(start_nodes) > 1:
raise ValueError("存在多个开始节点")
start_node_id = start_nodes[0]
# 执行上下文
context = {
'workflow_input': workflow_input,
'node_outputs': {}
}
# 从开始节点开始执行
await self._execute_node_recursive(execution, start_node_id, node_graph, context)
# 找到结束节点的输出作为工作流结果
end_nodes = [node_id for node_id, info in node_graph.items()
if info['node']['type'] == 'end']
if end_nodes:
end_node_id = end_nodes[0]
return context['node_outputs'].get(end_node_id, {})
return {}
async def _execute_nodes_stream(self, execution: WorkflowExecution, nodes: Dict[str, Any],
node_graph: Dict[str, Dict[str, Any]], workflow_input: Dict[str, Any]):
"""流式执行节点,实时推送节点状态"""
# 找到开始节点
start_nodes = [node_id for node_id, info in node_graph.items()
if info['node']['type'] == 'start']
if not start_nodes:
raise ValueError("未找到开始节点")
if len(start_nodes) > 1:
raise ValueError("存在多个开始节点")
start_node_id = start_nodes[0]
# 执行上下文
context = {
'workflow_input': workflow_input,
'node_outputs': {}
}
# 从开始节点开始执行
async for step_data in self._execute_node_recursive_stream(execution, start_node_id, node_graph, context):
yield step_data
# 找到结束节点的输出作为工作流结果
end_nodes = [node_id for node_id, info in node_graph.items()
if info['node']['type'] == 'end']
if end_nodes:
end_node_id = end_nodes[0]
result = context['node_outputs'].get(end_node_id, {})
else:
result = {}
# 发送最终结果
yield {
'type': 'workflow_result',
'execution_id': execution.id,
'data': result,
'timestamp': datetime.now().isoformat()
}
async def _execute_node_recursive_stream(self, execution: WorkflowExecution, node_id: str,
node_graph: Dict[str, Dict[str, Any]], context: Dict[str, Any]):
"""递归执行节点(流式版本)"""
if node_id in context['node_outputs']:
# 节点已执行过
return
node_info = node_graph[node_id]
node = node_info['node']
# 等待所有输入节点完成
for input_node_id in node_info['inputs']:
async for step_data in self._execute_node_recursive_stream(execution, input_node_id, node_graph, context):
yield step_data
# 发送节点开始执行的消息
yield {
'type': 'node_status',
'execution_id': execution.id,
'node_id': node_id,
'status': 'started',
'data': {
'node_name': node.get('name', ''),
'node_type': node.get('type', ''),
'started_at': datetime.now().isoformat()
},
'timestamp': datetime.now().isoformat()
}
try:
# 执行当前节点
output = await self._execute_single_node(execution, node, context)
context['node_outputs'][node_id] = output
# 发送节点完成的消息
yield {
'type': 'node_status',
'execution_id': execution.id,
'node_id': node_id,
'status': 'completed',
'data': {
'node_name': node.get('name', ''),
'node_type': node.get('type', ''),
'output': output,
'completed_at': datetime.now().isoformat()
},
'timestamp': datetime.now().isoformat()
}
except Exception as e:
# 发送节点失败的消息
yield {
'type': 'node_status',
'execution_id': execution.id,
'node_id': node_id,
'status': 'failed',
'data': {
'node_name': node.get('name', ''),
'node_type': node.get('type', ''),
'error_message': str(e),
'failed_at': datetime.now().isoformat()
},
'timestamp': datetime.now().isoformat()
}
raise
# 执行所有输出节点
for output_node_id in node_info['outputs']:
async for step_data in self._execute_node_recursive_stream(execution, output_node_id, node_graph, context):
yield step_data
async def _execute_node_recursive(self, execution: WorkflowExecution, node_id: str,
node_graph: Dict[str, Dict[str, Any]], context: Dict[str, Any]):
"""递归执行节点"""
if node_id in context['node_outputs']:
# 节点已执行过
return
node_info = node_graph[node_id]
node = node_info['node']
# 等待所有输入节点完成
for input_node_id in node_info['inputs']:
await self._execute_node_recursive(execution, input_node_id, node_graph, context)
# 执行当前节点
output = await self._execute_single_node(execution, node, context)
context['node_outputs'][node_id] = output
# 执行所有输出节点
for output_node_id in node_info['outputs']:
await self._execute_node_recursive(execution, output_node_id, node_graph, context)
async def _execute_single_node(self, execution: WorkflowExecution, node: Dict[str, Any],
context: Dict[str, Any]) -> Dict[str, Any]:
"""执行单个节点"""
node_id = node['id']
node_type = node['type']
node_name = node['name']
# 创建节点执行记录
node_execution = NodeExecution(
workflow_execution_id=execution.id,
node_id=node_id,
node_type=NodeType(node_type),
node_name=node_name,
status=ExecutionStatus.RUNNING,
started_at=datetime.now().isoformat()
)
self.session.add(node_execution)
await self.session.commit()
await self.session.refresh(node_execution)
await self.session.refresh(execution)
start_time = time.time()
try:
# 准备输入数据
input_data = self._prepare_node_input(node, context)
# 为前端显示准备输入数据
display_input_data = input_data.copy()
# 对于开始节点显示的输入应该是workflow_input
if node_type == 'start':
display_input_data = input_data['workflow_input']
elif node_type == 'llm':
# 对于LLM节点先执行变量替换以获取处理后的提示词
config = input_data['node_config']
prompt_template = config.get('prompt', '')
enable_variable_substitution = config.get('enable_variable_substitution', True)
if enable_variable_substitution:
processed_prompt = self._substitute_variables(prompt_template, input_data)
else:
processed_prompt = prompt_template
display_input_data = {
'original_prompt': prompt_template,
'processed_prompt': processed_prompt,
'model_config': config,
'resolved_inputs': input_data.get('resolved_inputs', {})
}
node_execution.input_data = display_input_data
await self.session.commit()
await self.session.refresh(execution)
# 根据节点类型执行
if node_type == 'start':
output_data = await self._execute_start_node(node, input_data)
elif node_type == 'end':
output_data = await self._execute_end_node(node, input_data)
elif node_type == 'llm':
output_data = await self._execute_llm_node(node, input_data)
elif node_type == 'condition':
output_data = await self._execute_condition_node(node, input_data)
elif node_type == 'code':
output_data = await self._execute_code_node(node, input_data)
elif node_type == 'http':
output_data = await self._execute_http_node(node, input_data)
elif node_type == 'knowledge-base':
output_data = await self._execute_knowledge_base_node(node, input_data)
2026-01-21 13:45:39 +08:00
else:
raise ValueError(f"不支持的节点类型: {node_type}")
# 更新执行状态
end_time = time.time()
node_execution.status = ExecutionStatus.COMPLETED
node_execution.output_data = output_data
node_execution.completed_at = datetime.now().isoformat()
node_execution.duration_ms = int((end_time - start_time) * 1000)
await self.session.commit()
await self.session.refresh(execution)
return output_data
except Exception as e:
logger.error(f"节点 {node_id} 执行失败: {str(e)}")
end_time = time.time()
node_execution.status = ExecutionStatus.FAILED
node_execution.error_message = str(e)
node_execution.completed_at = datetime.now().isoformat()
node_execution.duration_ms = int((end_time - start_time) * 1000)
await self.session.commit()
await self.session.refresh(execution)
raise
def _prepare_node_input(self, node: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
"""准备节点输入数据"""
# 基础输入数据
input_data = {
'workflow_input': context['workflow_input'],
'node_config': node.get('config', {}),
'previous_outputs': context['node_outputs']
}
# 处理节点参数配置
node_parameters = node.get('parameters', {})
if node_parameters and 'inputs' in node_parameters:
resolved_inputs = {}
for param in node_parameters['inputs']:
param_name = param.get('name')
param_source = param.get('source', 'default')
param_default = param.get('default_value')
variable_name = param.get('variable_name', '')
# 优先使用variable_name如果存在的话
if variable_name:
resolved_value = self._resolve_variable_value(variable_name, context)
resolved_inputs[param_name] = resolved_value if resolved_value is not None else param_default
elif param_source == 'workflow':
# 从工作流输入获取
source_param_name = param.get('source_param_name', param_name)
resolved_inputs[param_name] = context['workflow_input'].get(source_param_name, param_default)
elif param_source == 'node':
# 从其他节点输出获取
source_node_id = param.get('source_node_id')
source_param_name = param.get('source_param_name', 'data')
if source_node_id and source_node_id in context['node_outputs']:
source_output = context['node_outputs'][source_node_id]
if isinstance(source_output, dict):
resolved_inputs[param_name] = source_output.get(source_param_name, param_default)
else:
resolved_inputs[param_name] = source_output
else:
resolved_inputs[param_name] = param_default
else:
# 使用默认值
resolved_inputs[param_name] = param_default
# 将解析后的参数添加到输入数据
input_data['resolved_inputs'] = resolved_inputs
return input_data
def _resolve_variable_value(self, variable_name: str, context: Dict[str, Any]) -> Any:
"""解析变量值,支持格式如 "node_id.output.field_name" 或更深层路径"""
try:
# 解析变量名格式node_id.output.field_name 或 node_id.field1.field2.field3
parts = variable_name.split('.')
if len(parts) >= 2:
source_node_id = parts[0]
# 从previous_outputs中获取源节点的输出
if source_node_id in context['node_outputs']:
source_output = context['node_outputs'][source_node_id]
if isinstance(source_output, dict):
# 从第二个部分开始遍历路径
current_value = source_output
for field_name in parts[1:]:
if isinstance(current_value, dict) and field_name in current_value:
current_value = current_value[field_name]
else:
# 如果路径不存在返回None
return None
return current_value
return None
except Exception as e:
logger.warning(f"解析变量值失败: {variable_name}, 错误: {str(e)}")
return None
async def _execute_start_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""执行开始节点"""
# 开始节点的输入和输出应该一致都是workflow_input
workflow_input = input_data['workflow_input']
return {
'success': True,
'message': '工作流开始',
'data': workflow_input,
'user_input': workflow_input # 添加用户输入显示
}
async def _execute_end_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""执行结束节点"""
previous_outputs = input_data.get('previous_outputs', {})
# 处理结束节点的输出参数配置
node_parameters = node.get('parameters') or {}
output_params = node_parameters.get('outputs', []) if isinstance(node_parameters, dict) else []
2026-01-21 13:45:39 +08:00
result_data = {}
# 根据输出参数配置获取对应的值
for param in output_params:
param_name = param.get('name')
variable_name = param.get('variable_name')
if variable_name:
# 解析variable_name格式如: "node_1759022611056.output.response"
try:
parts = variable_name.split('.')
if len(parts) >= 3:
source_node_id = parts[0]
output_type = parts[1] # 通常是"output"
field_name = parts[2] # 具体的字段名,如"response"
# 从前一个节点的输出中获取值
if source_node_id in previous_outputs:
source_output = previous_outputs[source_node_id]
if isinstance(source_output, dict):
# 首先尝试从根级别获取字段如LLM节点的response字段
if field_name in source_output:
result_data[param_name] = source_output[field_name]
# 如果根级别没有再尝试从data字段中获取
elif 'data' in source_output and isinstance(source_output['data'], dict):
if field_name in source_output['data']:
result_data[param_name] = source_output['data'][field_name]
else:
result_data[param_name] = None
else:
result_data[param_name] = None
else:
result_data[param_name] = source_output
else:
result_data[param_name] = None
else:
# 格式不正确,使用默认值
result_data[param_name] = param.get('default_value')
except Exception as e:
logger.warning(f"解析variable_name失败: {variable_name}, 错误: {str(e)}")
result_data[param_name] = param.get('default_value')
else:
# 没有variable_name使用默认值
result_data[param_name] = param.get('default_value')
# 如果没有配置输出参数,返回简化的前一个节点输出(保持向后兼容)
if not output_params:
simplified_outputs = {}
for node_id, output in previous_outputs.items():
if isinstance(output, dict):
simplified_outputs[node_id] = {
'success': output.get('success', False),
'message': output.get('message', ''),
'data': output.get('data', {}) if not isinstance(output.get('data'), dict) or node_id not in str(output.get('data', {})) else {}
}
else:
simplified_outputs[node_id] = output
result_data = simplified_outputs
return {
'success': True,
'message': '工作流结束',
'data': result_data
}
async def _execute_llm_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""执行LLM节点"""
config = input_data.get('node_config', {})
2026-01-21 13:45:39 +08:00
# 获取LLM配置
model_id = config.get('model_id')
if not model_id:
# 兼容前端的model字段可能是ID或名称
model_value = config.get('model_name', config.get('model'))
if model_value:
# 如果是整数直接作为ID使用
if isinstance(model_value, int):
model_id = model_value
else:
# 如果是字符串,按名称查询
from sqlalchemy import select
result = await self.session.execute(
select(LLMConfig).where(LLMConfig.model_name == model_value)
)
llm_config = result.scalar_one_or_none()
2026-01-21 13:45:39 +08:00
if llm_config:
model_id = llm_config.id
# 如果还是没有,尝试从节点定义本身获取
if not model_id:
node_config = node.get('config', {})
model_id = node_config.get('model_id')
if not model_id:
model_value = node_config.get('model_name', node_config.get('model'))
if model_value:
if isinstance(model_value, int):
model_id = model_value
else:
from sqlalchemy import select
result = await self.session.execute(
select(LLMConfig).where(LLMConfig.model_name == model_value)
)
llm_config = result.scalar_one_or_none()
if llm_config:
model_id = llm_config.id
# 如果还是没有尝试使用默认的LLM配置
2026-01-21 13:45:39 +08:00
if not model_id:
from ..services.llm_config_service import LLMConfigService
llm_config_service = LLMConfigService()
default_config = await llm_config_service.get_default_chat_config(self.session)
if default_config:
model_id = default_config.id
logger.info(f"LLM节点未指定模型配置使用默认模型: {default_config.model_name} (ID: {model_id})")
else:
raise ValueError(
"未指定有效的大模型配置,且未找到默认配置。\n"
"请在节点配置中添加模型ID或模型名称例如\n"
" - config.model_id: 1\n"
" - config.model_name: 'gpt-4'\n"
" - config.model: 'gpt-4'\n"
"或者设置一个默认的LLM配置。"
)
2026-01-21 13:45:39 +08:00
from sqlalchemy import select
result = await self.session.execute(
select(LLMConfig).where(LLMConfig.id == model_id)
)
llm_config = result.scalar_one_or_none()
2026-01-21 13:45:39 +08:00
if not llm_config:
raise ValueError(f"大模型配置 {model_id} 不存在")
# 准备提示词
prompt_template = config.get('prompt', '')
# 检查是否启用变量替换
enable_variable_substitution = config.get('enable_variable_substitution', True)
if enable_variable_substitution:
# 使用增强的变量替换
prompt = self._substitute_variables(prompt_template, input_data)
else:
prompt = prompt_template
# 记录处理后的提示词到输入数据中,用于前端显示
input_data['processed_prompt'] = prompt
input_data['original_prompt'] = prompt_template
# 调用LLM服务
try:
response = await self.llm_service.chat_completion(
model_config=llm_config,
messages=[{"role": "user", "content": prompt}],
temperature=config.get('temperature', 0.7),
max_tokens=config.get('max_tokens')
)
return {
'success': True,
'response': response,
'prompt': prompt,
'model': llm_config.model_name,
'tokens_used': getattr(response, 'usage', {}).get('total_tokens', 0) if hasattr(response, 'usage') else 0
}
except Exception as e:
logger.error(f"LLM调用失败: {str(e)}")
raise ValueError(f"LLM调用失败: {str(e)}")
def _substitute_variables(self, template: str, input_data: Dict[str, Any]) -> str:
"""变量替换函数"""
import re
# 获取解析后的输入参数
resolved_inputs = input_data.get('resolved_inputs', {})
# 获取工作流输入数据
# input_data['workflow_input'] 包含了用户输入的参数
workflow_input = input_data.get('workflow_input', {})
# 构建变量上下文
variable_context = {}
# 首先添加解析后的参数
variable_context.update(resolved_inputs)
# 添加工作流输入的顶层字段
variable_context.update(workflow_input)
# 如果 workflow_input 包含 user_input 字段,将其内容提升到顶层
if 'user_input' in workflow_input and isinstance(workflow_input['user_input'], dict):
variable_context.update(workflow_input['user_input'])
# 添加前一个节点的输出(简化访问)
for node_id, output in input_data.get('previous_outputs', {}).items():
if isinstance(output, dict):
# 添加节点输出的直接访问
variable_context[f'node_{node_id}'] = output.get('data', output)
# 如果输出有response字段也添加直接访问
if 'response' in output:
variable_context[f'node_{node_id}_response'] = output['response']
# 调试日志:打印变量上下文
logger.info(f"变量替换上下文: {variable_context}")
logger.info(f"原始模板: {template}")
# 使用正则表达式替换变量 {{variable_name}} 和 {variable_name}
def replace_variable(match):
var_name = match.group(1)
replacement = variable_context.get(var_name, match.group(0))
logger.info(f"替换变量 {match.group(0)} -> {replacement}")
return str(replacement)
# 首先替换 {{variable_name}} 格式的变量
result = re.sub(r'\{\{([^}]+)\}\}', replace_variable, template)
# 然后替换 {variable_name} 格式的变量
result = re.sub(r'\{([^}]+)\}', replace_variable, result)
logger.info(f"替换后结果: {result}")
return result
async def _execute_condition_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""执行条件节点"""
config = input_data['node_config']
condition = config.get('condition', '')
# 简单的条件评估(生产环境需要更安全的实现)
try:
# 构建评估上下文
eval_context = {
'input': input_data['workflow_input'],
'previous': input_data['previous_outputs']
}
# 评估条件
result = eval(condition, {"__builtins__": {}}, eval_context)
return {
'success': True,
'condition': condition,
'result': bool(result)
}
except Exception as e:
logger.error(f"条件评估失败: {str(e)}")
raise ValueError(f"条件评估失败: {str(e)}")
async def _execute_code_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""执行代码节点"""
config = input_data['node_config']
language = config.get('language', 'python')
code = config.get('code', '')
if language == 'python':
# 执行Python代码
execution_result = await self._execute_python_code(code, input_data)
# 处理输出参数配置
node_parameters = node.get('parameters', {})
if node_parameters and 'outputs' in node_parameters:
output_params = node_parameters['outputs']
code_result = execution_result.get('result', {})
# 根据输出参数配置构建最终输出
final_output = {
'success': execution_result['success'],
'code': execution_result['code'],
'input_parameters': execution_result.get('input_parameters', {})
}
# 如果代码返回的是字典,根据输出参数配置提取对应字段
if isinstance(code_result, dict):
for output_param in output_params:
param_name = output_param.get('name')
if param_name and param_name in code_result:
final_output[param_name] = code_result[param_name]
else:
# 如果代码返回的不是字典,且只有一个输出参数,直接使用返回值
if len(output_params) == 1:
param_name = output_params[0].get('name')
if param_name:
final_output[param_name] = code_result
return final_output
else:
# 如果没有输出参数配置,返回原始结果
return execution_result
else:
raise ValueError(f"不支持的代码语言: {language}")
async def _execute_python_code(self, code: str, input_data: Dict[str, Any]) -> Dict[str, Any]:
"""执行Python代码"""
try:
# 构建执行上下文
safe_builtins = {
'len': len,
'str': str,
'int': int,
'float': float,
'bool': bool,
'list': list,
'dict': dict,
'tuple': tuple,
'set': set,
'range': range,
'enumerate': enumerate,
'zip': zip,
'sum': sum,
'min': min,
'max': max,
'abs': abs,
'round': round,
'sorted': sorted,
'reversed': reversed,
'print': print,
'__import__': __import__,
}
# 导入常用模块
import json
import datetime
import math
import re
exec_context = {
'__builtins__': safe_builtins,
'json': json, # 允许使用json模块
'datetime': datetime, # 允许使用datetime模块
'math': math, # 允许使用math模块
're': re, # 允许使用re模块
}
# 执行代码以定义函数
exec(code, exec_context)
# 检查是否定义了main函数
if 'main' not in exec_context:
raise ValueError("代码中必须定义一个main函数")
main_function = exec_context['main']
# 获取已解析的输入参数
resolved_inputs = input_data.get('resolved_inputs', {})
# 调用main函数并传递参数
if resolved_inputs:
# 使用解析后的输入参数调用main函数
result = main_function(**resolved_inputs)
else:
# 如果没有输入参数直接调用main函数
result = main_function()
return {
'success': True,
'result': result,
'code': code,
'input_parameters': resolved_inputs
}
except Exception as e:
logger.error(f"Python代码执行失败: {str(e)}")
raise ValueError(f"Python代码执行失败: {str(e)}")
async def _execute_http_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""执行HTTP请求节点"""
import aiohttp
config = input_data['node_config']
method = config.get('method', 'GET').upper()
url = config.get('url', '')
headers = config.get('headers', {})
body = config.get('body')
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method=method,
url=url,
headers=headers,
data=body
) as response:
response_text = await response.text()
return {
'success': True,
'status_code': response.status,
'response': response_text,
'headers': dict(response.headers)
}
except Exception as e:
logger.error(f"HTTP请求失败: {str(e)}")
raise ValueError(f"HTTP请求失败: {str(e)}")
async def _execute_knowledge_base_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""执行知识库节点"""
from ..services.document import DocumentService
config = input_data.get('node_config', {})
# 支持多种字段名knowledge_base_id, knowledgeBase, kb_id
# 先从 config 中获取
knowledge_base_id = config.get('knowledge_base_id') or config.get('knowledgeBase') or config.get('kb_id')
# 如果 config 中没有,尝试从节点定义本身获取
if not knowledge_base_id:
knowledge_base_id = node.get('knowledge_base_id') or node.get('knowledgeBase') or node.get('kb_id')
# 如果还是没有,尝试从节点的 config 字段(节点定义中的 config获取
if not knowledge_base_id and 'config' in node:
node_config = node.get('config', {})
knowledge_base_id = node_config.get('knowledge_base_id') or node_config.get('knowledgeBase') or node_config.get('kb_id')
query = config.get('query', '')
top_k = config.get('top_k', config.get('topK', 5))
similarity_threshold = config.get('similarity_threshold', config.get('similarityThreshold', 0.7))
# 如果还是没有,尝试从节点名称中提取(例如 "knowledge-base 2" -> 2
if not knowledge_base_id:
node_name = node.get('name', '')
import re
# 尝试从名称中提取数字可能是知识库ID
match = re.search(r'(\d+)', node_name)
if match:
try:
potential_id = int(match.group(1))
logger.info(f"从节点名称 '{node_name}' 中提取到潜在的知识库ID: {potential_id}")
knowledge_base_id = potential_id
except:
pass
if not knowledge_base_id:
raise ValueError(
f"知识库节点配置缺少 knowledge_base_id (或 knowledgeBase/kb_id)。\n"
f"请在节点配置中添加知识库ID例如\n"
f" - config.knowledge_base_id: 2\n"
f" - config.knowledgeBase: 2\n"
f" - config.kb_id: 2\n"
f"当前节点配置: {config}"
)
if not query:
# 尝试从 resolved_inputs 中获取查询文本
resolved_inputs = input_data.get('resolved_inputs', {})
query = resolved_inputs.get('query', '')
# 如果还是没有,尝试从 workflow_input 中获取
if not query:
workflow_input = input_data.get('workflow_input', {})
# 首先尝试获取 'query' 字段
query = workflow_input.get('query', '')
# 如果没有 'query' 字段,尝试获取第一个非空的字符串值作为查询文本
if not query and isinstance(workflow_input, dict):
for key, value in workflow_input.items():
if isinstance(value, str) and value.strip():
query = value.strip()
logger.info(f"从工作流输入的 '{key}' 字段获取查询文本: {query}")
break
# 如果还是没有,尝试从 previous_outputs 中获取(可能是上一个节点的输出)
if not query:
previous_outputs = input_data.get('previous_outputs', {})
# 尝试从上一个节点的输出中获取查询文本
for node_id, output in previous_outputs.items():
if isinstance(output, dict):
# 首先尝试从根级别获取
query = output.get('query', '')
if not query:
# 尝试从 data 字段中获取
data = output.get('data', {})
if isinstance(data, dict):
query = data.get('query', '')
# 如果 data 中没有 query尝试获取第一个非空字符串值
if not query:
for key, value in data.items():
if isinstance(value, str) and value.strip():
query = value.strip()
logger.info(f"从节点 {node_id} 输出的 data.{key} 字段获取查询文本: {query}")
break
if query:
break
# 如果还是没有,使用默认查询或抛出错误
if not query:
# 如果没有查询文本,可以返回空结果而不是抛出错误
logger.warning(f"知识库节点缺少查询文本,将返回空结果")
return {
'success': True,
'query': '',
'knowledge_base_id': knowledge_base_id,
'results': [],
'total_results': 0,
'top_k': top_k,
'similarity_threshold': similarity_threshold,
'warning': '缺少查询文本,返回空结果'
}
try:
# 直接使用 document_processor 进行搜索
# 注意get_document_processor 需要 session 来初始化嵌入模型
from ..services.document_processor import get_document_processor
# 传入 self.session 以便初始化嵌入模型(虽然类型不匹配,但 get_document_processor 会处理)
document_processor = await get_document_processor(self.session)
results = document_processor.search_similar_documents(
knowledge_base_id=knowledge_base_id,
query=query,
k=top_k
)
logger.info(f"知识库 {knowledge_base_id} 搜索查询 '{query}' 返回 {len(results)} 个原始结果")
# 过滤相似度阈值
filtered_results = []
all_results = []
for result in results:
score = result.get('normalized_score', result.get('similarity_score', 0))
all_results.append({
**result,
'score': score
})
if score >= similarity_threshold:
filtered_results.append(result)
logger.info(f"应用相似度阈值 {similarity_threshold} 后,剩余 {len(filtered_results)} 个结果")
# 如果过滤后结果为空,但原始结果不为空,返回所有结果并添加警告
if not filtered_results and results:
logger.warning(f"所有搜索结果都被相似度阈值 {similarity_threshold} 过滤,返回前 {min(len(results), top_k)} 个结果")
filtered_results = results[:top_k]
return {
'success': True,
'query': query,
'knowledge_base_id': knowledge_base_id,
'results': filtered_results,
'total_results': len(filtered_results),
'raw_results_count': len(results),
'top_k': top_k,
'similarity_threshold': similarity_threshold,
'all_results_scores': [r.get('score', 0) for r in all_results[:5]] if all_results else []
}
except Exception as e:
logger.error(f"知识库搜索失败: {str(e)}")
raise ValueError(f"知识库搜索失败: {str(e)}")
2026-01-21 13:45:39 +08:00
# 工作流引擎实例
async def get_workflow_engine(session: AsyncSession = None) -> WorkflowEngine:
"""获取工作流引擎实例"""
if session is None:
async for s in get_session():
session = s
break
return WorkflowEngine(session)