Files
life-echo/api/app/agents/memoir/story_route_agent.py
Kevin b853b986dd fix(memoir): 改善 story 合并决策,少生碎片篇
以前模型只看到很短预览,还容易被引导成新建 story。现在优先用已有摘要、
按需带正文片段,并区分「像续写同一主题」和「像换了一件事」;
beliefs/summary 更鼓励接着写, career/童年等仍可按新事件新开。
2026-04-03 11:02:05 +08:00

209 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
StoryRouteAgentCelery 批次内判断 new_story vs append_storyJSON
"""
from __future__ import annotations
import json
from typing import Any, Literal
from pydantic import BaseModel, field_validator
from app.agents.memoir.prompts import (
get_story_batch_plan_prompt,
get_story_route_prompt,
)
from app.agents.memoir.story_route_payload import build_route_candidate_json
from app.core.config import settings
from app.core.langchain_llm import invoke_json_object
from app.core.logging import get_logger
from app.features.story.models import Story
logger = get_logger(__name__)
# 超过此数量跳过批量规划(单次路由),避免 prompt 过大
PLAN_BATCH_MAX_SEGMENTS = 48
class StoryBatchPlanUnit(BaseModel):
"""批量写入中的一个单元(连续 segment 块)。"""
segment_ids: list[str]
decision: Literal["new_story", "append_story"]
target_story_id: str | None = None
new_story_title: str | None = None
reason: str | None = None
@field_validator("target_story_id", mode="before")
@classmethod
def empty_str_to_none_tid(cls, v: Any) -> str | None:
if v is None or v == "":
return None
if isinstance(v, str):
return v.strip() or None
return str(v)
class StoryBatchPlan(BaseModel):
units: list[StoryBatchPlanUnit]
class StoryRouteDecision(BaseModel):
decision: Literal["new_story", "append_story"]
target_story_id: str | None = None
new_story_title: str | None = None
reason: str | None = None
@field_validator("target_story_id", mode="before")
@classmethod
def empty_str_to_none(cls, v: Any) -> str | None:
if v is None or v == "":
return None
if isinstance(v, str):
return v.strip() or None
return str(v)
def _build_segments_json_for_plan(
segments: list[tuple[str, str]], *, text_preview_chars: int = 4000
) -> str:
"""segments: (id, user_input_text) 按口述顺序。"""
rows: list[dict[str, str]] = []
for sid, text in segments:
t = (text or "").strip()
if len(t) > text_preview_chars:
t = t[:text_preview_chars] + ""
rows.append({"id": sid, "text": t})
return json.dumps(rows, ensure_ascii=False, indent=2)
def validate_story_batch_plan(
ordered_segment_ids: list[str],
plan: StoryBatchPlan,
valid_story_ids: set[str],
) -> tuple[bool, str | None]:
"""
校验segment 全覆盖、顺序一致、append 目标合法。
标题由 NarrativeAgent 延迟生成,路由阶段不再要求 new_story_title。
返回 (ok, error_code)。
"""
if not plan.units:
return False, "empty_units"
flat: list[str] = []
for u in plan.units:
if not u.segment_ids:
return False, "empty_unit_segment_ids"
flat.extend(u.segment_ids)
if len(flat) != len(set(flat)):
return False, "duplicate_segment"
if flat != ordered_segment_ids:
return False, "segment_mismatch"
for u in plan.units:
if u.decision == "append_story":
tid = u.target_story_id
if not tid or tid not in valid_story_ids:
return False, "invalid_append_target"
return True, None
class StoryRouteAgent:
def decide(
self,
*,
chapter_category: str,
chapter_title: str,
batch_transcript: str,
candidate_stories: list[Story],
llm: Any,
valid_story_ids: set[str],
story_meta: dict[str, dict[str, int]] | None = None,
) -> StoryRouteDecision:
if not llm:
return StoryRouteDecision(
decision="new_story",
new_story_title=None,
reason="no_llm",
)
payload = build_route_candidate_json(candidate_stories, story_meta, settings)
prompt = get_story_route_prompt(
chapter_category=chapter_category,
chapter_title=chapter_title,
batch_transcript=batch_transcript,
candidate_stories_json=payload,
)
try:
raw = invoke_json_object(
llm,
prompt,
max_tokens=1024,
agent="StoryRouteAgent.decide",
).strip()
data = json.loads(raw)
decision = StoryRouteDecision.model_validate(data)
except Exception as e:
logger.warning("StoryRouteAgent 解析失败: {}", e)
return StoryRouteDecision(
decision="new_story",
new_story_title=None,
reason="parse_error",
)
if decision.decision == "append_story":
tid = decision.target_story_id
if not tid or tid not in valid_story_ids:
logger.warning(
"StoryRoute append 无效 target_story_id={},回退 new_story",
tid,
)
return StoryRouteDecision(
decision="new_story",
new_story_title=decision.new_story_title,
reason="invalid_target",
)
return decision
def plan_batch(
self,
*,
chapter_category: str,
chapter_title: str,
segments: list[tuple[str, str]],
candidate_stories: list[Story],
llm: Any,
valid_story_ids: set[str],
story_meta: dict[str, dict[str, int]] | None = None,
) -> StoryBatchPlan | None:
"""
将本批 segment 划分为多个写入单元。解析失败返回 None由调用方回退 decide()。
"""
if not llm or len(segments) < 2:
return None
payload = build_route_candidate_json(candidate_stories, story_meta, settings)
segments_json = _build_segments_json_for_plan(segments)
prompt = get_story_batch_plan_prompt(
chapter_category=chapter_category,
chapter_title=chapter_title,
segments_json=segments_json,
candidate_stories_json=payload,
)
try:
raw = invoke_json_object(
llm,
prompt,
max_tokens=4096,
agent="StoryRouteAgent.plan_batch",
).strip()
data = json.loads(raw)
plan = StoryBatchPlan.model_validate(data)
except Exception as e:
logger.warning("StoryRouteAgent.plan_batch 解析失败: {}", e)
return None
ordered = [s[0] for s in segments]
ok, err = validate_story_batch_plan(ordered, plan, valid_story_ids)
if not ok:
logger.warning("StoryRouteAgent.plan_batch 校验失败: {}", err)
return None
return plan