Files
life-echo/api/app/agents/memoir/batch_phase1_prep.py
Kevin 07c6478742 feat(api): 访谈路径轻量门控、Memoir Phase1 批处理与叙事/记忆管线加固
- 新增 utterance_substance:短时/应答/元话语可跳过记忆检索、阶段 LLM 与资料抽取 LLM;可配置
- 输入归一化:LLM 模式默认仅语音/ASR;配置项写入 .env.example
- Memoir Phase1:可选 batch LLM 一次性抽取+分类(失败回退逐段);Extraction 空槽位时阶段与 current_stage 对齐,prompt 约束收紧
- 叙事与忠实度:narrative_safety、证据重叠/场合锚点、标题 slots 与履历短语 grounded;fidelity 解析失败 fail-open 可配置
- 章节管线:锁 TTL 上调、锁竞争 Celery 重试、Phase2 immediate singleflight 等;story_pipeline_sync / chapter_compose / memoir_tasks 联动
- Memory:compaction / repo / summarizer / evidence 小修;事实 FTS 未命中是否回退最近事实可配置
- 新增 memoir_pipeline_trace;补充 memoir_reliability 文档与多项回归/门控测试
2026-04-03 10:12:59 +08:00

115 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Phase1 批处理:一次 LLM 调用完成多段的抽取 + 章节分类(与逐段循环语义对齐)。
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict, List
from app.agents.memoir.prompts import get_batch_memoir_phase1_prep_prompt
from app.agents.state_schema import MemoirStateSchema
from app.core.config import settings
from app.core.json_utils import extract_json_payload
from app.core.langchain_llm import invoke_json_object
from app.core.logging import get_logger
from app.features.conversation.models import Segment
logger = get_logger(__name__)
STAGE_ALLOWED_SLOTS: Dict[str, frozenset[str]] = {
"childhood": frozenset(
{"place", "people", "daily_life", "emotion", "turning_event"}
),
"education": frozenset({"school", "city", "motivation", "challenge", "change"}),
"career": frozenset({"job", "environment", "decision", "pressure", "growth"}),
"family": frozenset(
{"relationship", "conflict", "support", "responsibility", "change"}
),
"belief": frozenset({"value", "regret", "pride", "lesson"}),
}
def _slots_snapshot(state: MemoirStateSchema) -> dict:
snap: dict = {}
for stage, buckets in (state.slots or {}).items():
snap[stage] = {}
for k, v in (buckets or {}).items():
if hasattr(v, "snippet"):
sn = getattr(v, "snippet", None) or ""
elif isinstance(v, dict):
sn = (
(v.get("snippet") or "")
if isinstance(v.get("snippet"), str)
else ""
)
else:
sn = ""
snap[stage][k] = (sn or "")[:120]
return snap
@dataclass(frozen=True)
class BatchPhase1SegmentRow:
detected_stage: str
slots: Dict[str, str]
chapter_category_raw: str
def run_batch_phase1_prep(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
) -> Dict[str, BatchPhase1SegmentRow]:
"""对 segments 顺序批量调用 LLM返回 id → 行。id 集合必须与入参完全一致。"""
if not llm:
raise ValueError("batch phase1 requires llm")
if not segments:
return {}
items = [(str(s.id), (s.user_input_text or "").strip()) for s in segments]
prompt = get_batch_memoir_phase1_prep_prompt(
system_current_stage=state.current_stage or "childhood",
slots_snapshot=_slots_snapshot(state),
segment_items=items,
)
raw = invoke_json_object(
llm,
prompt,
max_tokens=int(settings.memoir_phase1_batch_llm_max_tokens),
agent="BatchPhase1Prep.run",
)
parsed = json.loads(extract_json_payload(raw))
rows = parsed.get("segments") or []
if not isinstance(rows, list):
raise ValueError("batch phase1: segments must be a list")
by_id: Dict[str, BatchPhase1SegmentRow] = {}
for row in rows:
if not isinstance(row, dict):
continue
sid = str(row.get("id", "")).strip()
if not sid:
continue
ds = str(row.get("detected_stage", "") or "").strip().lower()
slots_raw = row.get("slots") or {}
slots: Dict[str, str] = {}
if isinstance(slots_raw, dict):
for k, v in slots_raw.items():
if k and isinstance(k, str):
slots[k] = v if isinstance(v, str) else str(v)
cat_raw = str(row.get("chapter_category", row.get("category", "")) or "")
by_id[sid] = BatchPhase1SegmentRow(
detected_stage=ds or (state.current_stage or "childhood"),
slots=slots,
chapter_category_raw=cat_raw,
)
expected = {str(s.id) for s in segments}
if by_id.keys() != expected:
missing = expected - by_id.keys()
extra = by_id.keys() - expected
logger.warning("batch phase1 id mismatch missing={} extra={}", missing, extra)
raise ValueError("batch phase1 response segment ids do not match input")
return by_id