"""Behavior tests for `llm_json_call` / `allm_json_call` (schema-driven LLM JSON).""" from __future__ import annotations import httpx import pytest from langchain_core.messages import AIMessage from openai import APIStatusError from pydantic import BaseModel, Field from app.core.langchain_llm import ensure_json_object_prompt_has_json_keyword from app.core.llm_call import LLMCallError, allm_json_call, llm_json_call class _SmallOut(BaseModel): answer: str = Field(default="") score: int = 0 class _SyncFakeLlm: def __init__(self, contents: list[str]) -> None: self._contents = list(contents) self._i = 0 def bind(self, **_kwargs: object): return self def invoke(self, _prompt: str): c = self._contents[min(self._i, len(self._contents) - 1)] self._i += 1 return AIMessage(content=c) async def ainvoke(self, _prompt: str): return self.invoke(_prompt) class _CapturePromptLlm: """Records the prompt passed to ``invoke`` (after ``bind`` noop).""" def __init__(self, content: str) -> None: self.content = content self.last_prompt: str | None = None def bind(self, **_kwargs: object): return self def invoke(self, prompt: str): self.last_prompt = prompt return AIMessage(content=self.content) async def ainvoke(self, prompt: str): return self.invoke(prompt) def test_ensure_json_object_prompt_appends_when_no_json_substring() -> None: out = ensure_json_object_prompt_has_json_keyword("仅中文,无英文字样") assert "仅中文" in out assert "json" in out.casefold() def test_ensure_json_object_prompt_unchanged_when_json_present() -> None: s = "只输出 JSON:{}" assert ensure_json_object_prompt_has_json_keyword(s) is s def test_llm_json_call_injects_json_keyword_for_api_when_missing() -> None: llm = _CapturePromptLlm('{"answer": "ok", "score": 2}') llm_json_call( llm, '你是助手。输出:{"x":1}', _SmallOut, max_tokens=32, agent="test_agent", ) assert llm.last_prompt is not None # 用户内容里含 JSON 示例但无字母 json,仍应追加合规后缀 assert "json" in llm.last_prompt.casefold() def test_llm_json_call_success() -> None: llm = _SyncFakeLlm(['{"answer": "ok", "score": 2}']) out = llm_json_call( llm, "prompt", _SmallOut, max_tokens=32, agent="test_agent", ) assert out.answer == "ok" assert out.score == 2 def test_llm_json_call_empty_retry_then_success() -> None: llm = _SyncFakeLlm(["", '{"answer": "retry", "score": 0}']) out = llm_json_call( llm, "p", _SmallOut, max_tokens=8, agent="t", ) assert out.answer == "retry" assert llm._i == 2 def test_llm_json_call_compat_strip() -> None: raw = '```json\n{"answer": "fenced", "score": 1}\n```' llm = _SyncFakeLlm([raw]) out = llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t") assert out.answer == "fenced" def test_llm_json_call_validation_fallback() -> None: llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}']) out = llm_json_call( llm, "p", _SmallOut, max_tokens=8, agent="t", fallback_factory=lambda: _SmallOut(answer="fb"), ) assert out.answer == "fb" def test_llm_json_call_no_fallback_raises() -> None: llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}']) with pytest.raises(LLMCallError) as ei: llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t") assert ei.value.kind == "validation" def _api_status_402() -> APIStatusError: req = httpx.Request("POST", "https://api.deepseek.com/v1/chat/completions") resp = httpx.Response( 402, request=req, json={"error": {"message": "Insufficient balance"}} ) return APIStatusError("Payment required", response=resp, body=resp.json()) class _LlmInvokeRaises: def bind(self, **_kwargs: object): return self def invoke(self, _prompt: str) -> object: raise _api_status_402() async def ainvoke(self, _prompt: str) -> object: return self.invoke(_prompt) def test_llm_json_call_openai_status_maps_to_friendly_chinese() -> None: with pytest.raises(LLMCallError) as ei: llm_json_call( _LlmInvokeRaises(), "p", _SmallOut, max_tokens=8, agent="t", ) assert ei.value.kind == "invoke" s = str(ei.value) assert "402" in s assert "余额" in s assert "DeepSeek" in s @pytest.mark.asyncio async def test_allm_json_call_parity_with_sync() -> None: llm = _SyncFakeLlm(['{"answer": "async", "score": 7}']) out = await allm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="at") assert out.answer == "async" assert out.score == 7 @pytest.mark.asyncio async def test_allm_json_call_fallback_on_decode() -> None: llm = _SyncFakeLlm(["not json at all {{{"]) out = await allm_json_call( llm, "p", _SmallOut, max_tokens=8, agent="at", fallback_factory=lambda: _SmallOut(answer="dec"), ) assert out.answer == "dec"