Files
life-echo/api/app/features/memoir/state_service.py

309 lines
9.5 KiB
Python
Raw Normal View History

2026-01-21 22:31:03 +01:00
"""
回忆录状态服务get_or_create_stateupdate_slotmark_stage_complete
memoir serviceconversation ws 使用Celery 任务内使用同步版本
2026-01-21 22:31:03 +01:00
"""
2026-03-19 14:36:14 +08:00
2026-01-21 22:31:03 +01:00
import uuid
from typing import Dict, List
2026-01-21 22:31:03 +01:00
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
2026-01-21 22:31:03 +01:00
from app.agents.stage_constants import (
chat_bucket,
normalize_chat_stage,
)
from app.agents.state_schema import (
InterviewControlState,
KnownFact,
MemoirStateSchema,
PersonaThread,
SlotData,
default_state,
)
from app.core.config import settings
from app.features.memoir import _interview_meta_store as interview_meta
from app.features.memoir.models import MemoirState as MemoirStateModel
2026-01-21 22:31:03 +01:00
def _slots_snapshot_for_merge(raw: Dict[str, Dict] | None) -> Dict[str, Dict]:
"""浅拷贝 slots避免就地改 JSON 列同一 dict 引用导致 ORM 不标记 dirty。"""
if not raw or not isinstance(raw, dict):
return {}
return {k: dict(v or {}) for k, v in raw.items()}
def coerce_memoir_state(model: MemoirStateModel) -> MemoirStateSchema:
"""把 ORM 行投影成 MemoirStateSchema控制元数据的读法已隔离在 interview_meta 适配层。"""
raw_slots = model.slots if isinstance(model.slots, dict) else None
control = interview_meta.read(raw_slots)
clean_slots = interview_meta.strip(raw_slots) or dict(default_state().slots)
2026-01-21 22:31:03 +01:00
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": clean_slots,
"known_facts": control.known_facts,
"persona_threads": control.persona_threads,
"recent_questions": control.recent_questions,
2026-01-21 22:31:03 +01:00
}
)
async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return coerce_memoir_state(state)
2026-01-21 22:31:03 +01:00
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
2026-03-19 14:36:14 +08:00
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
2026-01-21 22:31:03 +01:00
)
db.add(state)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
def _apply_current_stage_policy(
state: MemoirStateModel,
stage_norm: str,
*,
memoir_batch: bool,
) -> None:
"""按 memoir_extraction_updates_current_stage 与 chat_bucket 真值表更新 current_stage。"""
current_from_db = state.current_stage or "childhood"
if not memoir_batch:
state.current_stage = stage_norm
return
if not settings.memoir_extraction_updates_current_stage:
return
cur_b = chat_bucket(state.current_stage or current_from_db)
new_b = chat_bucket(stage_norm)
if new_b == cur_b:
state.current_stage = stage_norm
2026-01-21 22:31:03 +01:00
async def update_slot(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: AsyncSession,
*,
memoir_batch: bool = False,
2026-01-21 22:31:03 +01:00
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
2026-01-21 22:31:03 +01:00
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
current_from_db = state.current_stage or "childhood"
stage_norm = normalize_chat_stage(
stage,
fallback=current_from_db,
log_context={"user_id": user_id},
)
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
stage_slots = dict(slots.get(stage_norm, {}) or {})
2026-01-21 22:31:03 +01:00
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
2026-03-19 14:36:14 +08:00
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage_norm] = stage_slots
2026-01-21 22:31:03 +01:00
state.slots = slots
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
2026-01-21 22:31:03 +01:00
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
2026-01-21 22:31:03 +01:00
2026-03-19 14:36:14 +08:00
async def mark_stage_complete(
user_id: str, stage: str, db: AsyncSession
) -> MemoirStateSchema:
"""推进 covered_stages 并在当前阶段匹配时尝试进入下一阶段。当前无调用方,预留未来阶段推进逻辑。"""
2026-01-21 22:31:03 +01:00
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
return await get_or_create_state(user_id, db)
covered = state.covered_stages or []
if stage not in covered:
covered.append(stage)
state.covered_stages = covered
stage_order = state.stage_order or default_state().stage_order
if state.current_stage == stage:
try:
idx = stage_order.index(stage)
state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)]
except ValueError:
state.current_stage = default_state().current_stage
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
2026-01-21 22:31:03 +01:00
async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]:
state = await get_or_create_state(user_id, db)
return state.empty_slots_for_current_stage()
2026-03-19 14:36:14 +08:00
async def switch_stage(
user_id: str, new_stage: str, db: AsyncSession
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
2026-01-21 22:31:03 +01:00
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
2026-01-21 22:31:03 +01:00
fb = state.current_stage or "childhood"
state.current_stage = normalize_chat_stage(
new_stage, fallback=fb, log_context={"user_id": user_id}
)
2026-01-21 22:31:03 +01:00
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
async def save_interview_state_meta(
user_id: str,
*,
known_facts: list[KnownFact],
persona_threads: list[PersonaThread],
recent_questions: list[str],
db: AsyncSession,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
state.slots = interview_meta.write(
slots,
control=InterviewControlState(
known_facts=known_facts,
persona_threads=persona_threads,
recent_questions=recent_questions,
),
)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
def get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return coerce_memoir_state(state)
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
db.add(state)
db.commit()
db.refresh(state)
return coerce_memoir_state(state)
def update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: Session,
*,
memoir_batch: bool = True,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
current_from_db = state.current_stage or "childhood"
stage_norm = normalize_chat_stage(
stage,
fallback=current_from_db,
log_context={"user_id": user_id},
)
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
stage_slots = dict(slots.get(stage_norm, {}) or {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage_norm] = stage_slots
state.slots = slots
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
db.commit()
db.refresh(state)
return coerce_memoir_state(state)