Files
life-echo/api/app/features/memoir/state_service.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

319 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录状态服务get_or_create_state、update_slot、mark_stage_complete 等。
供 memoir service、conversation ws 使用Celery 任务内使用同步版本。
"""
import uuid
from typing import Dict, List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.agents.stage_constants import (
allowed_slot_names_for_stage,
chat_bucket,
normalize_chat_stage,
)
from app.agents.state_schema import (
KnownFact,
MemoirStateSchema,
PersonaThread,
SlotData,
default_state,
narrative_coverage_state,
)
from app.core.config import settings
from app.core.db import transactional, transactional_sync
from app.features.memoir.models import MemoirState as MemoirStateModel
from app.features.memoir.constants import memoir
def _slots_snapshot_for_merge(raw: Dict[str, Dict] | None) -> Dict[str, Dict]:
"""浅拷贝 slots避免就地改 JSON 列同一 dict 引用导致 ORM 不标记 dirty。"""
if not raw or not isinstance(raw, dict):
return {}
return {k: dict(v or {}) for k, v in raw.items()}
def coerce_memoir_state(model: MemoirStateModel) -> MemoirStateSchema:
"""把 ORM 行投影成 MemoirStateSchema控制元数据来自独立列。"""
raw_slots = model.slots if isinstance(model.slots, dict) else None
clean_slots = raw_slots or dict(default_state().slots)
known_raw = (
model.known_facts_json if isinstance(model.known_facts_json, list) else []
)
persona_raw = (
model.persona_threads_json
if isinstance(model.persona_threads_json, list)
else []
)
recent_raw = (
model.recent_questions_json
if isinstance(model.recent_questions_json, list)
else []
)
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": clean_slots,
"known_facts": [x for x in known_raw if isinstance(x, dict)],
"persona_threads": [x for x in persona_raw if isinstance(x, dict)],
"recent_questions": [str(x).strip() for x in recent_raw if str(x).strip()],
}
)
async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return coerce_memoir_state(state)
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
async with transactional(db):
db.add(state)
await db.refresh(state)
return coerce_memoir_state(state)
def _apply_current_stage_policy(
state: MemoirStateModel,
stage_norm: str,
*,
memoir_batch: bool,
) -> None:
"""按 memoir_extraction_updates_current_stage 与 chat_bucket 真值表更新 current_stage。"""
current_from_db = state.current_stage or "childhood"
if not memoir_batch:
state.current_stage = stage_norm
return
if not memoir.extraction_updates_current_stage:
return
cur_b = chat_bucket(state.current_stage or current_from_db)
new_b = chat_bucket(stage_norm)
if new_b == cur_b:
state.current_stage = stage_norm
async def update_slot(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: AsyncSession,
*,
memoir_batch: bool = False,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
current_from_db = state.current_stage or "childhood"
stage_norm = normalize_chat_stage(
stage,
fallback=current_from_db,
log_context={"user_id": user_id},
)
if slot_name not in allowed_slot_names_for_stage(stage_norm, current_from_db):
return coerce_memoir_state(state)
async with transactional(db):
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
stage_slots = dict(slots.get(stage_norm, {}) or {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage_norm] = stage_slots
state.slots = slots
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
await db.refresh(state)
return coerce_memoir_state(state)
async def mark_stage_complete(
user_id: str, stage: str, db: AsyncSession
) -> MemoirStateSchema:
"""推进 covered_stages 并在当前阶段匹配时尝试进入下一阶段。当前无调用方,预留未来阶段推进逻辑。"""
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
return await get_or_create_state(user_id, db)
async with transactional(db):
covered = state.covered_stages or []
if stage not in covered:
covered.append(stage)
state.covered_stages = covered
stage_order = state.stage_order or default_state().stage_order
if state.current_stage == stage:
try:
idx = stage_order.index(stage)
state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)]
except ValueError:
state.current_stage = default_state().current_stage
await db.refresh(state)
return coerce_memoir_state(state)
async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]:
state = await get_or_create_state(user_id, db)
return narrative_coverage_state(state).empty_slots_for_current_stage()
async def switch_stage(
user_id: str, new_stage: str, db: AsyncSession
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
async with transactional(db):
fb = state.current_stage or "childhood"
state.current_stage = normalize_chat_stage(
new_stage, fallback=fb, log_context={"user_id": user_id}
)
await db.refresh(state)
return coerce_memoir_state(state)
async def save_interview_state_meta(
user_id: str,
*,
known_facts: list[KnownFact],
persona_threads: list[PersonaThread],
recent_questions: list[str],
db: AsyncSession,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
async with transactional(db):
state.known_facts_json = [x.model_dump() for x in known_facts]
state.persona_threads_json = [x.model_dump() for x in persona_threads]
state.recent_questions_json = list(recent_questions)
await db.refresh(state)
return coerce_memoir_state(state)
def get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return coerce_memoir_state(state)
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
with transactional_sync(db):
db.add(state)
db.refresh(state)
return coerce_memoir_state(state)
def update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: Session,
*,
memoir_batch: bool = True,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
current_from_db = state.current_stage or "childhood"
stage_norm = normalize_chat_stage(
stage,
fallback=current_from_db,
log_context={"user_id": user_id},
)
if slot_name not in allowed_slot_names_for_stage(stage_norm, current_from_db):
return coerce_memoir_state(state)
with transactional_sync(db):
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
stage_slots = dict(slots.get(stage_norm, {}) or {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage_norm] = stage_slots
state.slots = slots
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
db.refresh(state)
return coerce_memoir_state(state)