feat: 添加流式输出支持至agentChat API

- 在agentChat API中新增流式输出接口,允许实时返回AI模型的响应
- 引入StreamingResponse以支持流式数据传输
- 增强知识库检索逻辑,支持多知识库的相似度检索
- 更新错误处理,确保在流式调用中捕获并返回异常信息
- 更新相关文档以反映新功能的使用方式
This commit is contained in:
eason 2026-01-28 16:08:20 +08:00
parent 67087e0664
commit 643c2f90c4
3 changed files with 174 additions and 0 deletions

Binary file not shown.

View File

@ -1,6 +1,7 @@
"""agentChat 接口:根据 AI 大模型、提示词、关联知识库输出结果。"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.orm import Session
from loguru import logger
@ -143,3 +144,176 @@ async def agent_chat(
references=references,
)
)
@router.post(
"/stream",
summary="agentChat按大模型、提示词、关联知识库流式输出结果",
)
@router.post(
"stream",
include_in_schema=False,
)
async def agent_chat_stream(
body: AgentChatRequest,
current_user: User = Depends(require_authenticated_user),
session: Session = Depends(get_session),
):
"""
agentChat 流式接口
根据选择的大模型关联的知识库和提示词实时流式返回模型输出文本
"""
prompt_text = (body.prompt or body.message or "").strip()
# 解析知识库 ID 列表:优先 knowledge_base_ids否则 [knowledge_base_id]
kb_ids: list[int] = []
if body.knowledge_base_ids:
try:
kb_ids = [
int(x)
for x in body.knowledge_base_ids
if x is not None and str(x).strip() != ""
]
kb_ids = [i for i in kb_ids if i >= 1]
except (ValueError, TypeError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="knowledge_base_ids 须为数字或数字字符串",
)
elif body.knowledge_base_id is not None and body.knowledge_base_id >= 1:
kb_ids = [body.knowledge_base_id]
session.title = "agentChatStream"
session.desc = (
f"START: agentChat/stream model_id={body.model_id}, "
f"prompt_len={len(prompt_text)}, knowledge_base_ids={kb_ids}"
)
# 1. 校验并获取大模型配置
stmt = select(LLMConfig).where(LLMConfig.id == body.model_id)
llm_config = (await session.execute(stmt)).scalar_one_or_none()
if not llm_config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在"
)
if not llm_config.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="该大模型配置未启用"
)
if llm_config.is_embedding:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="请选择对话型大模型,不能使用嵌入模型",
)
# 2. 若指定知识库,校验并检索(支持多知识库,结果合并后按相似度取 top_k
knowledge_base_used = False
references = None
final_prompt = prompt_text
if kb_ids:
doc_processor = await get_document_processor(session)
all_results: list = []
for kb_id in kb_ids:
kb_stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
kb = (await session.execute(kb_stmt)).scalar_one_or_none()
if not kb:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"知识库不存在: id={kb_id}",
)
if not kb.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"该知识库未启用: id={kb_id}",
)
part = doc_processor.search_similar_documents(
knowledge_base_id=kb_id,
query=prompt_text,
k=body.top_k,
)
all_results.extend(part)
def _score(r):
return float(r.get("normalized_score") or r.get("similarity_score") or 0)
all_results.sort(key=_score, reverse=True)
results = all_results[: body.top_k]
max_score = _score(results[0]) if results else 0.0
if results and max_score >= 0.5:
knowledge_base_used = True
refs = []
for i, r in enumerate(results[:5], 1):
content = (r.get("content") or "").strip()
if content:
if len(content) > 1000:
content = content[:1000] + "..."
refs.append(
{
"index": i,
"content": content,
"score": r.get("normalized_score"),
}
)
references = refs
context = "\n\n".join(
[f"【参考文档{ref['index']}\n{ref['content']}" for ref in refs]
)
final_prompt = f"""你是一个专业的助手。请仔细阅读以下参考文档,然后回答用户的问题。
{context}
用户问题
{prompt_text}
重要提示
- 参考文档中包含了与用户问题相关的信息
- 请仔细阅读参考文档提取相关信息来回答用户的问题
- 即使文档没有直接定义也要基于文档中的相关内容进行解释和说明
- 如果文档中提到了相关概念政策法规等请基于这些内容进行回答
- 回答要准确详细有条理尽量引用文档中的具体内容"""
logger.info(
f"agentChat/stream 使用 RAG知识库 {kb_ids},检索 {len(results)} 条,最高相似度 {max_score:.3f}"
)
else:
logger.info(
f"agentChat/stream知识库 {kb_ids} 检索结果相似度较低(最高 {max_score:.3f}),仅用提示词"
)
# 3. 调用大模型(流式)
llm_service = LLMService()
async def generate():
try:
async for chunk in llm_service.chat_completion_stream(
model_config=llm_config,
messages=[{"role": "user", "content": final_prompt}],
temperature=body.temperature
if body.temperature is not None
else llm_config.temperature,
max_tokens=body.max_tokens
if body.max_tokens is not None
else llm_config.max_tokens,
):
if not chunk:
continue
# 和 /chat/stream 一样,直接输出文本内容
yield chunk
except Exception as e:
logger.error(f"agentChat/stream LLM 流式调用失败: {e}")
# 将错误信息也推到流里,方便前端展示
yield f"[ERROR] 大模型调用失败: {str(e)}"
return StreamingResponse(
generate(),
media_type="text/stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)