hyf-backend/th_agenter/llm/online/online_llm.py

80 lines
3.0 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, BaseMessage
from typing import List, Optional, Any, Union
from langchain_core.outputs import ChatResult
from th_agenter.llm.base_llm import BaseLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
class OnlineLLM(BaseLLM):
def __init__(self, config):
super().__init__(config)
def _validate_config(self):
if not self.config.api_key:
raise ValueError("OnlineLLM 必须配置 api_key")
def load_model(self):
# from langchain.chat_models import init_chat_model
# self.model = init_chat_model(
# self.config.model_name,
# self.config.api_key)
from langchain_openai import ChatOpenAI
self.model = ChatOpenAI(
api_key=self.config.api_key,
model_name=self.config.model_name,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
base_url=self.config.base_url,
)
@property
def _llm_type(self) -> str:
return "openai" # 标识模型类型
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
) -> ChatResult:
"""委托给底层 LangChain 模型的 _generate 方法"""
if not self.model:
self.load_model()
# 复用底层模型的实现
return self.model._generate(
messages=messages,
stop=stop,
run_manager=run_manager,** kwargs
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
) -> ChatResult:
if not self.model:
self.load_model()
return await self.model._agenerate(
messages=messages,
stop=stop,
run_manager=run_manager,** kwargs
)
# ---------------------- 保留自定义的便捷方法 ----------------------
def generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str:
"""自定义便捷方法:直接传入字符串 prompt 或消息列表"""
if isinstance(prompt, str):
messages = [HumanMessage(content=prompt)]
else:
messages = prompt
result = self._generate(messages, **kwargs)
return result.generations[0].text
async def async_generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str:
"""自定义便捷异步方法:直接传入字符串 prompt 或消息列表"""
if isinstance(prompt, str):
messages = [HumanMessage(content=prompt)]
else:
messages = prompt
result = await self._agenerate(messages, **kwargs)
return result.generations[0].text