feat(memoir): 回忆录分段两阶段管线(Phase1 分类 / Phase2 叙事)与配置、测试
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
回忆录状态服务:get_or_create_state、update_slot、mark_stage_complete 等。
|
||||
供 memoir service、conversation ws 使用;Celery 任务内使用同步版本(见 tasks/memoir_tasks)。
|
||||
供 memoir service、conversation ws 使用;Celery 任务内使用同步版本。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
@@ -8,12 +8,25 @@ from typing import Dict, List
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.agents.state_schema import MemoirStateSchema, SlotData, default_state
|
||||
from app.agents.stage_constants import (
|
||||
chat_bucket,
|
||||
normalize_chat_stage,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.features.memoir.models import MemoirState as MemoirStateModel
|
||||
|
||||
|
||||
def _coerce_state(model: MemoirStateModel) -> MemoirStateSchema:
|
||||
def _slots_snapshot_for_merge(raw: Dict[str, Dict] | None) -> Dict[str, Dict]:
|
||||
"""浅拷贝 slots,避免就地改 JSON 列同一 dict 引用导致 ORM 不标记 dirty。"""
|
||||
if not raw or not isinstance(raw, dict):
|
||||
return {}
|
||||
return {k: dict(v or {}) for k, v in raw.items()}
|
||||
|
||||
|
||||
def coerce_memoir_state(model: MemoirStateModel) -> MemoirStateSchema:
|
||||
return MemoirStateSchema.model_validate(
|
||||
{
|
||||
"stage_order": model.stage_order or default_state().stage_order,
|
||||
@@ -31,7 +44,7 @@ async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSche
|
||||
result = await db.execute(stmt)
|
||||
state = result.scalar_one_or_none()
|
||||
if state:
|
||||
return _coerce_state(state)
|
||||
return coerce_memoir_state(state)
|
||||
|
||||
default = default_state()
|
||||
state = MemoirStateModel(
|
||||
@@ -48,7 +61,27 @@ async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSche
|
||||
db.add(state)
|
||||
await db.commit()
|
||||
await db.refresh(state)
|
||||
return _coerce_state(state)
|
||||
return coerce_memoir_state(state)
|
||||
|
||||
|
||||
def _apply_current_stage_policy(
|
||||
state: MemoirStateModel,
|
||||
stage_norm: str,
|
||||
*,
|
||||
memoir_batch: bool,
|
||||
) -> None:
|
||||
"""按 memoir_extraction_updates_current_stage 与 chat_bucket 真值表更新 current_stage。"""
|
||||
current_from_db = state.current_stage or "childhood"
|
||||
if not memoir_batch:
|
||||
state.current_stage = stage_norm
|
||||
return
|
||||
|
||||
if not settings.memoir_extraction_updates_current_stage:
|
||||
return
|
||||
cur_b = chat_bucket(state.current_stage or current_from_db)
|
||||
new_b = chat_bucket(stage_norm)
|
||||
if new_b == cur_b:
|
||||
state.current_stage = stage_norm
|
||||
|
||||
|
||||
async def update_slot(
|
||||
@@ -58,8 +91,14 @@ async def update_slot(
|
||||
snippet: str,
|
||||
segment_ids: List[str],
|
||||
db: AsyncSession,
|
||||
*,
|
||||
memoir_batch: bool = False,
|
||||
) -> MemoirStateSchema:
|
||||
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
||||
stmt = (
|
||||
select(MemoirStateModel)
|
||||
.where(MemoirStateModel.user_id == user_id)
|
||||
.with_for_update()
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
state = result.scalar_one_or_none()
|
||||
if not state:
|
||||
@@ -67,25 +106,35 @@ async def update_slot(
|
||||
result = await db.execute(stmt)
|
||||
state = result.scalar_one()
|
||||
|
||||
slots: Dict[str, Dict] = state.slots or {}
|
||||
stage_slots = slots.get(stage, {})
|
||||
current_from_db = state.current_stage or "childhood"
|
||||
stage_norm = normalize_chat_stage(
|
||||
stage,
|
||||
fallback=current_from_db,
|
||||
log_context={"user_id": user_id},
|
||||
)
|
||||
|
||||
slots = _slots_snapshot_for_merge(
|
||||
state.slots if isinstance(state.slots, dict) else None
|
||||
)
|
||||
stage_slots = dict(slots.get(stage_norm, {}) or {})
|
||||
existing = stage_slots.get(slot_name, {})
|
||||
|
||||
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
|
||||
stage_slots[slot_name] = SlotData(
|
||||
snippet=snippet, segment_ids=merged_segment_ids
|
||||
).model_dump()
|
||||
slots[stage] = stage_slots
|
||||
slots[stage_norm] = stage_slots
|
||||
state.slots = slots
|
||||
state.current_stage = stage
|
||||
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
|
||||
await db.commit()
|
||||
await db.refresh(state)
|
||||
return _coerce_state(state)
|
||||
return coerce_memoir_state(state)
|
||||
|
||||
|
||||
async def mark_stage_complete(
|
||||
user_id: str, stage: str, db: AsyncSession
|
||||
) -> MemoirStateSchema:
|
||||
"""推进 covered_stages 并在当前阶段匹配时尝试进入下一阶段。当前无调用方,预留未来阶段推进逻辑。"""
|
||||
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
state = result.scalar_one_or_none()
|
||||
@@ -106,7 +155,7 @@ async def mark_stage_complete(
|
||||
state.current_stage = default_state().current_stage
|
||||
await db.commit()
|
||||
await db.refresh(state)
|
||||
return _coerce_state(state)
|
||||
return coerce_memoir_state(state)
|
||||
|
||||
|
||||
async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]:
|
||||
@@ -117,13 +166,94 @@ async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]:
|
||||
async def switch_stage(
|
||||
user_id: str, new_stage: str, db: AsyncSession
|
||||
) -> MemoirStateSchema:
|
||||
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
||||
stmt = (
|
||||
select(MemoirStateModel)
|
||||
.where(MemoirStateModel.user_id == user_id)
|
||||
.with_for_update()
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
state = result.scalar_one_or_none()
|
||||
if not state:
|
||||
return await get_or_create_state(user_id, db)
|
||||
await get_or_create_state(user_id, db)
|
||||
result = await db.execute(stmt)
|
||||
state = result.scalar_one()
|
||||
|
||||
state.current_stage = new_stage
|
||||
fb = state.current_stage or "childhood"
|
||||
state.current_stage = normalize_chat_stage(
|
||||
new_stage, fallback=fb, log_context={"user_id": user_id}
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(state)
|
||||
return _coerce_state(state)
|
||||
return coerce_memoir_state(state)
|
||||
|
||||
|
||||
def get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
|
||||
stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id)
|
||||
result = db.execute(stmt)
|
||||
state = result.scalar_one_or_none()
|
||||
if state:
|
||||
return coerce_memoir_state(state)
|
||||
|
||||
default = default_state()
|
||||
state = MemoirStateModel(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
stage_order=default.stage_order,
|
||||
current_stage=default.current_stage,
|
||||
covered_stages=default.covered_stages,
|
||||
slots={
|
||||
k: {sk: sv.model_dump() for sk, sv in v.items()}
|
||||
for k, v in default.slots.items()
|
||||
},
|
||||
)
|
||||
db.add(state)
|
||||
db.commit()
|
||||
db.refresh(state)
|
||||
return coerce_memoir_state(state)
|
||||
|
||||
|
||||
def update_slot_sync(
|
||||
user_id: str,
|
||||
stage: str,
|
||||
slot_name: str,
|
||||
snippet: str,
|
||||
segment_ids: List[str],
|
||||
db: Session,
|
||||
*,
|
||||
memoir_batch: bool = True,
|
||||
) -> MemoirStateSchema:
|
||||
stmt = (
|
||||
select(MemoirStateModel)
|
||||
.where(MemoirStateModel.user_id == user_id)
|
||||
.with_for_update()
|
||||
)
|
||||
result = db.execute(stmt)
|
||||
state = result.scalar_one_or_none()
|
||||
if not state:
|
||||
get_or_create_state_sync(user_id, db)
|
||||
result = db.execute(stmt)
|
||||
state = result.scalar_one()
|
||||
|
||||
current_from_db = state.current_stage or "childhood"
|
||||
stage_norm = normalize_chat_stage(
|
||||
stage,
|
||||
fallback=current_from_db,
|
||||
log_context={"user_id": user_id},
|
||||
)
|
||||
|
||||
slots = _slots_snapshot_for_merge(
|
||||
state.slots if isinstance(state.slots, dict) else None
|
||||
)
|
||||
stage_slots = dict(slots.get(stage_norm, {}) or {})
|
||||
existing = stage_slots.get(slot_name, {})
|
||||
|
||||
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
|
||||
stage_slots[slot_name] = SlotData(
|
||||
snippet=snippet, segment_ids=merged_segment_ids
|
||||
).model_dump()
|
||||
slots[stage_norm] = stage_slots
|
||||
state.slots = slots
|
||||
_apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch)
|
||||
db.commit()
|
||||
db.refresh(state)
|
||||
return coerce_memoir_state(state)
|
||||
|
||||
Reference in New Issue
Block a user