74 lines
2.2 KiB
Python
74 lines
2.2 KiB
Python
"""StoryRouteAgent:LLM JSON 决策与非法 target 回退。"""
|
||
|
||
import unittest
|
||
from types import SimpleNamespace
|
||
from unittest.mock import Mock
|
||
|
||
from app.agents.memoir.story_route_agent import StoryRouteAgent
|
||
|
||
|
||
def _story_stub(sid: str, title: str = "T"):
|
||
return SimpleNamespace(
|
||
id=sid,
|
||
title=title,
|
||
canonical_markdown="预览正文",
|
||
chapter_links=[],
|
||
)
|
||
|
||
|
||
class StoryRouteAgentTest(unittest.TestCase):
|
||
def test_no_llm_returns_new_story(self):
|
||
agent = StoryRouteAgent()
|
||
out = agent.decide(
|
||
chapter_category="childhood",
|
||
chapter_title="童年",
|
||
batch_transcript="hello",
|
||
candidate_stories=[_story_stub("s1")],
|
||
llm=None,
|
||
valid_story_ids={"s1"},
|
||
)
|
||
self.assertEqual(out.decision, "new_story")
|
||
self.assertIsNone(out.new_story_title)
|
||
|
||
def test_append_invalid_id_falls_back_to_new_story(self):
|
||
agent = StoryRouteAgent()
|
||
llm = Mock()
|
||
bound = Mock()
|
||
llm.bind.return_value = bound
|
||
bound.invoke.return_value = SimpleNamespace(
|
||
content='{"decision":"append_story","target_story_id":"unknown"}'
|
||
)
|
||
out = agent.decide(
|
||
chapter_category="childhood",
|
||
chapter_title="童年",
|
||
batch_transcript="hello",
|
||
candidate_stories=[_story_stub("s1")],
|
||
llm=llm,
|
||
valid_story_ids={"s1"},
|
||
)
|
||
self.assertEqual(out.decision, "new_story")
|
||
self.assertEqual(out.reason, "invalid_target")
|
||
|
||
def test_append_valid_target(self):
|
||
agent = StoryRouteAgent()
|
||
llm = Mock()
|
||
bound = Mock()
|
||
llm.bind.return_value = bound
|
||
bound.invoke.return_value = SimpleNamespace(
|
||
content='{"decision":"append_story","target_story_id":"s1"}'
|
||
)
|
||
out = agent.decide(
|
||
chapter_category="childhood",
|
||
chapter_title="童年",
|
||
batch_transcript="more text",
|
||
candidate_stories=[_story_stub("s1")],
|
||
llm=llm,
|
||
valid_story_ids={"s1"},
|
||
)
|
||
self.assertEqual(out.decision, "append_story")
|
||
self.assertEqual(out.target_story_id, "s1")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|