hyf-backend/th_agenter/llm/local/local_llm.py

69 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List, Optional
from th_agenter.llm.base_llm import BaseLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langchain_core.outputs import ChatResult, ChatGeneration
class LocalLLM(BaseLLM):
def __init__(self, config):
super().__init__(config)
self.local_config = config
def _validate_config(self):
if not self.local_config.model_path:
raise ValueError("LocalLLM 必须配置 model_path")
def load_model(self):
from langchain_community.llms import LlamaCpp
self.model = LlamaCpp(
model_path=self.local_config.model_path,
temperature=self.local_config.temperature,
max_tokens=self.local_config.max_tokens,
n_ctx=self.local_config.n_ctx,
n_threads=self.local_config.n_threads,
verbose=False
)
@property
def _llm_type(self) -> str:
return "llama"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
) -> ChatResult:
if not self.model:
self.load_model()
# 适配 LlamaCpp非 Chat 模型)的调用方式
prompt = self._format_messages(messages)
text = self.model.invoke(prompt, stop=stop, **kwargs)
# 构造 ChatResultLangChain 标准格式)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
) -> ChatResult:
if not self.model:
self.load_model()
prompt = self._format_messages(messages)
text = await self.model.ainvoke(prompt, stop=stop, **kwargs)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])
def _format_messages(self, messages: List[BaseMessage]) -> str:
"""将 LangChain 消息列表格式化为本地模型的 Prompt"""
prompt_parts = []
for msg in messages:
if isinstance(msg, HumanMessage):
prompt_parts.append(f"<s>[INST] {msg.content} [/INST]")
elif isinstance(msg, AIMessage):
prompt_parts.append(msg.content)
return "".join(prompt_parts)