feat(api): 统一 LLM JSON 调用层 llm_json_call,按域 Schema 迁移 chat/memoir agents

This commit is contained in:
Kevin
2026-04-03 13:34:27 +08:00
parent 41518bda11
commit 43d1689e9c
28 changed files with 1006 additions and 352 deletions

View File

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock
import pytest
from app.agents.chat.schemas import StageDetectionOutput
from app.agents.chat.stage_detection import detect_primary_life_stage
@@ -11,9 +12,9 @@ from app.agents.chat.stage_detection import detect_primary_life_stage
async def test_skip_llm_does_not_call_json_llm(monkeypatch: pytest.MonkeyPatch) -> None:
called: list[int] = []
async def _fake_ainvoke(*_a: object, **_k: object) -> str:
async def _fake_allm(*_a: object, **_k: object) -> StageDetectionOutput:
called.append(1)
return '{"detected_stage": "career"}'
return StageDetectionOutput(detected_stage="career")
monkeypatch.setattr(
"app.agents.chat.stage_detection.settings.chat_stage_detection_enabled",
@@ -25,8 +26,8 @@ async def test_skip_llm_does_not_call_json_llm(monkeypatch: pytest.MonkeyPatch)
True,
)
monkeypatch.setattr(
"app.agents.chat.stage_detection.ainvoke_json_object",
_fake_ainvoke,
"app.agents.chat.stage_detection.allm_json_call",
_fake_allm,
)
out = await detect_primary_life_stage(
"",

View File

@@ -8,6 +8,7 @@ import pytest
from app.agents.memoir.fidelity_check_agent import FidelityCheckAgent
from app.core.config import settings
from app.core.llm_call import LLMCallError
def test_fidelity_fail_closed_on_parse_when_not_append(
@@ -18,8 +19,8 @@ def test_fidelity_fail_closed_on_parse_when_not_append(
agent = FidelityCheckAgent()
llm = MagicMock()
with patch(
"app.agents.memoir.fidelity_check_agent.invoke_json_object",
side_effect=ValueError("simulated_bad_response"),
"app.agents.memoir.fidelity_check_agent.llm_json_call",
side_effect=LLMCallError("invoke", "simulated_bad_response"),
):
assert (
agent.passes(
@@ -40,8 +41,8 @@ def test_fidelity_fail_open_on_parse_when_append(
agent = FidelityCheckAgent()
llm = MagicMock()
with patch(
"app.agents.memoir.fidelity_check_agent.invoke_json_object",
side_effect=ValueError("simulated_bad_response"),
"app.agents.memoir.fidelity_check_agent.llm_json_call",
side_effect=LLMCallError("invoke", "simulated_bad_response"),
):
assert (
agent.passes(
@@ -62,8 +63,8 @@ def test_fidelity_fail_open_global_flag_overrides_append(
agent = FidelityCheckAgent()
llm = MagicMock()
with patch(
"app.agents.memoir.fidelity_check_agent.invoke_json_object",
side_effect=ValueError("simulated_bad_response"),
"app.agents.memoir.fidelity_check_agent.llm_json_call",
side_effect=LLMCallError("invoke", "simulated_bad_response"),
):
assert (
agent.passes(

View File

@@ -0,0 +1,150 @@
"""Behavior tests for `llm_json_call` / `allm_json_call` (schema-driven LLM JSON)."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from pydantic import BaseModel, Field
from app.core.langchain_llm import ensure_json_object_prompt_has_json_keyword
from app.core.llm_call import LLMCallError, allm_json_call, llm_json_call
class _SmallOut(BaseModel):
answer: str = Field(default="")
score: int = 0
class _SyncFakeLlm:
def __init__(self, contents: list[str]) -> None:
self._contents = list(contents)
self._i = 0
def bind(self, **_kwargs: object):
return self
def invoke(self, _prompt: str):
c = self._contents[min(self._i, len(self._contents) - 1)]
self._i += 1
return AIMessage(content=c)
async def ainvoke(self, _prompt: str):
return self.invoke(_prompt)
class _CapturePromptLlm:
"""Records the prompt passed to ``invoke`` (after ``bind`` noop)."""
def __init__(self, content: str) -> None:
self.content = content
self.last_prompt: str | None = None
def bind(self, **_kwargs: object):
return self
def invoke(self, prompt: str):
self.last_prompt = prompt
return AIMessage(content=self.content)
async def ainvoke(self, prompt: str):
return self.invoke(prompt)
def test_ensure_json_object_prompt_appends_when_no_json_substring() -> None:
out = ensure_json_object_prompt_has_json_keyword("仅中文,无英文字样")
assert "仅中文" in out
assert "json" in out.casefold()
def test_ensure_json_object_prompt_unchanged_when_json_present() -> None:
s = "只输出 JSON{}"
assert ensure_json_object_prompt_has_json_keyword(s) is s
def test_llm_json_call_injects_json_keyword_for_api_when_missing() -> None:
llm = _CapturePromptLlm('{"answer": "ok", "score": 2}')
llm_json_call(
llm,
'你是助手。输出:{"x":1}',
_SmallOut,
max_tokens=32,
agent="test_agent",
)
assert llm.last_prompt is not None
# 用户内容里含 JSON 示例但无字母 json仍应追加合规后缀
assert "json" in llm.last_prompt.casefold()
def test_llm_json_call_success() -> None:
llm = _SyncFakeLlm(['{"answer": "ok", "score": 2}'])
out = llm_json_call(
llm,
"prompt",
_SmallOut,
max_tokens=32,
agent="test_agent",
)
assert out.answer == "ok"
assert out.score == 2
def test_llm_json_call_empty_retry_then_success() -> None:
llm = _SyncFakeLlm(["", '{"answer": "retry", "score": 0}'])
out = llm_json_call(
llm,
"p",
_SmallOut,
max_tokens=8,
agent="t",
)
assert out.answer == "retry"
assert llm._i == 2
def test_llm_json_call_compat_strip() -> None:
raw = '```json\n{"answer": "fenced", "score": 1}\n```'
llm = _SyncFakeLlm([raw])
out = llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t")
assert out.answer == "fenced"
def test_llm_json_call_validation_fallback() -> None:
llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}'])
out = llm_json_call(
llm,
"p",
_SmallOut,
max_tokens=8,
agent="t",
fallback_factory=lambda: _SmallOut(answer="fb"),
)
assert out.answer == "fb"
def test_llm_json_call_no_fallback_raises() -> None:
llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}'])
with pytest.raises(LLMCallError) as ei:
llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t")
assert ei.value.kind == "validation"
@pytest.mark.asyncio
async def test_allm_json_call_parity_with_sync() -> None:
llm = _SyncFakeLlm(['{"answer": "async", "score": 7}'])
out = await allm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="at")
assert out.answer == "async"
assert out.score == 7
@pytest.mark.asyncio
async def test_allm_json_call_fallback_on_decode() -> None:
llm = _SyncFakeLlm(["not json at all {{{"])
out = await allm_json_call(
llm,
"p",
_SmallOut,
max_tokens=8,
agent="at",
fallback_factory=lambda: _SmallOut(answer="dec"),
)
assert out.answer == "dec"

View File

@@ -0,0 +1,13 @@
"""`STAGE_SLOT_KEYS` must stay aligned with `MemoirStateSchema.default_slots`."""
from __future__ import annotations
from app.agents.stage_constants import STAGE_SLOT_KEYS
from app.agents.state_schema import default_slots
def test_stage_slot_keys_matches_default_slots() -> None:
slots = default_slots()
assert set(STAGE_SLOT_KEYS.keys()) == set(slots.keys())
for stage, keys in STAGE_SLOT_KEYS.items():
assert tuple(slots[stage].keys()) == keys

View File

@@ -1,11 +1,11 @@
"""阶段 / 章节归一化纯函数VALID + normalize + chat_bucket"""
import json
from unittest.mock import MagicMock
import pytest
from app.agents.memoir.extraction_agent import ExtractionAgent
from app.agents.memoir.schemas import StateExtractionOutput
from app.agents.stage_constants import (
chat_bucket,
normalize_chapter_category,
@@ -47,10 +47,10 @@ def test_extraction_agent_normalizes_detected_stage(
agent = ExtractionAgent()
llm = MagicMock()
monkeypatch.setattr(
"app.agents.memoir.extraction_agent.invoke_json_object",
lambda *_a, **_k: json.dumps(
{"detected_stage": "career_early", "slots": {"job": "演员"}},
ensure_ascii=False,
"app.agents.memoir.extraction_agent.llm_json_call",
lambda *_a, **_k: StateExtractionOutput(
detected_stage="career_early",
slots={"job": "演员"},
),
)
r = agent.extract(
@@ -69,10 +69,10 @@ def test_extraction_agent_invalid_detected_falls_back(
agent = ExtractionAgent()
llm = MagicMock()
monkeypatch.setattr(
"app.agents.memoir.extraction_agent.invoke_json_object",
lambda *_a, **_k: json.dumps(
{"detected_stage": "llm_hallucination", "slots": {}},
ensure_ascii=False,
"app.agents.memoir.extraction_agent.llm_json_call",
lambda *_a, **_k: StateExtractionOutput(
detected_stage="llm_hallucination",
slots={},
),
)
r = agent.extract(
@@ -91,15 +91,12 @@ def test_extraction_agent_empty_slots_inherits_current_stage(
agent = ExtractionAgent()
llm = MagicMock()
monkeypatch.setattr(
"app.agents.memoir.extraction_agent.invoke_json_object",
lambda *_a, **_k: json.dumps(
{
"detected_stage": "childhood",
"slots": {},
"emotion": "neutral",
"is_new_chapter": False,
},
ensure_ascii=False,
"app.agents.memoir.extraction_agent.llm_json_call",
lambda *_a, **_k: StateExtractionOutput(
detected_stage="childhood",
slots={},
emotion="neutral",
is_new_chapter=False,
),
)
r = agent.extract(

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
@@ -12,7 +11,7 @@ from app.agents.memoir.prompts import (
get_story_route_prompt,
story_route_merge_hint_for_category,
)
from app.agents.memoir.story_route_agent import StoryRouteAgent
from app.agents.memoir.story_route_agent import StoryRouteAgent, StoryRouteDecision
def test_route_prompt_beliefs_has_strong_container_and_no_uncertain_new():
@@ -52,15 +51,12 @@ def test_merge_hint_career_episodic():
def test_decide_beliefs_mock_llm_append_and_prompt_has_payload():
captured: dict[str, str] = {}
def fake_invoke(_llm, prompt: str, **_kwargs):
def fake_llm_json(_llm, prompt: str, _schema: object, **_kwargs):
captured["prompt"] = prompt
return json.dumps(
{
"decision": "append_story",
"target_story_id": "story-a",
"reason": "同一价值观补充",
},
ensure_ascii=False,
return StoryRouteDecision(
decision="append_story",
target_story_id="story-a",
reason="同一价值观补充",
)
cand = SimpleNamespace(
@@ -72,8 +68,8 @@ def test_decide_beliefs_mock_llm_append_and_prompt_has_payload():
chapter_links=[],
)
with patch(
"app.agents.memoir.story_route_agent.invoke_json_object",
side_effect=fake_invoke,
"app.agents.memoir.story_route_agent.llm_json_call",
side_effect=fake_llm_json,
):
agent = StoryRouteAgent()
d = agent.decide(
@@ -94,15 +90,12 @@ def test_decide_beliefs_mock_llm_append_and_prompt_has_payload():
def test_decide_career_mock_llm_new_story_and_prompt_episodic():
captured: dict[str, str] = {}
def fake_invoke(_llm, prompt: str, **_kwargs):
def fake_llm_json(_llm, prompt: str, _schema: object, **_kwargs):
captured["prompt"] = prompt
return json.dumps(
{
"decision": "new_story",
"target_story_id": None,
"reason": "另一段完全不同的任职经历",
},
ensure_ascii=False,
return StoryRouteDecision(
decision="new_story",
target_story_id=None,
reason="另一段完全不同的任职经历",
)
cand = SimpleNamespace(
@@ -114,8 +107,8 @@ def test_decide_career_mock_llm_new_story_and_prompt_episodic():
chapter_links=[],
)
with patch(
"app.agents.memoir.story_route_agent.invoke_json_object",
side_effect=fake_invoke,
"app.agents.memoir.story_route_agent.llm_json_call",
side_effect=fake_llm_json,
):
d = StoryRouteAgent().decide(
chapter_category="career_achievement",