Files
life-echo/api/app/agents/memoir/batch_phase1_prep.py
Kevin 59d4b19d7d feat(api): 回忆录管线简化、路由延迟池与相关加固
- Phase1/2:移除 MemoirOrchestrator.run 与 process_memoir_segments 别名;文档改为 process_memoir_phase1。
- 槽位校验集中到 stage_constants(filter_stage_slots),批处理与顺序路径及 state_service 写库一致。
- StoryRoute:no_llm/parse_error/invalid_target 保守 new_story;短篇护栏不覆盖这些 fallback。
- Phase2 低置信单路径可选延迟(StoryPipelineResult.deferred):不写 Chapter/Story,Segment 记录 defer 元数据,冷却内不重复消费;上限后停自动重试,Phase1 同类目新段唤醒池内段。
- Alembic 0017:segments 表 narrative_defer_* 列。
- ProfileAgent:经 LlmGateway/注入 Provider 统一聊天与 JSON,新增测试。
- ImagePromptOrchestrator:LLM 初始化失败可依配置降级或硬失败;补充策略测试。
- 配套单测与 README/本地开发文档表述更新。

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-06 13:18:02 +08:00

176 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Phase1 批处理:一次 LLM 调用完成多段的抽取 + 章节分类(与逐段循环语义对齐)。
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List
from app.agents.memoir.prompts import get_batch_memoir_phase1_prep_prompt
from app.agents.memoir.schemas import BatchPhase1LLMOutput
from app.agents.state_schema import MemoirStateSchema
from app.core.config import settings
from app.core.llm_call import LLMCallError, llm_json_call
from app.core.logging import get_logger
from app.features.conversation.models import Segment
logger = get_logger(__name__)
def _slots_snapshot(state: MemoirStateSchema) -> dict:
snap: dict = {}
for stage, buckets in (state.slots or {}).items():
snap[stage] = {}
for k, v in (buckets or {}).items():
if hasattr(v, "snippet"):
sn = getattr(v, "snippet", None) or ""
elif isinstance(v, dict):
sn = (
(v.get("snippet") or "")
if isinstance(v.get("snippet"), str)
else ""
)
else:
sn = ""
snap[stage][k] = (sn or "")[:120]
return snap
@dataclass(frozen=True)
class BatchPhase1SegmentRow:
detected_stage: str
slots: Dict[str, str]
chapter_category_raw: str
def run_batch_phase1_prep(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
) -> Dict[str, BatchPhase1SegmentRow]:
"""对 segments 顺序批量调用 LLM返回 id → 行。id 集合必须与入参完全一致。"""
if not llm:
raise ValueError("batch phase1 requires llm")
if not segments:
return {}
items = [(str(s.id), (s.user_input_text or "").strip()) for s in segments]
prompt = get_batch_memoir_phase1_prep_prompt(
system_current_stage=state.current_stage or "childhood",
slots_snapshot=_slots_snapshot(state),
segment_items=items,
)
try:
parsed = llm_json_call(
llm,
prompt,
BatchPhase1LLMOutput,
max_tokens=int(settings.memoir_phase1_batch_llm_max_tokens),
agent="BatchPhase1Prep.run",
)
except LLMCallError as e:
logger.warning("batch phase1 LLM 解析失败: {}", e)
raise ValueError("batch phase1: llm parse failed") from e
rows = parsed.segments
if not rows:
raise ValueError("batch phase1: segments must be a non-empty list")
by_id: Dict[str, BatchPhase1SegmentRow] = {}
for row in rows:
sid = str(row.id).strip()
if not sid:
continue
ds = str(row.detected_stage or "").strip().lower()
slots_raw = row.slots or {}
slots = {
k: v if isinstance(v, str) else str(v)
for k, v in slots_raw.items()
if k and isinstance(k, str)
}
cat_raw = str(row.chapter_category or "")
by_id[sid] = BatchPhase1SegmentRow(
detected_stage=ds or (state.current_stage or "childhood"),
slots=slots,
chapter_category_raw=cat_raw,
)
expected = {str(s.id) for s in segments}
if by_id.keys() != expected:
missing = expected - by_id.keys()
extra = by_id.keys() - expected
logger.warning("batch phase1 id mismatch missing={} extra={}", missing, extra)
raise ValueError("batch phase1 response segment ids do not match input")
return by_id
def _run_batch_phase1_prep_chunk_with_bisect(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
) -> Dict[str, BatchPhase1SegmentRow]:
"""单块 LLM失败时如输出截断将块二等分重试直至单段。"""
try:
return run_batch_phase1_prep(segments, state, llm)
except ValueError:
if len(segments) <= 1:
raise
mid = len(segments) // 2
if mid < 1:
raise
left = _run_batch_phase1_prep_chunk_with_bisect(segments[:mid], state, llm)
right = _run_batch_phase1_prep_chunk_with_bisect(segments[mid:], state, llm)
merged = {**left, **right}
expected = {str(s.id) for s in segments}
if merged.keys() != expected:
raise ValueError(
"batch phase1 chunked bisect merge: segment ids do not match input"
) from None
return merged
def run_batch_phase1_prep_chunked(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
*,
chunk_size: int,
on_chunk: Callable[[int, int], None] | None = None,
) -> Dict[str, BatchPhase1SegmentRow]:
"""
将 segments 按 chunk_size 切片多次调用 Phase1 批处理 LLM合并 by_id。
单块仍失败时在块内二分回退(最后回退到单段),与 orchestrator 外层逐段回退衔接。
"""
if not segments:
return {}
if chunk_size < 1:
chunk_size = 1
n = len(segments)
total_chunks = max(1, math.ceil(n / chunk_size))
merged: Dict[str, BatchPhase1SegmentRow] = {}
for i in range(0, n, chunk_size):
chunk_idx = i // chunk_size + 1
sub = segments[i : i + chunk_size]
logger.info(
"event=batch_phase1_chunk chunk_idx={}/{} segment_count={} batch_path=chunked "
"msg=Phase1 批处理分块调用",
chunk_idx,
total_chunks,
len(sub),
)
part = _run_batch_phase1_prep_chunk_with_bisect(sub, state, llm)
merged.update(part)
if on_chunk is not None:
on_chunk(chunk_idx, total_chunks)
expected = {str(s.id) for s in segments}
if merged.keys() != expected:
missing = expected - merged.keys()
extra = merged.keys() - expected
logger.warning(
"batch phase1 chunked id mismatch missing={} extra={}",
missing,
extra,
)
raise ValueError("batch phase1 chunked: merged segment ids do not match input")
return merged