Files
life-echo/api/app/features/memoir/state_service.py

260 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录状态服务get_or_create_state、update_slot、mark_stage_complete 等。
供 memoir service、conversation ws 使用Celery 任务内使用同步版本。
"""
import uuid
from typing import Dict, List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.agents.state_schema import MemoirStateSchema, SlotData, default_state
from app.agents.stage_constants import (
chat_bucket,
normalize_chat_stage,
)
from app.core.config import settings
from app.features.memoir.models import MemoirState as MemoirStateModel
def _slots_snapshot_for_merge(raw: Dict[str, Dict] | None) -> Dict[str, Dict]:
"""浅拷贝 slots避免就地改 JSON 列同一 dict 引用导致 ORM 不标记 dirty。"""
if not raw or not isinstance(raw, dict):
return {}
return {k: dict(v or {}) for k, v in raw.items()}
def coerce_memoir_state(model: MemoirStateModel) -> MemoirStateSchema:
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": model.slots
if isinstance(model.slots, dict)
else default_state().slots,
}
)
async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return coerce_memoir_state(state)
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
db.add(state)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
def _apply_current_stage_policy(
state: MemoirStateModel,
stage_norm: str,
*,
memoir_batch: bool,
) -> None:
"""按 memoir_extraction_updates_current_stage 与 chat_bucket 真值表更新 current_stage。"""
current_from_db = state.current_stage or "childhood"
if not memoir_batch:
state.current_stage = stage_norm
return
if not settings.memoir_extraction_updates_current_stage:
return
cur_b = chat_bucket(state.current_stage or current_from_db)
new_b = chat_bucket(stage_norm)
if new_b == cur_b:
state.current_stage = stage_norm
async def update_slot(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: AsyncSession,
*,
memoir_batch: bool = False,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
current_from_db = state.current_stage or "childhood"
stage_norm = normalize_chat_stage(
stage,
fallback=current_from_db,
log_context={"user_id": user_id},
)
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
stage_slots = dict(slots.get(stage_norm, {}) or {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage_norm] = stage_slots
state.slots = slots
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
async def mark_stage_complete(
user_id: str, stage: str, db: AsyncSession
) -> MemoirStateSchema:
"""推进 covered_stages 并在当前阶段匹配时尝试进入下一阶段。当前无调用方,预留未来阶段推进逻辑。"""
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
return await get_or_create_state(user_id, db)
covered = state.covered_stages or []
if stage not in covered:
covered.append(stage)
state.covered_stages = covered
stage_order = state.stage_order or default_state().stage_order
if state.current_stage == stage:
try:
idx = stage_order.index(stage)
state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)]
except ValueError:
state.current_stage = default_state().current_stage
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]:
state = await get_or_create_state(user_id, db)
return state.empty_slots_for_current_stage()
async def switch_stage(
user_id: str, new_stage: str, db: AsyncSession
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
fb = state.current_stage or "childhood"
state.current_stage = normalize_chat_stage(
new_stage, fallback=fb, log_context={"user_id": user_id}
)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
def get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return coerce_memoir_state(state)
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
db.add(state)
db.commit()
db.refresh(state)
return coerce_memoir_state(state)
def update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: Session,
*,
memoir_batch: bool = True,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
current_from_db = state.current_stage or "childhood"
stage_norm = normalize_chat_stage(
stage,
fallback=current_from_db,
log_context={"user_id": user_id},
)
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
stage_slots = dict(slots.get(stage_norm, {}) or {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage_norm] = stage_slots
state.slots = slots
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
db.commit()
db.refresh(state)
return coerce_memoir_state(state)