""" 回忆录状态服务:get_or_create_state、update_slot、mark_stage_complete 等。 供 memoir service、conversation ws 使用;Celery 任务内使用同步版本。 """ import uuid from typing import Dict, List, cast from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.agents.stage_constants import ( chat_bucket, normalize_chat_stage, ) from app.agents.state_schema import ( KnownFact, MemoirStateSchema, PersonaThread, SlotData, default_state, ) from app.core.config import settings from app.features.memoir.models import MemoirState as MemoirStateModel _INTERVIEW_STATE_META_KEY = "__interview_state__" def _slots_snapshot_for_merge(raw: Dict[str, Dict] | None) -> Dict[str, Dict]: """浅拷贝 slots,避免就地改 JSON 列同一 dict 引用导致 ORM 不标记 dirty。""" if not raw or not isinstance(raw, dict): return {} return {k: dict(v or {}) for k, v in raw.items()} def _extract_interview_state_meta( raw_slots: Dict[str, Dict] | None, ) -> tuple[list[KnownFact], list[PersonaThread], list[str]]: if not raw_slots or not isinstance(raw_slots, dict): return [], [], [] meta = raw_slots.get(_INTERVIEW_STATE_META_KEY) if not isinstance(meta, dict): return [], [], [] known = meta.get("known_facts") if isinstance(meta.get("known_facts"), list) else [] persona = ( meta.get("persona_threads") if isinstance(meta.get("persona_threads"), list) else [] ) recent = ( meta.get("recent_questions") if isinstance(meta.get("recent_questions"), list) else [] ) return ( [KnownFact.model_validate(x) for x in known if isinstance(x, dict)], [PersonaThread.model_validate(x) for x in persona if isinstance(x, dict)], [str(x).strip() for x in recent if str(x).strip()], ) def _inject_interview_state_meta( *, slots: Dict[str, Dict], known_facts: list[KnownFact], persona_threads: list[PersonaThread], recent_questions: list[str], ) -> Dict[str, Dict]: out = dict(slots) out[_INTERVIEW_STATE_META_KEY] = cast( Dict, { "known_facts": [x.model_dump() for x in known_facts], "persona_threads": [x.model_dump() for x in persona_threads], "recent_questions": list(recent_questions), }, ) return out def coerce_memoir_state(model: MemoirStateModel) -> MemoirStateSchema: raw_slots = model.slots if isinstance(model.slots, dict) else None known_facts, persona_threads, recent_questions = _extract_interview_state_meta( raw_slots ) clean_slots = dict(raw_slots) if raw_slots else dict(default_state().slots) clean_slots.pop(_INTERVIEW_STATE_META_KEY, None) return MemoirStateSchema.model_validate( { "stage_order": model.stage_order or default_state().stage_order, "current_stage": model.current_stage, "covered_stages": model.covered_stages or [], "slots": clean_slots, "known_facts": known_facts, "persona_threads": persona_threads, "recent_questions": recent_questions, } ) async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSchema: stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) result = await db.execute(stmt) state = result.scalar_one_or_none() if state: return coerce_memoir_state(state) default = default_state() state = MemoirStateModel( id=str(uuid.uuid4()), user_id=user_id, stage_order=default.stage_order, current_stage=default.current_stage, covered_stages=default.covered_stages, slots={ k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items() }, ) db.add(state) await db.commit() await db.refresh(state) return coerce_memoir_state(state) def _apply_current_stage_policy( state: MemoirStateModel, stage_norm: str, *, memoir_batch: bool, ) -> None: """按 memoir_extraction_updates_current_stage 与 chat_bucket 真值表更新 current_stage。""" current_from_db = state.current_stage or "childhood" if not memoir_batch: state.current_stage = stage_norm return if not settings.memoir_extraction_updates_current_stage: return cur_b = chat_bucket(state.current_stage or current_from_db) new_b = chat_bucket(stage_norm) if new_b == cur_b: state.current_stage = stage_norm async def update_slot( user_id: str, stage: str, slot_name: str, snippet: str, segment_ids: List[str], db: AsyncSession, *, memoir_batch: bool = False, ) -> MemoirStateSchema: stmt = ( select(MemoirStateModel) .where(MemoirStateModel.user_id == user_id) .with_for_update() ) result = await db.execute(stmt) state = result.scalar_one_or_none() if not state: await get_or_create_state(user_id, db) result = await db.execute(stmt) state = result.scalar_one() current_from_db = state.current_stage or "childhood" stage_norm = normalize_chat_stage( stage, fallback=current_from_db, log_context={"user_id": user_id}, ) slots = _slots_snapshot_for_merge( state.slots if isinstance(state.slots, dict) else None ) stage_slots = dict(slots.get(stage_norm, {}) or {}) existing = stage_slots.get(slot_name, {}) merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) stage_slots[slot_name] = SlotData( snippet=snippet, segment_ids=merged_segment_ids ).model_dump() slots[stage_norm] = stage_slots state.slots = slots _apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch) await db.commit() await db.refresh(state) return coerce_memoir_state(state) async def mark_stage_complete( user_id: str, stage: str, db: AsyncSession ) -> MemoirStateSchema: """推进 covered_stages 并在当前阶段匹配时尝试进入下一阶段。当前无调用方,预留未来阶段推进逻辑。""" stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) result = await db.execute(stmt) state = result.scalar_one_or_none() if not state: return await get_or_create_state(user_id, db) covered = state.covered_stages or [] if stage not in covered: covered.append(stage) state.covered_stages = covered stage_order = state.stage_order or default_state().stage_order if state.current_stage == stage: try: idx = stage_order.index(stage) state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)] except ValueError: state.current_stage = default_state().current_stage await db.commit() await db.refresh(state) return coerce_memoir_state(state) async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]: state = await get_or_create_state(user_id, db) return state.empty_slots_for_current_stage() async def switch_stage( user_id: str, new_stage: str, db: AsyncSession ) -> MemoirStateSchema: stmt = ( select(MemoirStateModel) .where(MemoirStateModel.user_id == user_id) .with_for_update() ) result = await db.execute(stmt) state = result.scalar_one_or_none() if not state: await get_or_create_state(user_id, db) result = await db.execute(stmt) state = result.scalar_one() fb = state.current_stage or "childhood" state.current_stage = normalize_chat_stage( new_stage, fallback=fb, log_context={"user_id": user_id} ) await db.commit() await db.refresh(state) return coerce_memoir_state(state) async def save_interview_state_meta( user_id: str, *, known_facts: list[KnownFact], persona_threads: list[PersonaThread], recent_questions: list[str], db: AsyncSession, ) -> MemoirStateSchema: stmt = ( select(MemoirStateModel) .where(MemoirStateModel.user_id == user_id) .with_for_update() ) result = await db.execute(stmt) state = result.scalar_one_or_none() if not state: await get_or_create_state(user_id, db) result = await db.execute(stmt) state = result.scalar_one() slots = _slots_snapshot_for_merge( state.slots if isinstance(state.slots, dict) else None ) state.slots = _inject_interview_state_meta( slots=slots, known_facts=known_facts, persona_threads=persona_threads, recent_questions=recent_questions, ) await db.commit() await db.refresh(state) return coerce_memoir_state(state) def get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema: stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) result = db.execute(stmt) state = result.scalar_one_or_none() if state: return coerce_memoir_state(state) default = default_state() state = MemoirStateModel( id=str(uuid.uuid4()), user_id=user_id, stage_order=default.stage_order, current_stage=default.current_stage, covered_stages=default.covered_stages, slots={ k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items() }, ) db.add(state) db.commit() db.refresh(state) return coerce_memoir_state(state) def update_slot_sync( user_id: str, stage: str, slot_name: str, snippet: str, segment_ids: List[str], db: Session, *, memoir_batch: bool = True, ) -> MemoirStateSchema: stmt = ( select(MemoirStateModel) .where(MemoirStateModel.user_id == user_id) .with_for_update() ) result = db.execute(stmt) state = result.scalar_one_or_none() if not state: get_or_create_state_sync(user_id, db) result = db.execute(stmt) state = result.scalar_one() current_from_db = state.current_stage or "childhood" stage_norm = normalize_chat_stage( stage, fallback=current_from_db, log_context={"user_id": user_id}, ) slots = _slots_snapshot_for_merge( state.slots if isinstance(state.slots, dict) else None ) stage_slots = dict(slots.get(stage_norm, {}) or {}) existing = stage_slots.get(slot_name, {}) merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) stage_slots[slot_name] = SlotData( snippet=snippet, segment_ids=merged_segment_ids ).model_dump() slots[stage_norm] = stage_slots state.slots = slots _apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch) db.commit() db.refresh(state) return coerce_memoir_state(state)