hyf-backend/th_agenter/services/tools/search.py

75 lines
2.9 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
"""基于TavilySearch的搜索工具"""
from th_agenter.core.config import get_settings
from loguru import logger
from langchain.tools import BaseTool
from langchain_community.tools.tavily_search import TavilySearchResults
from pydantic import BaseModel, Field, PrivateAttr
from typing import Optional, Type, ClassVar
from langchain_tavily import TavilySearch
# 定义输入参数模型替代原get_parameters()
class SearchInput(BaseModel):
query: str = Field(description="搜索查询内容")
max_results: Optional[int] = Field(
default=5,
description="返回结果的最大数量默认5"
)
topic: Optional[str] = Field(
default="general",
description="搜索主题可选值general, academic, news, places"
)
class TavilySearchTool(BaseTool):
name:ClassVar[str] = "tavily_search_tool"
description:ClassVar[str] = """使用Tavily搜索引擎进行网络搜索可以获取最新信息。
输入应该包含搜索查询(query)可选参数包括max_results和topic""" # 替代get_description()
args_schema: Type[BaseModel] = SearchInput # 用Pydantic模型定义参数
_tavily_api_key: str = PrivateAttr()
_search_client: TavilySearchResults = PrivateAttr()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._tavily_api_key = get_settings().tool.tavily_api_key
if not self._tavily_api_key:
raise ValueError("Tavily API key not found in settings")
# 初始化Tavily客户端
self._search_client = TavilySearch(
tavily_api_key=self._tavily_api_key
)
def _run(self, query: str, max_results: int = 5, topic: str = "general"):
try:
logger.info(f"执行搜索:{query}")
# 调用TavilyLangChain已内置Tavily工具这里直接使用
results = self._search_client.run({
"query": query,
"max_results": max_results,
"topic": topic
})
# 格式化结果根据Tavily的实际返回结构调整
if isinstance(results, list):
return {
"status": "success",
"results": [
{
"title": r.get("title", ""),
"url": r.get("url", ""),
"content": r.get("content", "")[:200] + "..."
} for r in results
]
}
else:
return {"status": "error", "message": "Unexpected result format"}
except Exception as e:
logger.error(f"搜索失败: {str(e)}")
return {"status": "error", "message": str(e)}
async def _arun(self, **kwargs):
"""异步版本"""
"""直接调用同步版本"""
return self._run(**kwargs) # 直接委托给同步方法