""" 回忆录状态服务:get_or_create_state、update_slot、mark_stage_complete 等。 供 memoir service、conversation ws 使用;Celery 任务内使用同步版本(见 tasks/memoir_tasks)。 """ import uuid from typing import Dict, List from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.agents.state_schema import MemoirStateSchema, SlotData, default_state from app.features.memoir.models import MemoirState as MemoirStateModel def _coerce_state(model: MemoirStateModel) -> MemoirStateSchema: return MemoirStateSchema.model_validate( { "stage_order": model.stage_order or default_state().stage_order, "current_stage": model.current_stage, "covered_stages": model.covered_stages or [], "slots": model.slots if isinstance(model.slots, dict) else default_state().slots, } ) async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSchema: stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) result = await db.execute(stmt) state = result.scalar_one_or_none() if state: return _coerce_state(state) default = default_state() state = MemoirStateModel( id=str(uuid.uuid4()), user_id=user_id, stage_order=default.stage_order, current_stage=default.current_stage, covered_stages=default.covered_stages, slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()}, ) db.add(state) await db.commit() await db.refresh(state) return _coerce_state(state) async def update_slot( user_id: str, stage: str, slot_name: str, snippet: str, segment_ids: List[str], db: AsyncSession, ) -> MemoirStateSchema: stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) result = await db.execute(stmt) state = result.scalar_one_or_none() if not state: await get_or_create_state(user_id, db) result = await db.execute(stmt) state = result.scalar_one() slots: Dict[str, Dict] = state.slots or {} stage_slots = slots.get(stage, {}) existing = stage_slots.get(slot_name, {}) merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump() slots[stage] = stage_slots state.slots = slots state.current_stage = state.current_stage or stage await db.commit() await db.refresh(state) return _coerce_state(state) async def mark_stage_complete(user_id: str, stage: str, db: AsyncSession) -> MemoirStateSchema: stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) result = await db.execute(stmt) state = result.scalar_one_or_none() if not state: return await get_or_create_state(user_id, db) covered = state.covered_stages or [] if stage not in covered: covered.append(stage) state.covered_stages = covered stage_order = state.stage_order or default_state().stage_order if state.current_stage == stage: try: idx = stage_order.index(stage) state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)] except ValueError: state.current_stage = default_state().current_stage await db.commit() await db.refresh(state) return _coerce_state(state) async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]: state = await get_or_create_state(user_id, db) return state.empty_slots_for_current_stage() async def switch_stage(user_id: str, new_stage: str, db: AsyncSession) -> MemoirStateSchema: stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) result = await db.execute(stmt) state = result.scalar_one_or_none() if not state: return await get_or_create_state(user_id, db) state.current_stage = new_stage await db.commit() await db.refresh(state) return _coerce_state(state)