Files
life-echo/api/app/features/memoir/state_service.py
2026-03-19 14:36:40 +08:00

130 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录状态服务get_or_create_state、update_slot、mark_stage_complete 等。
供 memoir service、conversation ws 使用Celery 任务内使用同步版本(见 tasks/memoir_tasks
"""
import uuid
from typing import Dict, List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.state_schema import MemoirStateSchema, SlotData, default_state
from app.features.memoir.models import MemoirState as MemoirStateModel
def _coerce_state(model: MemoirStateModel) -> MemoirStateSchema:
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": model.slots
if isinstance(model.slots, dict)
else default_state().slots,
}
)
async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return _coerce_state(state)
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
db.add(state)
await db.commit()
await db.refresh(state)
return _coerce_state(state)
async def update_slot(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: AsyncSession,
) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
slots: Dict[str, Dict] = state.slots or {}
stage_slots = slots.get(stage, {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage] = stage_slots
state.slots = slots
state.current_stage = state.current_stage or stage
await db.commit()
await db.refresh(state)
return _coerce_state(state)
async def mark_stage_complete(
user_id: str, stage: str, db: AsyncSession
) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
return await get_or_create_state(user_id, db)
covered = state.covered_stages or []
if stage not in covered:
covered.append(stage)
state.covered_stages = covered
stage_order = state.stage_order or default_state().stage_order
if state.current_stage == stage:
try:
idx = stage_order.index(stage)
state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)]
except ValueError:
state.current_stage = default_state().current_stage
await db.commit()
await db.refresh(state)
return _coerce_state(state)
async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]:
state = await get_or_create_state(user_id, db)
return state.empty_slots_for_current_stage()
async def switch_stage(
user_id: str, new_stage: str, db: AsyncSession
) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
return await get_or_create_state(user_id, db)
state.current_stage = new_stage
await db.commit()
await db.refresh(state)
return _coerce_state(state)