117 lines
3.8 KiB
Python
117 lines
3.8 KiB
Python
"""
|
|
回忆录状态服务
|
|
"""
|
|
import uuid
|
|
from typing import Dict, List
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from agents.state_schema import MemoirStateSchema, SlotData, default_state
|
|
from database.models import MemoirState as MemoirStateModel
|
|
|
|
|
|
def _coerce_state(model: MemoirStateModel) -> MemoirStateSchema:
|
|
return MemoirStateSchema.model_validate(
|
|
{
|
|
"stage_order": model.stage_order or default_state().stage_order,
|
|
"current_stage": model.current_stage,
|
|
"covered_stages": model.covered_stages or [],
|
|
"slots": model.slots if isinstance(model.slots, dict) else default_state().slots,
|
|
}
|
|
)
|
|
|
|
|
|
async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSchema:
|
|
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
state = result.scalar_one_or_none()
|
|
if state:
|
|
return _coerce_state(state)
|
|
|
|
default = default_state()
|
|
state = MemoirStateModel(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
stage_order=default.stage_order,
|
|
current_stage=default.current_stage,
|
|
covered_stages=default.covered_stages,
|
|
slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()},
|
|
)
|
|
db.add(state)
|
|
await db.commit()
|
|
await db.refresh(state)
|
|
return _coerce_state(state)
|
|
|
|
|
|
async def update_slot(
|
|
user_id: str,
|
|
stage: str,
|
|
slot_name: str,
|
|
snippet: str,
|
|
segment_ids: List[str],
|
|
db: AsyncSession,
|
|
) -> MemoirStateSchema:
|
|
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
state = result.scalar_one_or_none()
|
|
if not state:
|
|
await get_or_create_state(user_id, db)
|
|
result = await db.execute(stmt)
|
|
state = result.scalar_one()
|
|
|
|
slots: Dict[str, Dict] = state.slots or {}
|
|
stage_slots = slots.get(stage, {})
|
|
existing = stage_slots.get(slot_name, {})
|
|
|
|
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
|
|
stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump()
|
|
slots[stage] = stage_slots
|
|
state.slots = slots
|
|
state.current_stage = state.current_stage or stage
|
|
await db.commit()
|
|
await db.refresh(state)
|
|
return _coerce_state(state)
|
|
|
|
|
|
async def mark_stage_complete(user_id: str, stage: str, db: AsyncSession) -> MemoirStateSchema:
|
|
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
state = result.scalar_one_or_none()
|
|
if not state:
|
|
return await get_or_create_state(user_id, db)
|
|
|
|
covered = state.covered_stages or []
|
|
if stage not in covered:
|
|
covered.append(stage)
|
|
state.covered_stages = covered
|
|
|
|
stage_order = state.stage_order or default_state().stage_order
|
|
if state.current_stage == stage:
|
|
try:
|
|
idx = stage_order.index(stage)
|
|
state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)]
|
|
except ValueError:
|
|
state.current_stage = default_state().current_stage
|
|
await db.commit()
|
|
await db.refresh(state)
|
|
return _coerce_state(state)
|
|
|
|
|
|
async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]:
|
|
state = await get_or_create_state(user_id, db)
|
|
return state.empty_slots_for_current_stage()
|
|
|
|
|
|
async def switch_stage(user_id: str, new_stage: str, db: AsyncSession) -> MemoirStateSchema:
|
|
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
state = result.scalar_one_or_none()
|
|
if not state:
|
|
return await get_or_create_state(user_id, db)
|
|
|
|
state.current_stage = new_stage
|
|
await db.commit()
|
|
await db.refresh(state)
|
|
return _coerce_state(state)
|