2026-01-21 22:31:03 +01:00
|
|
|
|
"""
|
2026-03-18 17:18:23 +08:00
|
|
|
|
回忆录状态服务:get_or_create_state、update_slot、mark_stage_complete 等。
|
2026-04-02 16:37:14 +08:00
|
|
|
|
供 memoir service、conversation ws 使用;Celery 任务内使用同步版本。
|
2026-01-21 22:31:03 +01:00
|
|
|
|
"""
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-01-21 22:31:03 +01:00
|
|
|
|
import uuid
|
2026-04-08 21:36:12 +08:00
|
|
|
|
from typing import Dict, List, cast
|
2026-01-21 22:31:03 +01:00
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2026-04-02 16:37:14 +08:00
|
|
|
|
from sqlalchemy.orm import Session
|
2026-01-21 22:31:03 +01:00
|
|
|
|
|
2026-04-02 16:37:14 +08:00
|
|
|
|
from app.agents.stage_constants import (
|
|
|
|
|
|
chat_bucket,
|
|
|
|
|
|
normalize_chat_stage,
|
|
|
|
|
|
)
|
2026-04-08 21:36:12 +08:00
|
|
|
|
from app.agents.state_schema import (
|
|
|
|
|
|
KnownFact,
|
|
|
|
|
|
MemoirStateSchema,
|
|
|
|
|
|
PersonaThread,
|
|
|
|
|
|
SlotData,
|
|
|
|
|
|
default_state,
|
|
|
|
|
|
)
|
2026-04-02 16:37:14 +08:00
|
|
|
|
from app.core.config import settings
|
2026-03-18 17:18:23 +08:00
|
|
|
|
from app.features.memoir.models import MemoirState as MemoirStateModel
|
2026-01-21 22:31:03 +01:00
|
|
|
|
|
2026-04-08 21:36:12 +08:00
|
|
|
|
_INTERVIEW_STATE_META_KEY = "__interview_state__"
|
|
|
|
|
|
|
2026-01-21 22:31:03 +01:00
|
|
|
|
|
2026-04-02 16:37:14 +08:00
|
|
|
|
def _slots_snapshot_for_merge(raw: Dict[str, Dict] | None) -> Dict[str, Dict]:
|
|
|
|
|
|
"""浅拷贝 slots,避免就地改 JSON 列同一 dict 引用导致 ORM 不标记 dirty。"""
|
|
|
|
|
|
if not raw or not isinstance(raw, dict):
|
|
|
|
|
|
return {}
|
|
|
|
|
|
return {k: dict(v or {}) for k, v in raw.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-08 21:36:12 +08:00
|
|
|
|
def _extract_interview_state_meta(
|
|
|
|
|
|
raw_slots: Dict[str, Dict] | None,
|
|
|
|
|
|
) -> tuple[list[KnownFact], list[PersonaThread], list[str]]:
|
|
|
|
|
|
if not raw_slots or not isinstance(raw_slots, dict):
|
|
|
|
|
|
return [], [], []
|
|
|
|
|
|
meta = raw_slots.get(_INTERVIEW_STATE_META_KEY)
|
|
|
|
|
|
if not isinstance(meta, dict):
|
|
|
|
|
|
return [], [], []
|
|
|
|
|
|
known = meta.get("known_facts") if isinstance(meta.get("known_facts"), list) else []
|
|
|
|
|
|
persona = (
|
|
|
|
|
|
meta.get("persona_threads")
|
|
|
|
|
|
if isinstance(meta.get("persona_threads"), list)
|
|
|
|
|
|
else []
|
|
|
|
|
|
)
|
|
|
|
|
|
recent = (
|
|
|
|
|
|
meta.get("recent_questions")
|
|
|
|
|
|
if isinstance(meta.get("recent_questions"), list)
|
|
|
|
|
|
else []
|
|
|
|
|
|
)
|
|
|
|
|
|
return (
|
|
|
|
|
|
[KnownFact.model_validate(x) for x in known if isinstance(x, dict)],
|
|
|
|
|
|
[PersonaThread.model_validate(x) for x in persona if isinstance(x, dict)],
|
|
|
|
|
|
[str(x).strip() for x in recent if str(x).strip()],
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _inject_interview_state_meta(
|
|
|
|
|
|
*,
|
|
|
|
|
|
slots: Dict[str, Dict],
|
|
|
|
|
|
known_facts: list[KnownFact],
|
|
|
|
|
|
persona_threads: list[PersonaThread],
|
|
|
|
|
|
recent_questions: list[str],
|
|
|
|
|
|
) -> Dict[str, Dict]:
|
|
|
|
|
|
out = dict(slots)
|
|
|
|
|
|
out[_INTERVIEW_STATE_META_KEY] = cast(
|
|
|
|
|
|
Dict,
|
|
|
|
|
|
{
|
|
|
|
|
|
"known_facts": [x.model_dump() for x in known_facts],
|
|
|
|
|
|
"persona_threads": [x.model_dump() for x in persona_threads],
|
|
|
|
|
|
"recent_questions": list(recent_questions),
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-02 16:37:14 +08:00
|
|
|
|
def coerce_memoir_state(model: MemoirStateModel) -> MemoirStateSchema:
|
2026-04-08 21:36:12 +08:00
|
|
|
|
raw_slots = model.slots if isinstance(model.slots, dict) else None
|
|
|
|
|
|
known_facts, persona_threads, recent_questions = _extract_interview_state_meta(
|
|
|
|
|
|
raw_slots
|
|
|
|
|
|
)
|
|
|
|
|
|
clean_slots = dict(raw_slots) if raw_slots else dict(default_state().slots)
|
|
|
|
|
|
clean_slots.pop(_INTERVIEW_STATE_META_KEY, None)
|
2026-01-21 22:31:03 +01:00
|
|
|
|
return MemoirStateSchema.model_validate(
|
|
|
|
|
|
{
|
|
|
|
|
|
"stage_order": model.stage_order or default_state().stage_order,
|
|
|
|
|
|
"current_stage": model.current_stage,
|
|
|
|
|
|
"covered_stages": model.covered_stages or [],
|
2026-04-08 21:36:12 +08:00
|
|
|
|
"slots": clean_slots,
|
|
|
|
|
|
"known_facts": known_facts,
|
|
|
|
|
|
"persona_threads": persona_threads,
|
|
|
|
|
|
"recent_questions": recent_questions,
|
2026-01-21 22:31:03 +01:00
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSchema:
|
|
|
|
|
|
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one_or_none()
|
|
|
|
|
|
if state:
|
2026-04-02 16:37:14 +08:00
|
|
|
|
return coerce_memoir_state(state)
|
2026-01-21 22:31:03 +01:00
|
|
|
|
|
|
|
|
|
|
default = default_state()
|
|
|
|
|
|
state = MemoirStateModel(
|
|
|
|
|
|
id=str(uuid.uuid4()),
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
stage_order=default.stage_order,
|
|
|
|
|
|
current_stage=default.current_stage,
|
|
|
|
|
|
covered_stages=default.covered_stages,
|
2026-03-19 14:36:14 +08:00
|
|
|
|
slots={
|
|
|
|
|
|
k: {sk: sv.model_dump() for sk, sv in v.items()}
|
|
|
|
|
|
for k, v in default.slots.items()
|
|
|
|
|
|
},
|
2026-01-21 22:31:03 +01:00
|
|
|
|
)
|
|
|
|
|
|
db.add(state)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
await db.refresh(state)
|
2026-04-02 16:37:14 +08:00
|
|
|
|
return coerce_memoir_state(state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _apply_current_stage_policy(
|
|
|
|
|
|
state: MemoirStateModel,
|
|
|
|
|
|
stage_norm: str,
|
|
|
|
|
|
*,
|
|
|
|
|
|
memoir_batch: bool,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
"""按 memoir_extraction_updates_current_stage 与 chat_bucket 真值表更新 current_stage。"""
|
|
|
|
|
|
current_from_db = state.current_stage or "childhood"
|
|
|
|
|
|
if not memoir_batch:
|
|
|
|
|
|
state.current_stage = stage_norm
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if not settings.memoir_extraction_updates_current_stage:
|
|
|
|
|
|
return
|
|
|
|
|
|
cur_b = chat_bucket(state.current_stage or current_from_db)
|
|
|
|
|
|
new_b = chat_bucket(stage_norm)
|
|
|
|
|
|
if new_b == cur_b:
|
|
|
|
|
|
state.current_stage = stage_norm
|
2026-01-21 22:31:03 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def update_slot(
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
stage: str,
|
|
|
|
|
|
slot_name: str,
|
|
|
|
|
|
snippet: str,
|
|
|
|
|
|
segment_ids: List[str],
|
|
|
|
|
|
db: AsyncSession,
|
2026-04-02 16:37:14 +08:00
|
|
|
|
*,
|
|
|
|
|
|
memoir_batch: bool = False,
|
2026-01-21 22:31:03 +01:00
|
|
|
|
) -> MemoirStateSchema:
|
2026-04-02 16:37:14 +08:00
|
|
|
|
stmt = (
|
|
|
|
|
|
select(MemoirStateModel)
|
|
|
|
|
|
.where(MemoirStateModel.user_id == user_id)
|
|
|
|
|
|
.with_for_update()
|
|
|
|
|
|
)
|
2026-01-21 22:31:03 +01:00
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one_or_none()
|
|
|
|
|
|
if not state:
|
|
|
|
|
|
await get_or_create_state(user_id, db)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one()
|
|
|
|
|
|
|
2026-04-02 16:37:14 +08:00
|
|
|
|
current_from_db = state.current_stage or "childhood"
|
|
|
|
|
|
stage_norm = normalize_chat_stage(
|
|
|
|
|
|
stage,
|
|
|
|
|
|
fallback=current_from_db,
|
|
|
|
|
|
log_context={"user_id": user_id},
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
slots = _slots_snapshot_for_merge(
|
|
|
|
|
|
state.slots if isinstance(state.slots, dict) else None
|
|
|
|
|
|
)
|
|
|
|
|
|
stage_slots = dict(slots.get(stage_norm, {}) or {})
|
2026-01-21 22:31:03 +01:00
|
|
|
|
existing = stage_slots.get(slot_name, {})
|
|
|
|
|
|
|
|
|
|
|
|
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
|
2026-03-19 14:36:14 +08:00
|
|
|
|
stage_slots[slot_name] = SlotData(
|
|
|
|
|
|
snippet=snippet, segment_ids=merged_segment_ids
|
|
|
|
|
|
).model_dump()
|
2026-04-02 16:37:14 +08:00
|
|
|
|
slots[stage_norm] = stage_slots
|
2026-01-21 22:31:03 +01:00
|
|
|
|
state.slots = slots
|
2026-04-02 16:37:14 +08:00
|
|
|
|
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
|
2026-01-21 22:31:03 +01:00
|
|
|
|
await db.commit()
|
|
|
|
|
|
await db.refresh(state)
|
2026-04-02 16:37:14 +08:00
|
|
|
|
return coerce_memoir_state(state)
|
2026-01-21 22:31:03 +01:00
|
|
|
|
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
async def mark_stage_complete(
|
|
|
|
|
|
user_id: str, stage: str, db: AsyncSession
|
|
|
|
|
|
) -> MemoirStateSchema:
|
2026-04-02 16:37:14 +08:00
|
|
|
|
"""推进 covered_stages 并在当前阶段匹配时尝试进入下一阶段。当前无调用方,预留未来阶段推进逻辑。"""
|
2026-01-21 22:31:03 +01:00
|
|
|
|
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one_or_none()
|
|
|
|
|
|
if not state:
|
|
|
|
|
|
return await get_or_create_state(user_id, db)
|
|
|
|
|
|
|
|
|
|
|
|
covered = state.covered_stages or []
|
|
|
|
|
|
if stage not in covered:
|
|
|
|
|
|
covered.append(stage)
|
|
|
|
|
|
state.covered_stages = covered
|
|
|
|
|
|
|
|
|
|
|
|
stage_order = state.stage_order or default_state().stage_order
|
|
|
|
|
|
if state.current_stage == stage:
|
|
|
|
|
|
try:
|
|
|
|
|
|
idx = stage_order.index(stage)
|
|
|
|
|
|
state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)]
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
state.current_stage = default_state().current_stage
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
await db.refresh(state)
|
2026-04-02 16:37:14 +08:00
|
|
|
|
return coerce_memoir_state(state)
|
2026-01-21 22:31:03 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]:
|
|
|
|
|
|
state = await get_or_create_state(user_id, db)
|
|
|
|
|
|
return state.empty_slots_for_current_stage()
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
async def switch_stage(
|
|
|
|
|
|
user_id: str, new_stage: str, db: AsyncSession
|
|
|
|
|
|
) -> MemoirStateSchema:
|
2026-04-02 16:37:14 +08:00
|
|
|
|
stmt = (
|
|
|
|
|
|
select(MemoirStateModel)
|
|
|
|
|
|
.where(MemoirStateModel.user_id == user_id)
|
|
|
|
|
|
.with_for_update()
|
|
|
|
|
|
)
|
2026-01-21 22:31:03 +01:00
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one_or_none()
|
|
|
|
|
|
if not state:
|
2026-04-02 16:37:14 +08:00
|
|
|
|
await get_or_create_state(user_id, db)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one()
|
2026-01-21 22:31:03 +01:00
|
|
|
|
|
2026-04-02 16:37:14 +08:00
|
|
|
|
fb = state.current_stage or "childhood"
|
|
|
|
|
|
state.current_stage = normalize_chat_stage(
|
|
|
|
|
|
new_stage, fallback=fb, log_context={"user_id": user_id}
|
|
|
|
|
|
)
|
2026-01-21 22:31:03 +01:00
|
|
|
|
await db.commit()
|
|
|
|
|
|
await db.refresh(state)
|
2026-04-02 16:37:14 +08:00
|
|
|
|
return coerce_memoir_state(state)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-08 21:36:12 +08:00
|
|
|
|
async def save_interview_state_meta(
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
*,
|
|
|
|
|
|
known_facts: list[KnownFact],
|
|
|
|
|
|
persona_threads: list[PersonaThread],
|
|
|
|
|
|
recent_questions: list[str],
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
) -> MemoirStateSchema:
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(MemoirStateModel)
|
|
|
|
|
|
.where(MemoirStateModel.user_id == user_id)
|
|
|
|
|
|
.with_for_update()
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one_or_none()
|
|
|
|
|
|
if not state:
|
|
|
|
|
|
await get_or_create_state(user_id, db)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one()
|
|
|
|
|
|
|
|
|
|
|
|
slots = _slots_snapshot_for_merge(
|
|
|
|
|
|
state.slots if isinstance(state.slots, dict) else None
|
|
|
|
|
|
)
|
|
|
|
|
|
state.slots = _inject_interview_state_meta(
|
|
|
|
|
|
slots=slots,
|
|
|
|
|
|
known_facts=known_facts,
|
|
|
|
|
|
persona_threads=persona_threads,
|
|
|
|
|
|
recent_questions=recent_questions,
|
|
|
|
|
|
)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
await db.refresh(state)
|
|
|
|
|
|
return coerce_memoir_state(state)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-02 16:37:14 +08:00
|
|
|
|
def get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
|
|
|
|
|
|
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
|
|
|
|
|
result = db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one_or_none()
|
|
|
|
|
|
if state:
|
|
|
|
|
|
return coerce_memoir_state(state)
|
|
|
|
|
|
|
|
|
|
|
|
default = default_state()
|
|
|
|
|
|
state = MemoirStateModel(
|
|
|
|
|
|
id=str(uuid.uuid4()),
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
stage_order=default.stage_order,
|
|
|
|
|
|
current_stage=default.current_stage,
|
|
|
|
|
|
covered_stages=default.covered_stages,
|
|
|
|
|
|
slots={
|
|
|
|
|
|
k: {sk: sv.model_dump() for sk, sv in v.items()}
|
|
|
|
|
|
for k, v in default.slots.items()
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(state)
|
|
|
|
|
|
db.commit()
|
|
|
|
|
|
db.refresh(state)
|
|
|
|
|
|
return coerce_memoir_state(state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_slot_sync(
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
stage: str,
|
|
|
|
|
|
slot_name: str,
|
|
|
|
|
|
snippet: str,
|
|
|
|
|
|
segment_ids: List[str],
|
|
|
|
|
|
db: Session,
|
|
|
|
|
|
*,
|
|
|
|
|
|
memoir_batch: bool = True,
|
|
|
|
|
|
) -> MemoirStateSchema:
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(MemoirStateModel)
|
|
|
|
|
|
.where(MemoirStateModel.user_id == user_id)
|
|
|
|
|
|
.with_for_update()
|
|
|
|
|
|
)
|
|
|
|
|
|
result = db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one_or_none()
|
|
|
|
|
|
if not state:
|
|
|
|
|
|
get_or_create_state_sync(user_id, db)
|
|
|
|
|
|
result = db.execute(stmt)
|
|
|
|
|
|
state = result.scalar_one()
|
|
|
|
|
|
|
|
|
|
|
|
current_from_db = state.current_stage or "childhood"
|
|
|
|
|
|
stage_norm = normalize_chat_stage(
|
|
|
|
|
|
stage,
|
|
|
|
|
|
fallback=current_from_db,
|
|
|
|
|
|
log_context={"user_id": user_id},
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
slots = _slots_snapshot_for_merge(
|
|
|
|
|
|
state.slots if isinstance(state.slots, dict) else None
|
|
|
|
|
|
)
|
|
|
|
|
|
stage_slots = dict(slots.get(stage_norm, {}) or {})
|
|
|
|
|
|
existing = stage_slots.get(slot_name, {})
|
|
|
|
|
|
|
|
|
|
|
|
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
|
|
|
|
|
|
stage_slots[slot_name] = SlotData(
|
|
|
|
|
|
snippet=snippet, segment_ids=merged_segment_ids
|
|
|
|
|
|
).model_dump()
|
|
|
|
|
|
slots[stage_norm] = stage_slots
|
|
|
|
|
|
state.slots = slots
|
|
|
|
|
|
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
|
|
|
|
|
|
db.commit()
|
|
|
|
|
|
db.refresh(state)
|
|
|
|
|
|
return coerce_memoir_state(state)
|