Files
life-echo/api/tests/test_story_route_agent.py
2026-03-20 15:15:35 +08:00

74 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""StoryRouteAgentLLM JSON 决策与非法 target 回退。"""
import unittest
from types import SimpleNamespace
from unittest.mock import Mock
from app.agents.memoir.story_route_agent import StoryRouteAgent
def _story_stub(sid: str, title: str = "T"):
return SimpleNamespace(
id=sid,
title=title,
canonical_markdown="预览正文",
chapter_links=[],
)
class StoryRouteAgentTest(unittest.TestCase):
def test_no_llm_returns_new_story(self):
agent = StoryRouteAgent()
out = agent.decide(
chapter_category="childhood",
chapter_title="童年",
batch_transcript="hello",
candidate_stories=[_story_stub("s1")],
llm=None,
valid_story_ids={"s1"},
)
self.assertEqual(out.decision, "new_story")
self.assertIsNone(out.new_story_title)
def test_append_invalid_id_falls_back_to_new_story(self):
agent = StoryRouteAgent()
llm = Mock()
bound = Mock()
llm.bind.return_value = bound
bound.invoke.return_value = SimpleNamespace(
content='{"decision":"append_story","target_story_id":"unknown"}'
)
out = agent.decide(
chapter_category="childhood",
chapter_title="童年",
batch_transcript="hello",
candidate_stories=[_story_stub("s1")],
llm=llm,
valid_story_ids={"s1"},
)
self.assertEqual(out.decision, "new_story")
self.assertEqual(out.reason, "invalid_target")
def test_append_valid_target(self):
agent = StoryRouteAgent()
llm = Mock()
bound = Mock()
llm.bind.return_value = bound
bound.invoke.return_value = SimpleNamespace(
content='{"decision":"append_story","target_story_id":"s1"}'
)
out = agent.decide(
chapter_category="childhood",
chapter_title="童年",
batch_transcript="more text",
candidate_stories=[_story_stub("s1")],
llm=llm,
valid_story_ids={"s1"},
)
self.assertEqual(out.decision, "append_story")
self.assertEqual(out.target_story_id, "s1")
if __name__ == "__main__":
unittest.main()