Files
life-echo/api/tests/test_llm_json_call.py

151 lines
4.2 KiB
Python
Raw Normal View History

"""Behavior tests for `llm_json_call` / `allm_json_call` (schema-driven LLM JSON)."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from pydantic import BaseModel, Field
from app.core.langchain_llm import ensure_json_object_prompt_has_json_keyword
from app.core.llm_call import LLMCallError, allm_json_call, llm_json_call
class _SmallOut(BaseModel):
answer: str = Field(default="")
score: int = 0
class _SyncFakeLlm:
def __init__(self, contents: list[str]) -> None:
self._contents = list(contents)
self._i = 0
def bind(self, **_kwargs: object):
return self
def invoke(self, _prompt: str):
c = self._contents[min(self._i, len(self._contents) - 1)]
self._i += 1
return AIMessage(content=c)
async def ainvoke(self, _prompt: str):
return self.invoke(_prompt)
class _CapturePromptLlm:
"""Records the prompt passed to ``invoke`` (after ``bind`` noop)."""
def __init__(self, content: str) -> None:
self.content = content
self.last_prompt: str | None = None
def bind(self, **_kwargs: object):
return self
def invoke(self, prompt: str):
self.last_prompt = prompt
return AIMessage(content=self.content)
async def ainvoke(self, prompt: str):
return self.invoke(prompt)
def test_ensure_json_object_prompt_appends_when_no_json_substring() -> None:
out = ensure_json_object_prompt_has_json_keyword("仅中文,无英文字样")
assert "仅中文" in out
assert "json" in out.casefold()
def test_ensure_json_object_prompt_unchanged_when_json_present() -> None:
s = "只输出 JSON{}"
assert ensure_json_object_prompt_has_json_keyword(s) is s
def test_llm_json_call_injects_json_keyword_for_api_when_missing() -> None:
llm = _CapturePromptLlm('{"answer": "ok", "score": 2}')
llm_json_call(
llm,
'你是助手。输出:{"x":1}',
_SmallOut,
max_tokens=32,
agent="test_agent",
)
assert llm.last_prompt is not None
# 用户内容里含 JSON 示例但无字母 json仍应追加合规后缀
assert "json" in llm.last_prompt.casefold()
def test_llm_json_call_success() -> None:
llm = _SyncFakeLlm(['{"answer": "ok", "score": 2}'])
out = llm_json_call(
llm,
"prompt",
_SmallOut,
max_tokens=32,
agent="test_agent",
)
assert out.answer == "ok"
assert out.score == 2
def test_llm_json_call_empty_retry_then_success() -> None:
llm = _SyncFakeLlm(["", '{"answer": "retry", "score": 0}'])
out = llm_json_call(
llm,
"p",
_SmallOut,
max_tokens=8,
agent="t",
)
assert out.answer == "retry"
assert llm._i == 2
def test_llm_json_call_compat_strip() -> None:
raw = '```json\n{"answer": "fenced", "score": 1}\n```'
llm = _SyncFakeLlm([raw])
out = llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t")
assert out.answer == "fenced"
def test_llm_json_call_validation_fallback() -> None:
llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}'])
out = llm_json_call(
llm,
"p",
_SmallOut,
max_tokens=8,
agent="t",
fallback_factory=lambda: _SmallOut(answer="fb"),
)
assert out.answer == "fb"
def test_llm_json_call_no_fallback_raises() -> None:
llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}'])
with pytest.raises(LLMCallError) as ei:
llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t")
assert ei.value.kind == "validation"
@pytest.mark.asyncio
async def test_allm_json_call_parity_with_sync() -> None:
llm = _SyncFakeLlm(['{"answer": "async", "score": 7}'])
out = await allm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="at")
assert out.answer == "async"
assert out.score == 7
@pytest.mark.asyncio
async def test_allm_json_call_fallback_on_decode() -> None:
llm = _SyncFakeLlm(["not json at all {{{"])
out = await allm_json_call(
llm,
"p",
_SmallOut,
max_tokens=8,
agent="at",
fallback_factory=lambda: _SmallOut(answer="dec"),
)
assert out.answer == "dec"