Files
life-echo/api/tests/test_llm_gateway.py

63 lines
1.7 KiB
Python
Raw Normal View History

from __future__ import annotations
import pytest
from app.core.llm_gateway import LlmGateway, LlmUseCase
class _FakeProvider:
def __init__(self, name: str) -> None:
self.name = name
self.langchain_llm = f"lc-{name}"
self.complete_calls: list[dict] = []
async def complete(self, messages, **kwargs) -> str:
self.complete_calls.append({"messages": messages, **kwargs})
return f"ok-{self.name}"
def test_llm_gateway_selects_default_or_fast_provider(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from app.core import llm_gateway as gateway_mod
default = _FakeProvider("default")
fast = _FakeProvider("fast")
monkeypatch.setattr(gateway_mod, "get_llm_provider", lambda: default)
monkeypatch.setattr(gateway_mod, "get_llm_provider_fast", lambda: fast)
gateway = LlmGateway()
assert gateway.langchain_llm_for() == "lc-default"
assert gateway.langchain_llm_for(LlmUseCase("memory", fast=True)) == "lc-fast"
@pytest.mark.asyncio
async def test_llm_gateway_chat_text_applies_use_case_defaults(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from app.core import llm_gateway as gateway_mod
provider = _FakeProvider("default")
monkeypatch.setattr(gateway_mod, "get_llm_provider", lambda: provider)
text = await LlmGateway().chat_text(
[{"role": "user", "content": "hi"}],
use_case=LlmUseCase(
"chat",
max_tokens=99,
temperature=0.2,
model="model-a",
),
)
assert text == "ok-default"
assert provider.complete_calls == [
{
"messages": [{"role": "user", "content": "hi"}],
"temperature": 0.2,
"model": "model-a",
"max_tokens": 99,
}
]