from typing import List, Optional from th_agenter.llm.base_llm import BaseLLM from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.messages import BaseMessage, AIMessage, HumanMessage from langchain_core.outputs import ChatResult, ChatGeneration class LocalLLM(BaseLLM): def __init__(self, config): super().__init__(config) self.local_config = config def _validate_config(self): if not self.local_config.model_path: raise ValueError("LocalLLM 必须配置 model_path") def load_model(self): from langchain_community.llms import LlamaCpp self.model = LlamaCpp( model_path=self.local_config.model_path, temperature=self.local_config.temperature, max_tokens=self.local_config.max_tokens, n_ctx=self.local_config.n_ctx, n_threads=self.local_config.n_threads, verbose=False ) @property def _llm_type(self) -> str: return "llama" def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, ) -> ChatResult: if not self.model: self.load_model() # 适配 LlamaCpp(非 Chat 模型)的调用方式 prompt = self._format_messages(messages) text = self.model.invoke(prompt, stop=stop, **kwargs) # 构造 ChatResult(LangChain 标准格式) generation = ChatGeneration(message=AIMessage(content=text)) return ChatResult(generations=[generation]) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, ) -> ChatResult: if not self.model: self.load_model() prompt = self._format_messages(messages) text = await self.model.ainvoke(prompt, stop=stop, **kwargs) generation = ChatGeneration(message=AIMessage(content=text)) return ChatResult(generations=[generation]) def _format_messages(self, messages: List[BaseMessage]) -> str: """将 LangChain 消息列表格式化为本地模型的 Prompt""" prompt_parts = [] for msg in messages: if isinstance(msg, HumanMessage): prompt_parts.append(f"[INST] {msg.content} [/INST]") elif isinstance(msg, AIMessage): prompt_parts.append(msg.content) return "".join(prompt_parts)