108 lines
3.0 KiB
Python
108 lines
3.0 KiB
Python
"""Use-case oriented LLM gateway.
|
|
|
|
This is a small compatibility layer over the existing provider and JSON helper
|
|
functions. It gives new code a stable place to request model capabilities while
|
|
older agents continue to use LangChain directly during the transition.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, TypeVar
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from app.core.dependencies import get_llm_provider, get_llm_provider_fast
|
|
from app.core.llm_call import allm_json_call, llm_json_call
|
|
|
|
T = TypeVar("T", bound=BaseModel)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LlmUseCase:
|
|
name: str
|
|
fast: bool = False
|
|
max_tokens: int | None = None
|
|
temperature: float | None = None
|
|
model: str | None = None
|
|
|
|
|
|
class LlmGateway:
|
|
"""Facade for text and JSON LLM calls."""
|
|
|
|
def provider_for(self, use_case: LlmUseCase | None = None):
|
|
if use_case and use_case.fast:
|
|
return get_llm_provider_fast()
|
|
return get_llm_provider()
|
|
|
|
def langchain_llm_for(self, use_case: LlmUseCase | None = None) -> Any | None:
|
|
provider = self.provider_for(use_case)
|
|
return getattr(provider, "langchain_llm", None)
|
|
|
|
async def chat_text(
|
|
self,
|
|
messages: list[dict],
|
|
*,
|
|
use_case: LlmUseCase | None = None,
|
|
temperature: float | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int | None = None,
|
|
) -> str:
|
|
provider = self.provider_for(use_case)
|
|
resolved_temperature = (
|
|
temperature
|
|
if temperature is not None
|
|
else (
|
|
use_case.temperature
|
|
if use_case and use_case.temperature is not None
|
|
else 0.7
|
|
)
|
|
)
|
|
return await provider.complete(
|
|
messages,
|
|
temperature=resolved_temperature,
|
|
model=model if model is not None else (use_case.model if use_case else None),
|
|
max_tokens=(
|
|
max_tokens
|
|
if max_tokens is not None
|
|
else (use_case.max_tokens if use_case else None)
|
|
),
|
|
)
|
|
|
|
async def json_object(
|
|
self,
|
|
prompt: str,
|
|
schema: type[T],
|
|
*,
|
|
use_case: LlmUseCase,
|
|
fallback_factory: Callable[[], T] | None = None,
|
|
) -> T:
|
|
return await allm_json_call(
|
|
self.langchain_llm_for(use_case),
|
|
prompt,
|
|
schema,
|
|
max_tokens=use_case.max_tokens or 1024,
|
|
agent=use_case.name,
|
|
fallback_factory=fallback_factory,
|
|
)
|
|
|
|
def sync_json_object(
|
|
self,
|
|
prompt: str,
|
|
schema: type[T],
|
|
*,
|
|
use_case: LlmUseCase,
|
|
fallback_factory: Callable[[], T] | None = None,
|
|
) -> T:
|
|
return llm_json_call(
|
|
self.langchain_llm_for(use_case),
|
|
prompt,
|
|
schema,
|
|
max_tokens=use_case.max_tokens or 1024,
|
|
agent=use_case.name,
|
|
fallback_factory=fallback_factory,
|
|
)
|
|
|
|
|
|
__all__ = ["LlmGateway", "LlmUseCase"]
|