diff --git a/data/chroma/kb_2/714cba61-f48b-47c5-bcb8-eed8ffd400b0/data_level0.bin b/data/chroma/kb_2/714cba61-f48b-47c5-bcb8-eed8ffd400b0/data_level0.bin index d352f28..eded1b7 100644 Binary files a/data/chroma/kb_2/714cba61-f48b-47c5-bcb8-eed8ffd400b0/data_level0.bin and b/data/chroma/kb_2/714cba61-f48b-47c5-bcb8-eed8ffd400b0/data_level0.bin differ diff --git a/data/chroma/kb_2/chroma.sqlite3 b/data/chroma/kb_2/chroma.sqlite3 index 504c313..78f49c3 100644 Binary files a/data/chroma/kb_2/chroma.sqlite3 and b/data/chroma/kb_2/chroma.sqlite3 differ diff --git a/th_agenter/services/workflow_engine.py b/th_agenter/services/workflow_engine.py index 5ad3b66..337d346 100644 --- a/th_agenter/services/workflow_engine.py +++ b/th_agenter/services/workflow_engine.py @@ -1030,14 +1030,35 @@ class WorkflowEngine: # 如果还是没有,尝试从 workflow_input 中获取 if not query: workflow_input = input_data.get('workflow_input', {}) + # 首先尝试获取 'query' 字段 query = workflow_input.get('query', '') + # 如果没有 'query' 字段,尝试获取第一个非空的字符串值作为查询文本 + if not query and isinstance(workflow_input, dict): + for key, value in workflow_input.items(): + if isinstance(value, str) and value.strip(): + query = value.strip() + logger.info(f"从工作流输入的 '{key}' 字段获取查询文本: {query}") + break # 如果还是没有,尝试从 previous_outputs 中获取(可能是上一个节点的输出) if not query: previous_outputs = input_data.get('previous_outputs', {}) # 尝试从上一个节点的输出中获取查询文本 for node_id, output in previous_outputs.items(): if isinstance(output, dict): - query = output.get('query') or output.get('data', {}).get('query', '') + # 首先尝试从根级别获取 + query = output.get('query', '') + if not query: + # 尝试从 data 字段中获取 + data = output.get('data', {}) + if isinstance(data, dict): + query = data.get('query', '') + # 如果 data 中没有 query,尝试获取第一个非空字符串值 + if not query: + for key, value in data.items(): + if isinstance(value, str) and value.strip(): + query = value.strip() + logger.info(f"从节点 {node_id} 输出的 data.{key} 字段获取查询文本: {query}") + break if query: break # 如果还是没有,使用默认查询或抛出错误 @@ -1057,30 +1078,47 @@ class WorkflowEngine: try: # 直接使用 document_processor 进行搜索 - # 注意:get_document_processor 期望同步 Session,但这里传入 None 以避免类型不匹配 + # 注意:get_document_processor 需要 session 来初始化嵌入模型 from ..services.document_processor import get_document_processor - document_processor = await get_document_processor(None) + # 传入 self.session 以便初始化嵌入模型(虽然类型不匹配,但 get_document_processor 会处理) + document_processor = await get_document_processor(self.session) results = document_processor.search_similar_documents( knowledge_base_id=knowledge_base_id, query=query, k=top_k ) + logger.info(f"知识库 {knowledge_base_id} 搜索查询 '{query}' 返回 {len(results)} 个原始结果") + # 过滤相似度阈值 filtered_results = [] + all_results = [] for result in results: score = result.get('normalized_score', result.get('similarity_score', 0)) + all_results.append({ + **result, + 'score': score + }) if score >= similarity_threshold: filtered_results.append(result) + logger.info(f"应用相似度阈值 {similarity_threshold} 后,剩余 {len(filtered_results)} 个结果") + + # 如果过滤后结果为空,但原始结果不为空,返回所有结果并添加警告 + if not filtered_results and results: + logger.warning(f"所有搜索结果都被相似度阈值 {similarity_threshold} 过滤,返回前 {min(len(results), top_k)} 个结果") + filtered_results = results[:top_k] + return { 'success': True, 'query': query, 'knowledge_base_id': knowledge_base_id, 'results': filtered_results, 'total_results': len(filtered_results), + 'raw_results_count': len(results), 'top_k': top_k, - 'similarity_threshold': similarity_threshold + 'similarity_threshold': similarity_threshold, + 'all_results_scores': [r.get('score', 0) for r in all_results[:5]] if all_results else [] } except Exception as e: