Files
life-echo/api/app/agents/memoir/story_route_agent.py
Kevin a3f61fcc0f feat(api+app): 对话阶段化、回忆录流水线与客户端会话体验
- DB: segments 用户输入文本(Alembic 0002)
- Chat: 阶段检测/阶段提示/回复限制,编排与访谈/画像 prompts 调整
- Memoir: 忠实度检查 agent,叙事与分类等链路更新
- Core: agent 日志、Alembic 启动、LangChain/日志/配置等
- Story: time_hints;Memory 检索与相关测试
- Expo: 助手头像、会话页与消息拆分、实时会话与文案/i18n
- Docs/scripts/tests: 迁移脚本、LLM JSON/记忆检索文档、新增单测
2026-03-26 12:13:36 +08:00

236 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
StoryRouteAgentCelery 批次内判断 new_story vs append_storyJSON
"""
from __future__ import annotations
import json
from typing import Any, Literal
from pydantic import BaseModel, field_validator
from app.agents.memoir.prompts import (
get_story_batch_plan_prompt,
get_story_route_prompt,
)
from app.core.langchain_llm import invoke_json_object
from app.core.logging import get_logger
from app.features.story.models import Story
logger = get_logger(__name__)
# 超过此数量跳过批量规划(单次路由),避免 prompt 过大
PLAN_BATCH_MAX_SEGMENTS = 48
class StoryBatchPlanUnit(BaseModel):
"""批量写入中的一个单元(连续 segment 块)。"""
segment_ids: list[str]
decision: Literal["new_story", "append_story"]
target_story_id: str | None = None
new_story_title: str | None = None
reason: str | None = None
@field_validator("target_story_id", mode="before")
@classmethod
def empty_str_to_none_tid(cls, v: Any) -> str | None:
if v is None or v == "":
return None
if isinstance(v, str):
return v.strip() or None
return str(v)
class StoryBatchPlan(BaseModel):
units: list[StoryBatchPlanUnit]
class StoryRouteDecision(BaseModel):
decision: Literal["new_story", "append_story"]
target_story_id: str | None = None
new_story_title: str | None = None
reason: str | None = None
@field_validator("target_story_id", mode="before")
@classmethod
def empty_str_to_none(cls, v: Any) -> str | None:
if v is None or v == "":
return None
if isinstance(v, str):
return v.strip() or None
return str(v)
def _build_candidate_json(stories: list[Story], *, preview_chars: int = 220) -> str:
rows: list[dict[str, Any]] = []
for s in stories:
md = (s.canonical_markdown or "").strip().replace("\n", " ")
preview = md[:preview_chars] + ("" if len(md) > preview_chars else "")
links: list[str] = []
for cl in getattr(s, "chapter_links", None) or []:
ch = getattr(cl, "chapter", None)
if ch is None:
continue
cat = getattr(ch, "category", None) or ""
tit = getattr(ch, "title", None) or ""
links.append(f"{tit}({cat})")
rows.append(
{
"id": s.id,
"title": s.title,
"preview": preview,
"linked_chapters": links,
}
)
return json.dumps(rows, ensure_ascii=False, indent=2)
def _build_segments_json_for_plan(
segments: list[tuple[str, str]], *, text_preview_chars: int = 4000
) -> str:
"""segments: (id, user_input_text) 按口述顺序。"""
rows: list[dict[str, str]] = []
for sid, text in segments:
t = (text or "").strip()
if len(t) > text_preview_chars:
t = t[:text_preview_chars] + ""
rows.append({"id": sid, "text": t})
return json.dumps(rows, ensure_ascii=False, indent=2)
def validate_story_batch_plan(
ordered_segment_ids: list[str],
plan: StoryBatchPlan,
valid_story_ids: set[str],
) -> tuple[bool, str | None]:
"""
校验segment 全覆盖、顺序一致、append 目标合法、new_story 有标题。
返回 (ok, error_code)。
"""
if not plan.units:
return False, "empty_units"
flat: list[str] = []
for u in plan.units:
if not u.segment_ids:
return False, "empty_unit_segment_ids"
flat.extend(u.segment_ids)
if len(flat) != len(set(flat)):
return False, "duplicate_segment"
if flat != ordered_segment_ids:
return False, "segment_mismatch"
for u in plan.units:
if u.decision == "append_story":
tid = u.target_story_id
if not tid or tid not in valid_story_ids:
return False, "invalid_append_target"
else:
title = (u.new_story_title or "").strip()
if not title:
return False, "missing_new_title"
return True, None
class StoryRouteAgent:
def decide(
self,
*,
chapter_category: str,
chapter_title: str,
batch_transcript: str,
candidate_stories: list[Story],
llm: Any,
valid_story_ids: set[str],
) -> StoryRouteDecision:
if not llm:
return StoryRouteDecision(
decision="new_story",
new_story_title=None,
reason="no_llm",
)
payload = _build_candidate_json(candidate_stories)
prompt = get_story_route_prompt(
chapter_category=chapter_category,
chapter_title=chapter_title,
batch_transcript=batch_transcript,
candidate_stories_json=payload,
)
try:
raw = invoke_json_object(
llm,
prompt,
max_tokens=1024,
agent="StoryRouteAgent.decide",
).strip()
data = json.loads(raw)
decision = StoryRouteDecision.model_validate(data)
except Exception as e:
logger.warning("StoryRouteAgent 解析失败: {}", e)
return StoryRouteDecision(
decision="new_story",
new_story_title=None,
reason="parse_error",
)
if decision.decision == "append_story":
tid = decision.target_story_id
if not tid or tid not in valid_story_ids:
logger.warning(
"StoryRoute append 无效 target_story_id={},回退 new_story",
tid,
)
return StoryRouteDecision(
decision="new_story",
new_story_title=decision.new_story_title,
reason="invalid_target",
)
if decision.decision == "new_story" and not (
decision.new_story_title and decision.new_story_title.strip()
):
decision.new_story_title = None
return decision
def plan_batch(
self,
*,
chapter_category: str,
chapter_title: str,
segments: list[tuple[str, str]],
candidate_stories: list[Story],
llm: Any,
valid_story_ids: set[str],
) -> StoryBatchPlan | None:
"""
将本批 segment 划分为多个写入单元。解析失败返回 None由调用方回退 decide()。
"""
if not llm or len(segments) < 2:
return None
payload = _build_candidate_json(candidate_stories)
segments_json = _build_segments_json_for_plan(segments)
prompt = get_story_batch_plan_prompt(
chapter_category=chapter_category,
chapter_title=chapter_title,
segments_json=segments_json,
candidate_stories_json=payload,
)
try:
raw = invoke_json_object(
llm,
prompt,
max_tokens=4096,
agent="StoryRouteAgent.plan_batch",
).strip()
data = json.loads(raw)
plan = StoryBatchPlan.model_validate(data)
except Exception as e:
logger.warning("StoryRouteAgent.plan_batch 解析失败: {}", e)
return None
ordered = [s[0] for s in segments]
ok, err = validate_story_batch_plan(ordered, plan, valid_story_ids)
if not ok:
logger.warning("StoryRouteAgent.plan_batch 校验失败: {}", err)
return None
return plan