feat(api): 统一 LLM JSON 调用层 llm_json_call,按域 Schema 迁移 chat/memoir agents
This commit is contained in:
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.schemas import StageDetectionOutput
|
||||
from app.agents.chat.stage_detection import detect_primary_life_stage
|
||||
|
||||
|
||||
@@ -11,9 +12,9 @@ from app.agents.chat.stage_detection import detect_primary_life_stage
|
||||
async def test_skip_llm_does_not_call_json_llm(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
called: list[int] = []
|
||||
|
||||
async def _fake_ainvoke(*_a: object, **_k: object) -> str:
|
||||
async def _fake_allm(*_a: object, **_k: object) -> StageDetectionOutput:
|
||||
called.append(1)
|
||||
return '{"detected_stage": "career"}'
|
||||
return StageDetectionOutput(detected_stage="career")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.chat.stage_detection.settings.chat_stage_detection_enabled",
|
||||
@@ -25,8 +26,8 @@ async def test_skip_llm_does_not_call_json_llm(monkeypatch: pytest.MonkeyPatch)
|
||||
True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.chat.stage_detection.ainvoke_json_object",
|
||||
_fake_ainvoke,
|
||||
"app.agents.chat.stage_detection.allm_json_call",
|
||||
_fake_allm,
|
||||
)
|
||||
out = await detect_primary_life_stage(
|
||||
"嗯",
|
||||
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
from app.agents.memoir.fidelity_check_agent import FidelityCheckAgent
|
||||
from app.core.config import settings
|
||||
from app.core.llm_call import LLMCallError
|
||||
|
||||
|
||||
def test_fidelity_fail_closed_on_parse_when_not_append(
|
||||
@@ -18,8 +19,8 @@ def test_fidelity_fail_closed_on_parse_when_not_append(
|
||||
agent = FidelityCheckAgent()
|
||||
llm = MagicMock()
|
||||
with patch(
|
||||
"app.agents.memoir.fidelity_check_agent.invoke_json_object",
|
||||
side_effect=ValueError("simulated_bad_response"),
|
||||
"app.agents.memoir.fidelity_check_agent.llm_json_call",
|
||||
side_effect=LLMCallError("invoke", "simulated_bad_response"),
|
||||
):
|
||||
assert (
|
||||
agent.passes(
|
||||
@@ -40,8 +41,8 @@ def test_fidelity_fail_open_on_parse_when_append(
|
||||
agent = FidelityCheckAgent()
|
||||
llm = MagicMock()
|
||||
with patch(
|
||||
"app.agents.memoir.fidelity_check_agent.invoke_json_object",
|
||||
side_effect=ValueError("simulated_bad_response"),
|
||||
"app.agents.memoir.fidelity_check_agent.llm_json_call",
|
||||
side_effect=LLMCallError("invoke", "simulated_bad_response"),
|
||||
):
|
||||
assert (
|
||||
agent.passes(
|
||||
@@ -62,8 +63,8 @@ def test_fidelity_fail_open_global_flag_overrides_append(
|
||||
agent = FidelityCheckAgent()
|
||||
llm = MagicMock()
|
||||
with patch(
|
||||
"app.agents.memoir.fidelity_check_agent.invoke_json_object",
|
||||
side_effect=ValueError("simulated_bad_response"),
|
||||
"app.agents.memoir.fidelity_check_agent.llm_json_call",
|
||||
side_effect=LLMCallError("invoke", "simulated_bad_response"),
|
||||
):
|
||||
assert (
|
||||
agent.passes(
|
||||
|
||||
150
api/tests/test_llm_json_call.py
Normal file
150
api/tests/test_llm_json_call.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Behavior tests for `llm_json_call` / `allm_json_call` (schema-driven LLM JSON)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.langchain_llm import ensure_json_object_prompt_has_json_keyword
|
||||
from app.core.llm_call import LLMCallError, allm_json_call, llm_json_call
|
||||
|
||||
|
||||
class _SmallOut(BaseModel):
|
||||
answer: str = Field(default="")
|
||||
score: int = 0
|
||||
|
||||
|
||||
class _SyncFakeLlm:
|
||||
def __init__(self, contents: list[str]) -> None:
|
||||
self._contents = list(contents)
|
||||
self._i = 0
|
||||
|
||||
def bind(self, **_kwargs: object):
|
||||
return self
|
||||
|
||||
def invoke(self, _prompt: str):
|
||||
c = self._contents[min(self._i, len(self._contents) - 1)]
|
||||
self._i += 1
|
||||
return AIMessage(content=c)
|
||||
|
||||
async def ainvoke(self, _prompt: str):
|
||||
return self.invoke(_prompt)
|
||||
|
||||
|
||||
class _CapturePromptLlm:
|
||||
"""Records the prompt passed to ``invoke`` (after ``bind`` noop)."""
|
||||
|
||||
def __init__(self, content: str) -> None:
|
||||
self.content = content
|
||||
self.last_prompt: str | None = None
|
||||
|
||||
def bind(self, **_kwargs: object):
|
||||
return self
|
||||
|
||||
def invoke(self, prompt: str):
|
||||
self.last_prompt = prompt
|
||||
return AIMessage(content=self.content)
|
||||
|
||||
async def ainvoke(self, prompt: str):
|
||||
return self.invoke(prompt)
|
||||
|
||||
|
||||
def test_ensure_json_object_prompt_appends_when_no_json_substring() -> None:
|
||||
out = ensure_json_object_prompt_has_json_keyword("仅中文,无英文字样")
|
||||
assert "仅中文" in out
|
||||
assert "json" in out.casefold()
|
||||
|
||||
|
||||
def test_ensure_json_object_prompt_unchanged_when_json_present() -> None:
|
||||
s = "只输出 JSON:{}"
|
||||
assert ensure_json_object_prompt_has_json_keyword(s) is s
|
||||
|
||||
|
||||
def test_llm_json_call_injects_json_keyword_for_api_when_missing() -> None:
|
||||
llm = _CapturePromptLlm('{"answer": "ok", "score": 2}')
|
||||
llm_json_call(
|
||||
llm,
|
||||
'你是助手。输出:{"x":1}',
|
||||
_SmallOut,
|
||||
max_tokens=32,
|
||||
agent="test_agent",
|
||||
)
|
||||
assert llm.last_prompt is not None
|
||||
# 用户内容里含 JSON 示例但无字母 json,仍应追加合规后缀
|
||||
assert "json" in llm.last_prompt.casefold()
|
||||
|
||||
|
||||
def test_llm_json_call_success() -> None:
|
||||
llm = _SyncFakeLlm(['{"answer": "ok", "score": 2}'])
|
||||
out = llm_json_call(
|
||||
llm,
|
||||
"prompt",
|
||||
_SmallOut,
|
||||
max_tokens=32,
|
||||
agent="test_agent",
|
||||
)
|
||||
assert out.answer == "ok"
|
||||
assert out.score == 2
|
||||
|
||||
|
||||
def test_llm_json_call_empty_retry_then_success() -> None:
|
||||
llm = _SyncFakeLlm(["", '{"answer": "retry", "score": 0}'])
|
||||
out = llm_json_call(
|
||||
llm,
|
||||
"p",
|
||||
_SmallOut,
|
||||
max_tokens=8,
|
||||
agent="t",
|
||||
)
|
||||
assert out.answer == "retry"
|
||||
assert llm._i == 2
|
||||
|
||||
|
||||
def test_llm_json_call_compat_strip() -> None:
|
||||
raw = '```json\n{"answer": "fenced", "score": 1}\n```'
|
||||
llm = _SyncFakeLlm([raw])
|
||||
out = llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t")
|
||||
assert out.answer == "fenced"
|
||||
|
||||
|
||||
def test_llm_json_call_validation_fallback() -> None:
|
||||
llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}'])
|
||||
out = llm_json_call(
|
||||
llm,
|
||||
"p",
|
||||
_SmallOut,
|
||||
max_tokens=8,
|
||||
agent="t",
|
||||
fallback_factory=lambda: _SmallOut(answer="fb"),
|
||||
)
|
||||
assert out.answer == "fb"
|
||||
|
||||
|
||||
def test_llm_json_call_no_fallback_raises() -> None:
|
||||
llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}'])
|
||||
with pytest.raises(LLMCallError) as ei:
|
||||
llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t")
|
||||
assert ei.value.kind == "validation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allm_json_call_parity_with_sync() -> None:
|
||||
llm = _SyncFakeLlm(['{"answer": "async", "score": 7}'])
|
||||
out = await allm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="at")
|
||||
assert out.answer == "async"
|
||||
assert out.score == 7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allm_json_call_fallback_on_decode() -> None:
|
||||
llm = _SyncFakeLlm(["not json at all {{{"])
|
||||
out = await allm_json_call(
|
||||
llm,
|
||||
"p",
|
||||
_SmallOut,
|
||||
max_tokens=8,
|
||||
agent="at",
|
||||
fallback_factory=lambda: _SmallOut(answer="dec"),
|
||||
)
|
||||
assert out.answer == "dec"
|
||||
13
api/tests/test_stage_slot_registry.py
Normal file
13
api/tests/test_stage_slot_registry.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""`STAGE_SLOT_KEYS` must stay aligned with `MemoirStateSchema.default_slots`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.stage_constants import STAGE_SLOT_KEYS
|
||||
from app.agents.state_schema import default_slots
|
||||
|
||||
|
||||
def test_stage_slot_keys_matches_default_slots() -> None:
|
||||
slots = default_slots()
|
||||
assert set(STAGE_SLOT_KEYS.keys()) == set(slots.keys())
|
||||
for stage, keys in STAGE_SLOT_KEYS.items():
|
||||
assert tuple(slots[stage].keys()) == keys
|
||||
@@ -1,11 +1,11 @@
|
||||
"""阶段 / 章节归一化纯函数(VALID + normalize + chat_bucket)。"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.memoir.extraction_agent import ExtractionAgent
|
||||
from app.agents.memoir.schemas import StateExtractionOutput
|
||||
from app.agents.stage_constants import (
|
||||
chat_bucket,
|
||||
normalize_chapter_category,
|
||||
@@ -47,10 +47,10 @@ def test_extraction_agent_normalizes_detected_stage(
|
||||
agent = ExtractionAgent()
|
||||
llm = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"app.agents.memoir.extraction_agent.invoke_json_object",
|
||||
lambda *_a, **_k: json.dumps(
|
||||
{"detected_stage": "career_early", "slots": {"job": "演员"}},
|
||||
ensure_ascii=False,
|
||||
"app.agents.memoir.extraction_agent.llm_json_call",
|
||||
lambda *_a, **_k: StateExtractionOutput(
|
||||
detected_stage="career_early",
|
||||
slots={"job": "演员"},
|
||||
),
|
||||
)
|
||||
r = agent.extract(
|
||||
@@ -69,10 +69,10 @@ def test_extraction_agent_invalid_detected_falls_back(
|
||||
agent = ExtractionAgent()
|
||||
llm = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"app.agents.memoir.extraction_agent.invoke_json_object",
|
||||
lambda *_a, **_k: json.dumps(
|
||||
{"detected_stage": "llm_hallucination", "slots": {}},
|
||||
ensure_ascii=False,
|
||||
"app.agents.memoir.extraction_agent.llm_json_call",
|
||||
lambda *_a, **_k: StateExtractionOutput(
|
||||
detected_stage="llm_hallucination",
|
||||
slots={},
|
||||
),
|
||||
)
|
||||
r = agent.extract(
|
||||
@@ -91,15 +91,12 @@ def test_extraction_agent_empty_slots_inherits_current_stage(
|
||||
agent = ExtractionAgent()
|
||||
llm = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"app.agents.memoir.extraction_agent.invoke_json_object",
|
||||
lambda *_a, **_k: json.dumps(
|
||||
{
|
||||
"detected_stage": "childhood",
|
||||
"slots": {},
|
||||
"emotion": "neutral",
|
||||
"is_new_chapter": False,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
"app.agents.memoir.extraction_agent.llm_json_call",
|
||||
lambda *_a, **_k: StateExtractionOutput(
|
||||
detected_stage="childhood",
|
||||
slots={},
|
||||
emotion="neutral",
|
||||
is_new_chapter=False,
|
||||
),
|
||||
)
|
||||
r = agent.extract(
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -12,7 +11,7 @@ from app.agents.memoir.prompts import (
|
||||
get_story_route_prompt,
|
||||
story_route_merge_hint_for_category,
|
||||
)
|
||||
from app.agents.memoir.story_route_agent import StoryRouteAgent
|
||||
from app.agents.memoir.story_route_agent import StoryRouteAgent, StoryRouteDecision
|
||||
|
||||
|
||||
def test_route_prompt_beliefs_has_strong_container_and_no_uncertain_new():
|
||||
@@ -52,15 +51,12 @@ def test_merge_hint_career_episodic():
|
||||
def test_decide_beliefs_mock_llm_append_and_prompt_has_payload():
|
||||
captured: dict[str, str] = {}
|
||||
|
||||
def fake_invoke(_llm, prompt: str, **_kwargs):
|
||||
def fake_llm_json(_llm, prompt: str, _schema: object, **_kwargs):
|
||||
captured["prompt"] = prompt
|
||||
return json.dumps(
|
||||
{
|
||||
"decision": "append_story",
|
||||
"target_story_id": "story-a",
|
||||
"reason": "同一价值观补充",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
return StoryRouteDecision(
|
||||
decision="append_story",
|
||||
target_story_id="story-a",
|
||||
reason="同一价值观补充",
|
||||
)
|
||||
|
||||
cand = SimpleNamespace(
|
||||
@@ -72,8 +68,8 @@ def test_decide_beliefs_mock_llm_append_and_prompt_has_payload():
|
||||
chapter_links=[],
|
||||
)
|
||||
with patch(
|
||||
"app.agents.memoir.story_route_agent.invoke_json_object",
|
||||
side_effect=fake_invoke,
|
||||
"app.agents.memoir.story_route_agent.llm_json_call",
|
||||
side_effect=fake_llm_json,
|
||||
):
|
||||
agent = StoryRouteAgent()
|
||||
d = agent.decide(
|
||||
@@ -94,15 +90,12 @@ def test_decide_beliefs_mock_llm_append_and_prompt_has_payload():
|
||||
def test_decide_career_mock_llm_new_story_and_prompt_episodic():
|
||||
captured: dict[str, str] = {}
|
||||
|
||||
def fake_invoke(_llm, prompt: str, **_kwargs):
|
||||
def fake_llm_json(_llm, prompt: str, _schema: object, **_kwargs):
|
||||
captured["prompt"] = prompt
|
||||
return json.dumps(
|
||||
{
|
||||
"decision": "new_story",
|
||||
"target_story_id": None,
|
||||
"reason": "另一段完全不同的任职经历",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
return StoryRouteDecision(
|
||||
decision="new_story",
|
||||
target_story_id=None,
|
||||
reason="另一段完全不同的任职经历",
|
||||
)
|
||||
|
||||
cand = SimpleNamespace(
|
||||
@@ -114,8 +107,8 @@ def test_decide_career_mock_llm_new_story_and_prompt_episodic():
|
||||
chapter_links=[],
|
||||
)
|
||||
with patch(
|
||||
"app.agents.memoir.story_route_agent.invoke_json_object",
|
||||
side_effect=fake_invoke,
|
||||
"app.agents.memoir.story_route_agent.llm_json_call",
|
||||
side_effect=fake_llm_json,
|
||||
):
|
||||
d = StoryRouteAgent().decide(
|
||||
chapter_category="career_achievement",
|
||||
|
||||
Reference in New Issue
Block a user