Files
life-echo/api/app/features/memoir/state_service.py
Kevin 3121d1384d WIP: memory system improvements (in progress)
Interview/chat prompt layers, reply planner, style profiles, memory
injection, interview meta store, and related tests. Work not finished.

Made-with: Cursor
2026-04-22 16:56:28 +08:00

309 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录状态服务get_or_create_state、update_slot、mark_stage_complete 等。
供 memoir service、conversation ws 使用Celery 任务内使用同步版本。
"""
import uuid
from typing import Dict, List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.agents.stage_constants import (
chat_bucket,
normalize_chat_stage,
)
from app.agents.state_schema import (
InterviewControlState,
KnownFact,
MemoirStateSchema,
PersonaThread,
SlotData,
default_state,
)
from app.core.config import settings
from app.features.memoir import _interview_meta_store as interview_meta
from app.features.memoir.models import MemoirState as MemoirStateModel
def _slots_snapshot_for_merge(raw: Dict[str, Dict] | None) -> Dict[str, Dict]:
"""浅拷贝 slots避免就地改 JSON 列同一 dict 引用导致 ORM 不标记 dirty。"""
if not raw or not isinstance(raw, dict):
return {}
return {k: dict(v or {}) for k, v in raw.items()}
def coerce_memoir_state(model: MemoirStateModel) -> MemoirStateSchema:
"""把 ORM 行投影成 MemoirStateSchema控制元数据的读法已隔离在 interview_meta 适配层。"""
raw_slots = model.slots if isinstance(model.slots, dict) else None
control = interview_meta.read(raw_slots)
clean_slots = interview_meta.strip(raw_slots) or dict(default_state().slots)
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": clean_slots,
"known_facts": control.known_facts,
"persona_threads": control.persona_threads,
"recent_questions": control.recent_questions,
}
)
async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return coerce_memoir_state(state)
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
db.add(state)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
def _apply_current_stage_policy(
state: MemoirStateModel,
stage_norm: str,
*,
memoir_batch: bool,
) -> None:
"""按 memoir_extraction_updates_current_stage 与 chat_bucket 真值表更新 current_stage。"""
current_from_db = state.current_stage or "childhood"
if not memoir_batch:
state.current_stage = stage_norm
return
if not settings.memoir_extraction_updates_current_stage:
return
cur_b = chat_bucket(state.current_stage or current_from_db)
new_b = chat_bucket(stage_norm)
if new_b == cur_b:
state.current_stage = stage_norm
async def update_slot(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: AsyncSession,
*,
memoir_batch: bool = False,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
current_from_db = state.current_stage or "childhood"
stage_norm = normalize_chat_stage(
stage,
fallback=current_from_db,
log_context={"user_id": user_id},
)
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
stage_slots = dict(slots.get(stage_norm, {}) or {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage_norm] = stage_slots
state.slots = slots
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
async def mark_stage_complete(
user_id: str, stage: str, db: AsyncSession
) -> MemoirStateSchema:
"""推进 covered_stages 并在当前阶段匹配时尝试进入下一阶段。当前无调用方,预留未来阶段推进逻辑。"""
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
return await get_or_create_state(user_id, db)
covered = state.covered_stages or []
if stage not in covered:
covered.append(stage)
state.covered_stages = covered
stage_order = state.stage_order or default_state().stage_order
if state.current_stage == stage:
try:
idx = stage_order.index(stage)
state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)]
except ValueError:
state.current_stage = default_state().current_stage
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]:
state = await get_or_create_state(user_id, db)
return state.empty_slots_for_current_stage()
async def switch_stage(
user_id: str, new_stage: str, db: AsyncSession
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
fb = state.current_stage or "childhood"
state.current_stage = normalize_chat_stage(
new_stage, fallback=fb, log_context={"user_id": user_id}
)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
async def save_interview_state_meta(
user_id: str,
*,
known_facts: list[KnownFact],
persona_threads: list[PersonaThread],
recent_questions: list[str],
db: AsyncSession,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = await db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
await get_or_create_state(user_id, db)
result = await db.execute(stmt)
state = result.scalar_one()
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
state.slots = interview_meta.write(
slots,
control=InterviewControlState(
known_facts=known_facts,
persona_threads=persona_threads,
recent_questions=recent_questions,
),
)
await db.commit()
await db.refresh(state)
return coerce_memoir_state(state)
def get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return coerce_memoir_state(state)
default = default_state()
state = MemoirStateModel(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
db.add(state)
db.commit()
db.refresh(state)
return coerce_memoir_state(state)
def update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: Session,
*,
memoir_batch: bool = True,
) -> MemoirStateSchema:
stmt = (
select(MemoirStateModel)
.where(MemoirStateModel.user_id == user_id)
.with_for_update()
)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
current_from_db = state.current_stage or "childhood"
stage_norm = normalize_chat_stage(
stage,
fallback=current_from_db,
log_context={"user_id": user_id},
)
slots = _slots_snapshot_for_merge(
state.slots if isinstance(state.slots, dict) else None
)
stage_slots = dict(slots.get(stage_norm, {}) or {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage_norm] = stage_slots
state.slots = slots
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
db.commit()
db.refresh(state)
return coerce_memoir_state(state)