Files
life-echo/api/app/core/llm_gateway.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

110 lines
3.2 KiB
Python

"""Use-case oriented LLM gateway.
This is a small compatibility layer over the existing provider and JSON helper
functions. It gives new code a stable place to request model capabilities while
older agents continue to use LangChain directly during the transition.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, TypeVar
from pydantic import BaseModel
from app.core.dependencies import get_llm_provider, get_llm_provider_fast
from app.core.llm_call import allm_json_call, llm_json_call
T = TypeVar("T", bound=BaseModel)
@dataclass(frozen=True)
class LlmUseCase:
name: str
fast: bool = False
max_tokens: int | None = None
temperature: float | None = None
model: str | None = None
class LlmGateway:
"""Facade for text and JSON LLM calls."""
def provider_for(self, use_case: LlmUseCase | None = None):
if use_case and use_case.fast:
return get_llm_provider_fast()
return get_llm_provider()
def langchain_llm_for(self, use_case: LlmUseCase | None = None) -> Any | None:
provider = self.provider_for(use_case)
return getattr(provider, "langchain_llm", None)
async def chat_text(
self,
messages: list[dict],
*,
use_case: LlmUseCase | None = None,
temperature: float | None = None,
model: str | None = None,
max_tokens: int | None = None,
) -> str:
provider = self.provider_for(use_case)
resolved_temperature = (
temperature
if temperature is not None
else (
use_case.temperature
if use_case and use_case.temperature is not None
else 0.7
)
)
resolved_model = model if model is not None else (use_case.model if use_case else None)
agent_name = use_case.name if use_case else "llm_gateway.chat"
kwargs = dict(
messages=messages,
temperature=resolved_temperature,
model=resolved_model,
max_tokens=(
max_tokens
if max_tokens is not None
else (use_case.max_tokens if use_case else None)
),
)
return await provider.complete(**kwargs)
async def json_object(
self,
prompt: str,
schema: type[T],
*,
use_case: LlmUseCase,
fallback_factory: Callable[[], T] | None = None,
) -> T:
return await allm_json_call(
self.langchain_llm_for(use_case),
prompt,
schema,
max_tokens=use_case.max_tokens or 1024,
agent=use_case.name,
fallback_factory=fallback_factory,
)
def sync_json_object(
self,
prompt: str,
schema: type[T],
*,
use_case: LlmUseCase,
fallback_factory: Callable[[], T] | None = None,
) -> T:
return llm_json_call(
self.langchain_llm_for(use_case),
prompt,
schema,
max_tokens=use_case.max_tokens or 1024,
agent=use_case.name,
fallback_factory=fallback_factory,
)
__all__ = ["LlmGateway", "LlmUseCase"]