Files
life-echo/api/tests/test_llm_json_call.py

151 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Behavior tests for `llm_json_call` / `allm_json_call` (schema-driven LLM JSON)."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from pydantic import BaseModel, Field
from app.core.langchain_llm import ensure_json_object_prompt_has_json_keyword
from app.core.llm_call import LLMCallError, allm_json_call, llm_json_call
class _SmallOut(BaseModel):
answer: str = Field(default="")
score: int = 0
class _SyncFakeLlm:
def __init__(self, contents: list[str]) -> None:
self._contents = list(contents)
self._i = 0
def bind(self, **_kwargs: object):
return self
def invoke(self, _prompt: str):
c = self._contents[min(self._i, len(self._contents) - 1)]
self._i += 1
return AIMessage(content=c)
async def ainvoke(self, _prompt: str):
return self.invoke(_prompt)
class _CapturePromptLlm:
"""Records the prompt passed to ``invoke`` (after ``bind`` noop)."""
def __init__(self, content: str) -> None:
self.content = content
self.last_prompt: str | None = None
def bind(self, **_kwargs: object):
return self
def invoke(self, prompt: str):
self.last_prompt = prompt
return AIMessage(content=self.content)
async def ainvoke(self, prompt: str):
return self.invoke(prompt)
def test_ensure_json_object_prompt_appends_when_no_json_substring() -> None:
out = ensure_json_object_prompt_has_json_keyword("仅中文,无英文字样")
assert "仅中文" in out
assert "json" in out.casefold()
def test_ensure_json_object_prompt_unchanged_when_json_present() -> None:
s = "只输出 JSON{}"
assert ensure_json_object_prompt_has_json_keyword(s) is s
def test_llm_json_call_injects_json_keyword_for_api_when_missing() -> None:
llm = _CapturePromptLlm('{"answer": "ok", "score": 2}')
llm_json_call(
llm,
'你是助手。输出:{"x":1}',
_SmallOut,
max_tokens=32,
agent="test_agent",
)
assert llm.last_prompt is not None
# 用户内容里含 JSON 示例但无字母 json仍应追加合规后缀
assert "json" in llm.last_prompt.casefold()
def test_llm_json_call_success() -> None:
llm = _SyncFakeLlm(['{"answer": "ok", "score": 2}'])
out = llm_json_call(
llm,
"prompt",
_SmallOut,
max_tokens=32,
agent="test_agent",
)
assert out.answer == "ok"
assert out.score == 2
def test_llm_json_call_empty_retry_then_success() -> None:
llm = _SyncFakeLlm(["", '{"answer": "retry", "score": 0}'])
out = llm_json_call(
llm,
"p",
_SmallOut,
max_tokens=8,
agent="t",
)
assert out.answer == "retry"
assert llm._i == 2
def test_llm_json_call_compat_strip() -> None:
raw = '```json\n{"answer": "fenced", "score": 1}\n```'
llm = _SyncFakeLlm([raw])
out = llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t")
assert out.answer == "fenced"
def test_llm_json_call_validation_fallback() -> None:
llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}'])
out = llm_json_call(
llm,
"p",
_SmallOut,
max_tokens=8,
agent="t",
fallback_factory=lambda: _SmallOut(answer="fb"),
)
assert out.answer == "fb"
def test_llm_json_call_no_fallback_raises() -> None:
llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}'])
with pytest.raises(LLMCallError) as ei:
llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t")
assert ei.value.kind == "validation"
@pytest.mark.asyncio
async def test_allm_json_call_parity_with_sync() -> None:
llm = _SyncFakeLlm(['{"answer": "async", "score": 7}'])
out = await allm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="at")
assert out.answer == "async"
assert out.score == 7
@pytest.mark.asyncio
async def test_allm_json_call_fallback_on_decode() -> None:
llm = _SyncFakeLlm(["not json at all {{{"])
out = await allm_json_call(
llm,
"p",
_SmallOut,
max_tokens=8,
agent="at",
fallback_factory=lambda: _SmallOut(answer="dec"),
)
assert out.answer == "dec"