Files
life-echo/api/app/core/llm_gateway.py
2026-04-30 16:22:55 +08:00

108 lines
3.0 KiB
Python

"""Use-case oriented LLM gateway.
This is a small compatibility layer over the existing provider and JSON helper
functions. It gives new code a stable place to request model capabilities while
older agents continue to use LangChain directly during the transition.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, TypeVar
from pydantic import BaseModel
from app.core.dependencies import get_llm_provider, get_llm_provider_fast
from app.core.llm_call import allm_json_call, llm_json_call
T = TypeVar("T", bound=BaseModel)
@dataclass(frozen=True)
class LlmUseCase:
name: str
fast: bool = False
max_tokens: int | None = None
temperature: float | None = None
model: str | None = None
class LlmGateway:
"""Facade for text and JSON LLM calls."""
def provider_for(self, use_case: LlmUseCase | None = None):
if use_case and use_case.fast:
return get_llm_provider_fast()
return get_llm_provider()
def langchain_llm_for(self, use_case: LlmUseCase | None = None) -> Any | None:
provider = self.provider_for(use_case)
return getattr(provider, "langchain_llm", None)
async def chat_text(
self,
messages: list[dict],
*,
use_case: LlmUseCase | None = None,
temperature: float | None = None,
model: str | None = None,
max_tokens: int | None = None,
) -> str:
provider = self.provider_for(use_case)
resolved_temperature = (
temperature
if temperature is not None
else (
use_case.temperature
if use_case and use_case.temperature is not None
else 0.7
)
)
return await provider.complete(
messages,
temperature=resolved_temperature,
model=model if model is not None else (use_case.model if use_case else None),
max_tokens=(
max_tokens
if max_tokens is not None
else (use_case.max_tokens if use_case else None)
),
)
async def json_object(
self,
prompt: str,
schema: type[T],
*,
use_case: LlmUseCase,
fallback_factory: Callable[[], T] | None = None,
) -> T:
return await allm_json_call(
self.langchain_llm_for(use_case),
prompt,
schema,
max_tokens=use_case.max_tokens or 1024,
agent=use_case.name,
fallback_factory=fallback_factory,
)
def sync_json_object(
self,
prompt: str,
schema: type[T],
*,
use_case: LlmUseCase,
fallback_factory: Callable[[], T] | None = None,
) -> T:
return llm_json_call(
self.langchain_llm_for(use_case),
prompt,
schema,
max_tokens=use_case.max_tokens or 1024,
agent=use_case.name,
fallback_factory=fallback_factory,
)
__all__ = ["LlmGateway", "LlmUseCase"]