Files
life-echo/api/app/features/memoir/state_service.py
Kevin 59d4b19d7d feat(api): 回忆录管线简化、路由延迟池与相关加固
- Phase1/2:移除 MemoirOrchestrator.run 与 process_memoir_segments 别名;文档改为 process_memoir_phase1。
- 槽位校验集中到 stage_constants(filter_stage_slots),批处理与顺序路径及 state_service 写库一致。
- StoryRoute:no_llm/parse_error/invalid_target 保守 new_story;短篇护栏不覆盖这些 fallback。
- Phase2 低置信单路径可选延迟(StoryPipelineResult.deferred):不写 Chapter/Story,Segment 记录 defer 元数据,冷却内不重复消费;上限后停自动重试,Phase1 同类目新段唤醒池内段。
- Alembic 0017:segments 表 narrative_defer_* 列。
- ProfileAgent:经 LlmGateway/注入 Provider 统一聊天与 JSON,新增测试。
- ImagePromptOrchestrator:LLM 初始化失败可依配置降级或硬失败;补充策略测试。
- 配套单测与 README/本地开发文档表述更新。

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-06 13:18:02 +08:00

317 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录状态服务get_or_create_state、update_slot、mark_stage_complete 等。
供 memoir service、conversation ws 使用Celery 任务内使用同步版本。
"""
import uuid
from typing import Dict, List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.agents.stage_constants import (
allowed_slot_names_for_stage,
chat_bucket,
normalize_chat_stage,
)
from app.agents.state_schema import (
KnownFact,
MemoirStateSchema,
PersonaThread,
SlotData,
default_state,
narrative_coverage_state,
)
from app.core.config import settings
from app.features.memoir.models import MemoirState as MemoirStateModel
def _slots_snapshot_for_merge(raw: Dict[str, Dict] | None) -> Dict[str, Dict]:
"""浅拷贝 slots避免就地改 JSON 列同一 dict 引用导致 ORM 不标记 dirty。"""
if not raw or not isinstance(raw, dict):
return {}
return {k: dict(v or {}) for k, v in raw.items()}
def coerce_memoir_state(model: MemoirStateModel) -> MemoirStateSchema:
"""把 ORM 行投影成 MemoirStateSchema控制元数据来自独立列。"""
raw_slots = model.slots if isinstance(model.slots, dict) else None
clean_slots = raw_slots or dict(default_state().slots)
known_raw = (
model.known_facts_json if isinstance(model.known_facts_json, list) else []
)
persona_raw = (
model.persona_threads_json
if isinstance(model.persona_threads_json, list)
else []
)
recent_raw = (
model.recent_questions_json
if isinstance(model.recent_questions_json, list)
else []
)
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": clean_slots,
"known_facts": [x for x in known_raw if isinstance(x, dict)],
"persona_threads": [x for x in persona_raw if isinstance(x, dict)],
"recent_questions": [str(x).strip() for x in recent_raw if str(x).strip()],
}
)
async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return coerce_memoir_state(state)
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
db.add(state)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
def _apply_current_stage_policy(
state: MemoirStateModel,
stage_norm: str,
*,
memoir_batch: bool,
) -> None:
"""按 memoir_extraction_updates_current_stage 与 chat_bucket 真值表更新 current_stage。"""
current_from_db = state.current_stage or "childhood"
if not memoir_batch:
state.current_stage = stage_norm
return
if not settings.memoir_extraction_updates_current_stage:
return
cur_b = chat_bucket(state.current_stage or current_from_db)
new_b = chat_bucket(stage_norm)
if new_b == cur_b:
state.current_stage = stage_norm
async def update_slot(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: AsyncSession,
*,
memoir_batch: bool = False,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
current_from_db = state.current_stage or "childhood"
stage_norm = normalize_chat_stage(
stage,
fallback=current_from_db,
log_context={"user_id": user_id},
)
if slot_name not in allowed_slot_names_for_stage(stage_norm, current_from_db):
return coerce_memoir_state(state)
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
stage_slots = dict(slots.get(stage_norm, {}) or {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage_norm] = stage_slots
state.slots = slots
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
async def mark_stage_complete(
user_id: str, stage: str, db: AsyncSession
) -> MemoirStateSchema:
"""推进 covered_stages 并在当前阶段匹配时尝试进入下一阶段。当前无调用方,预留未来阶段推进逻辑。"""
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
return await get_or_create_state(user_id, db)
covered = state.covered_stages or []
if stage not in covered:
covered.append(stage)
state.covered_stages = covered
stage_order = state.stage_order or default_state().stage_order
if state.current_stage == stage:
try:
idx = stage_order.index(stage)
state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)]
except ValueError:
state.current_stage = default_state().current_stage
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]:
state = await get_or_create_state(user_id, db)
return narrative_coverage_state(state).empty_slots_for_current_stage()
async def switch_stage(
user_id: str, new_stage: str, db: AsyncSession
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
fb = state.current_stage or "childhood"
state.current_stage = normalize_chat_stage(
new_stage, fallback=fb, log_context={"user_id": user_id}
)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
async def save_interview_state_meta(
user_id: str,
*,
known_facts: list[KnownFact],
persona_threads: list[PersonaThread],
recent_questions: list[str],
db: AsyncSession,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
state.known_facts_json = [x.model_dump() for x in known_facts]
state.persona_threads_json = [x.model_dump() for x in persona_threads]
state.recent_questions_json = list(recent_questions)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
def get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return coerce_memoir_state(state)
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
db.add(state)
db.commit()
db.refresh(state)
return coerce_memoir_state(state)
def update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: Session,
*,
memoir_batch: bool = True,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
current_from_db = state.current_stage or "childhood"
stage_norm = normalize_chat_stage(
stage,
fallback=current_from_db,
log_context={"user_id": user_id},
)
if slot_name not in allowed_slot_names_for_stage(stage_norm, current_from_db):
return coerce_memoir_state(state)
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
stage_slots = dict(slots.get(stage_norm, {}) or {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage_norm] = stage_slots
state.slots = slots
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
db.commit()
db.refresh(state)
return coerce_memoir_state(state)