Files
life-echo/api/app/agents/memoir/story_route_agent.py
2026-03-20 15:15:35 +08:00

115 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
StoryRouteAgentCelery 批次内判断 new_story vs append_storyJSON
"""
from __future__ import annotations
import json
from typing import Any, Literal
from pydantic import BaseModel, field_validator
from app.agents.memoir.prompts import get_story_route_prompt
from app.core.langchain_llm import bind_json_object_mode
from app.core.logging import get_logger
from app.features.story.models import Story
logger = get_logger(__name__)
class StoryRouteDecision(BaseModel):
decision: Literal["new_story", "append_story"]
target_story_id: str | None = None
new_story_title: str | None = None
reason: str | None = None
@field_validator("target_story_id", mode="before")
@classmethod
def empty_str_to_none(cls, v: Any) -> str | None:
if v is None or v == "":
return None
if isinstance(v, str):
return v.strip() or None
return str(v)
def _build_candidate_json(stories: list[Story], *, preview_chars: int = 220) -> str:
rows: list[dict[str, Any]] = []
for s in stories:
md = (s.canonical_markdown or "").strip().replace("\n", " ")
preview = md[:preview_chars] + ("" if len(md) > preview_chars else "")
links: list[str] = []
for cl in getattr(s, "chapter_links", None) or []:
ch = getattr(cl, "chapter", None)
if ch is None:
continue
cat = getattr(ch, "category", None) or ""
tit = getattr(ch, "title", None) or ""
links.append(f"{tit}({cat})")
rows.append(
{
"id": s.id,
"title": s.title,
"preview": preview,
"linked_chapters": links,
}
)
return json.dumps(rows, ensure_ascii=False, indent=2)
class StoryRouteAgent:
def decide(
self,
*,
chapter_category: str,
chapter_title: str,
batch_transcript: str,
candidate_stories: list[Story],
llm: Any,
valid_story_ids: set[str],
) -> StoryRouteDecision:
if not llm:
return StoryRouteDecision(
decision="new_story",
new_story_title=None,
reason="no_llm",
)
payload = _build_candidate_json(candidate_stories)
prompt = get_story_route_prompt(
chapter_category=chapter_category,
chapter_title=chapter_title,
batch_transcript=batch_transcript,
candidate_stories_json=payload,
)
try:
json_llm = bind_json_object_mode(llm, max_tokens=1024)
response = json_llm.invoke(prompt)
raw = (response.content or "").strip()
data = json.loads(raw)
decision = StoryRouteDecision.model_validate(data)
except Exception as e:
logger.warning("StoryRouteAgent 解析失败: %s", e)
return StoryRouteDecision(
decision="new_story",
new_story_title=None,
reason="parse_error",
)
if decision.decision == "append_story":
tid = decision.target_story_id
if not tid or tid not in valid_story_ids:
logger.warning(
"StoryRoute append 无效 target_story_id=%s,回退 new_story",
tid,
)
return StoryRouteDecision(
decision="new_story",
new_story_title=decision.new_story_title,
reason="invalid_target",
)
if decision.decision == "new_story" and not (
decision.new_story_title and decision.new_story_title.strip()
):
decision.new_story_title = None
return decision