from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, BaseMessage from typing import List, Optional, Any, Union from langchain_core.outputs import ChatResult from th_agenter.llm.base_llm import BaseLLM from langchain_core.callbacks import CallbackManagerForLLMRun class OnlineLLM(BaseLLM): def __init__(self, config): super().__init__(config) def _validate_config(self): if not self.config.api_key: raise ValueError("OnlineLLM 必须配置 api_key") def load_model(self): # from langchain.chat_models import init_chat_model # self.model = init_chat_model( # self.config.model_name, # self.config.api_key) from langchain_openai import ChatOpenAI self.model = ChatOpenAI( api_key=self.config.api_key, model_name=self.config.model_name, temperature=self.config.temperature, max_tokens=self.config.max_tokens, base_url=self.config.base_url, ) @property def _llm_type(self) -> str: return "openai" # 标识模型类型 def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, ) -> ChatResult: """委托给底层 LangChain 模型的 _generate 方法""" if not self.model: self.load_model() # 复用底层模型的实现 return self.model._generate( messages=messages, stop=stop, run_manager=run_manager,** kwargs ) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, ) -> ChatResult: if not self.model: self.load_model() return await self.model._agenerate( messages=messages, stop=stop, run_manager=run_manager,** kwargs ) # ---------------------- 保留自定义的便捷方法 ---------------------- def generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str: """自定义便捷方法:直接传入字符串 prompt 或消息列表""" if isinstance(prompt, str): messages = [HumanMessage(content=prompt)] else: messages = prompt result = self._generate(messages, **kwargs) return result.generations[0].text async def async_generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str: """自定义便捷异步方法:直接传入字符串 prompt 或消息列表""" if isinstance(prompt, str): messages = [HumanMessage(content=prompt)] else: messages = prompt result = await self._agenerate(messages, **kwargs) return result.generations[0].text