feat(memoir): 回忆录分段两阶段管线(Phase1 分类 / Phase2 叙事)与配置、测试

This commit is contained in:
Kevin
2026-04-02 16:37:14 +08:00
parent 3ae39838c0
commit 6b930808a3
27 changed files with 1550 additions and 430 deletions

View File

@@ -9,13 +9,16 @@ from typing import Dict, List, Set
import redis
from celery import shared_task
from sqlalchemy import select
from celery.exceptions import Retry
from celery.result import AsyncResult
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.agents.chat.background_voice import infer_background_voice
from app.agents.chat.prompts_profile import format_user_profile_context
from app.agents.memoir import MemoirOrchestrator
from app.agents.state_schema import MemoirStateSchema, SlotData, default_state
from app.agents.stage_constants import normalize_chapter_category
from app.agents.state_schema import MemoirStateSchema, default_state
from app.core.chapter_pipeline_lock import (
acquire_chapter_pipeline_lock as _acquire_chapter_lock,
)
@@ -26,7 +29,9 @@ from app.core.config import settings
from app.core.db import get_sync_db
from app.core.dependencies import get_llm_provider, get_llm_provider_fast
from app.core.logging import get_logger
from app.features.conversation.models import Segment
from app.features.conversation.models import Conversation, Segment
from app.tasks.celery_app import celery_app
from app.features.memoir.cover_eligibility import (
chapter_needs_cover_enqueue,
)
@@ -46,7 +51,10 @@ from app.features.memoir.memoir_images.settings import MemoirImageSettings
from app.features.memoir.models import (
Book,
MemoirImage,
MemoirState,
)
from app.features.memoir.state_service import (
get_or_create_state_sync,
update_slot_sync,
)
from app.features.memoir.story_pipeline_sync import (
run_story_pipeline_for_category_batch,
@@ -187,109 +195,361 @@ def _memoir_image_from_asset(
)
def _coerce_state(model: MemoirState) -> MemoirStateSchema:
"""将数据库模型转换为 Schema"""
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": model.slots
if isinstance(model.slots, dict)
else default_state().slots,
}
)
def _phase2_timeout_task_id(user_id: str, chapter_category: str) -> str:
return f"phase2-timeout-{user_id}-{chapter_category}"
def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
"""同步获取或创建状态"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return _coerce_state(state)
default = default_state()
state = MemoirState(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
db.add(state)
db.commit()
db.refresh(state)
return _coerce_state(state)
def _revoke_phase2_timeout(user_id: str, chapter_category: str) -> None:
tid = _phase2_timeout_task_id(user_id, chapter_category)
try:
AsyncResult(tid, app=celery_app).revoke(terminate=False)
except Exception as e:
logger.debug(
"event=phase2_timeout_revoke_skipped task_id={} exc={}",
tid,
e,
)
def _update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
def _should_trigger_phase2(
db: Session,
) -> MemoirStateSchema:
"""同步更新 slot"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
_get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
user_id: str,
chapter_category: str,
current_segment_chars: int,
) -> bool:
if current_segment_chars >= int(settings.memoir_narrative_immediate_char_threshold):
return True
user_convs = select(Conversation.id).where(
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
stmt = select(
func.count(Segment.id),
func.coalesce(func.sum(func.length(Segment.user_input_text)), 0),
).where(
Segment.conversation_id.in_(user_convs),
Segment.topic_category == chapter_category,
Segment.narrated.is_(False),
Segment.skip_narrative.is_(False),
)
row = db.execute(stmt).one()
count, total_chars = int(row[0] or 0), int(row[1] or 0)
if count >= int(settings.memoir_narrative_batch_min_segments):
return True
if total_chars >= int(settings.memoir_narrative_batch_min_chars):
return True
return False
slots: Dict[str, Dict] = state.slots or {}
stage_slots = slots.get(stage, {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage] = stage_slots
state.slots = slots
state.current_stage = stage
db.commit()
db.refresh(state)
return _coerce_state(state)
def _schedule_phase2_timeout(user_id: str, chapter_category: str) -> None:
"""Reset countdown for Phase 2 narrative for one category."""
_revoke_phase2_timeout(user_id, chapter_category)
countdown = float(max(1.0, settings.memoir_narrative_batch_max_wait_seconds))
celery_app.send_task(
"app.tasks.memoir_tasks.process_memoir_phase2",
args=[user_id, chapter_category],
countdown=countdown,
task_id=_phase2_timeout_task_id(user_id, chapter_category),
)
logger.info(
"event=phase2_timeout_scheduled user_id={} chapter_category={} countdown={}",
user_id,
chapter_category,
countdown,
)
def _dispatch_phase2_immediate(user_id: str, chapter_category: str) -> None:
_revoke_phase2_timeout(user_id, chapter_category)
celery_app.send_task(
"app.tasks.memoir_tasks.process_memoir_phase2",
args=[user_id, chapter_category],
)
logger.info(
"event=phase2_dispatched_immediate user_id={} chapter_category={}",
user_id,
chapter_category,
)
def dispatch_pending_memoir_phase2_for_user(user_id: str) -> None:
"""会话结束等场景:为该用户所有待叙事类目各发一条 Phase2幂等"""
try:
with get_sync_db() as db:
user_convs = select(Conversation.id).where(
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
stmt = (
select(Segment.topic_category)
.where(
Segment.conversation_id.in_(user_convs),
Segment.narrated.is_(False),
Segment.skip_narrative.is_(False),
Segment.topic_category.isnot(None),
)
.distinct()
)
cats = [r[0] for r in db.execute(stmt).all() if r[0]]
for chapter_category in cats:
_revoke_phase2_timeout(user_id, chapter_category)
celery_app.send_task(
"app.tasks.memoir_tasks.process_memoir_phase2",
args=[user_id, chapter_category],
)
logger.info(
"event=phase2_dispatched_flush user_id={} chapter_category={}",
user_id,
chapter_category,
)
except Exception as e:
logger.error(
"event=phase2_flush_failed user_id={} exc_type={} exc={}",
user_id,
type(e).__name__,
e,
)
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def process_memoir_phase2(self, user_id: str, chapter_category: str):
"""Phase 2叙事 / 路由 / 忠实度 / 标题;按类目加锁,消费未叙事且非 skip 的 segments。"""
task_id = self.request.id
logger.info(
"event=memoir_phase2_start user_id={} task_id={} chapter_category={}",
user_id,
task_id,
chapter_category,
)
try:
with get_sync_db() as db:
user_convs = select(Conversation.id).where(
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
stmt = (
select(Segment)
.where(
Segment.conversation_id.in_(user_convs),
Segment.topic_category == chapter_category,
Segment.narrated.is_(False),
Segment.skip_narrative.is_(False),
)
.order_by(Segment.created_at)
)
category_segments = list(db.execute(stmt).scalars().all())
if not category_segments:
logger.info(
"event=memoir_phase2_noop user_id={} chapter_category={}",
user_id,
chapter_category,
)
return {"status": "noop"}
llm = _get_llm()
user_obj = db.get(User, user_id)
user_profile = ""
user_birth_year = None
background_voice = "default"
user_occupation = ""
if user_obj:
user_birth_year = user_obj.birth_year
user_profile = format_user_profile_context(
birth_year=user_obj.birth_year,
birth_place=user_obj.birth_place,
grew_up_place=user_obj.grew_up_place,
occupation=user_obj.occupation,
)
background_voice = infer_background_voice(user_obj.occupation)
user_occupation = user_obj.occupation or ""
image_settings = MemoirImageSettings.from_env()
story_dispatch_ids: Set[str] = set()
chapters_to_enqueue: Set[str] = set()
affected_chapter_ids: Set[str] = set()
lock_handle = _acquire_chapter_lock(
user_id, chapter_category, ttl_seconds=_chapter_lock_ttl()
)
if lock_handle is None:
logger.warning(
"event=memoir_phase2_lock_busy user_id={} chapter_category={}",
user_id,
chapter_category,
)
raise self.retry(countdown=10)
try:
# 锁内再查一次,避免等待锁期间状态已变
category_segments = list(db.execute(stmt).scalars().all())
if not category_segments:
return {"status": "noop"}
state = get_or_create_state_sync(user_id, db)
chapter, needs_cover, disp = run_story_pipeline_for_category_batch(
db,
user_id=user_id,
chapter_category=chapter_category,
category_segments=category_segments,
state=state,
user_profile=user_profile,
user_birth_year=user_birth_year,
llm=llm,
background_voice=background_voice,
occupation=user_occupation,
)
story_dispatch_ids |= disp
db.flush()
if chapter is None:
logger.error(
"event=memoir_phase2_no_chapter user_id={} chapter_category={}",
user_id,
chapter_category,
)
db.rollback()
raise self.retry(
exc=RuntimeError("story_pipeline returned no chapter"),
countdown=30,
)
db.refresh(chapter)
affected_chapter_ids.add(chapter.id)
needs_cover_enqueue = (
image_settings.enabled and chapter_needs_cover_enqueue(chapter)
)
stmt_book = (
select(Book)
.where(Book.user_id == user_id)
.order_by(Book.updated_at.desc())
)
result_book = db.execute(stmt_book)
book = result_book.scalar_one_or_none()
if not book:
book = Book(
id=str(uuid.uuid4()),
user_id=user_id,
title="我的回忆录",
total_pages=0,
total_words=0,
cover_image_url=None,
)
db.add(book)
book.has_update = True
book.last_update_chapter_id = chapter.id
if needs_cover_enqueue:
chapters_to_enqueue.add(chapter.id)
for seg in category_segments:
seg.narrated = True
seg.processed = True
db.commit()
from app.features.story.post_commit import (
enqueue_story_post_commit_effects,
)
pc = enqueue_story_post_commit_effects(
user_id=user_id,
story_ids=set(story_dispatch_ids),
chapter_ids=affected_chapter_ids,
trigger_source="pipeline_phase2",
need_compaction=True,
compaction_extra={
"pipeline_run_id": str(task_id),
"story_dispatch_ids": sorted(story_dispatch_ids),
"chapters_to_enqueue": sorted(chapters_to_enqueue),
"chapter_category": chapter_category,
},
)
logger.info(
"event=story_post_commit user_id={} trigger=pipeline_phase2 "
"enqueued_story_image_count={} enqueued_chapter_recompose_count={} "
"compaction_scheduled={} errors={}",
user_id,
pc.enqueued_story_image_count,
pc.enqueued_chapter_recompose_count,
pc.compaction_scheduled,
pc.errors,
)
from app.tasks.chapter_cover_enqueue import (
try_enqueue_generate_chapter_cover,
)
for chapter_id in sorted(chapters_to_enqueue):
if try_enqueue_generate_chapter_cover(
chapter_id, source="pipeline_phase2"
):
logger.info(f"派发章节封面任务: chapter={chapter_id}")
logger.info(
"event=memoir_phase2_done user_id={} task_id={} chapter_category={} "
"segment_count={}",
user_id,
task_id,
chapter_category,
len(category_segments),
)
return {
"status": "success",
"chapter_category": chapter_category,
"segments": len(category_segments),
}
finally:
_release_chapter_lock(lock_handle)
except Retry:
raise
except Exception as e:
logger.error(
"event=memoir_phase2_failed user_id={} chapter_category={} exc={}",
user_id,
chapter_category,
e,
)
raise self.retry(exc=e) from e
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
def process_memoir_phase1(self, user_id: str, segment_ids: List[str]):
"""
处理回忆录段落的 Celery 任务
Args:
user_id: 用户 ID
segment_ids: 段落 ID 列表
Phase 1记忆 ingest + 抽取/分类;持久化 topic_category / skip_narrative
按需派发 Phase 2阈值或延迟兜底
"""
task_id = self.request.id
logger.info(
f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}"
"event=memoir_phase1_start user_id={} task_id={} segments={}",
user_id,
task_id,
len(segment_ids),
)
# 更新任务状态为 running
_update_task_status_sync(user_id, task_id, "running")
try:
with get_sync_db() as db:
# 获取段落
stmt = select(Segment).where(Segment.id.in_(segment_ids))
result = db.execute(stmt)
segments = result.scalars().all()
stmt = (
select(Segment)
.where(Segment.id.in_(segment_ids))
.order_by(Segment.created_at.asc(), Segment.id.asc())
)
rows = db.execute(stmt).scalars().all()
segments = [s for s in rows if not s.narrated]
if not segments:
logger.warning(f"未找到段落: {segment_ids}")
logger.warning("event=memoir_phase1_no_segments ids={}", segment_ids)
_update_task_status_sync(
user_id,
task_id,
"success",
{"processed": 0, "categories": []},
)
return {"status": "no_segments"}
# Memory ingest 先于回忆录流水线 commit保证后续 retrieve_evidence_sync 可见本批 chunk
# (见 api/docs/memory-retrieval.md
conv_id = getattr(segments[0], "conversation_id", None) or ""
transcript = "\n\n".join(seg.user_input_text or "" for seg in segments)
if transcript.strip():
@@ -322,169 +582,78 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
"event=llm_fast_tier_used pipeline=memoir_prepare_batches model={}",
settings.llm_fast_model,
)
image_settings = MemoirImageSettings.from_env()
user_obj = db.get(User, user_id)
user_profile = ""
user_birth_year = None
background_voice = "default"
user_occupation = ""
if user_obj:
user_birth_year = user_obj.birth_year
user_profile = format_user_profile_context(
birth_year=user_obj.birth_year,
birth_place=user_obj.birth_place,
grew_up_place=user_obj.grew_up_place,
occupation=user_obj.occupation,
)
background_voice = infer_background_voice(user_obj.occupation)
user_occupation = user_obj.occupation or ""
story_dispatch_ids: Set[str] = set()
memoir_orchestrator = MemoirOrchestrator()
prepared = memoir_orchestrator.prepare_batches(
segments=list(segments),
llm=llm,
llm_fast=llm_fast,
get_or_create_state=lambda: _get_or_create_state_sync(user_id, db),
update_slot=lambda stage, slot_name, snippet, seg_ids: (
_update_slot_sync(user_id, stage, slot_name, snippet, seg_ids, db)
get_or_create_state=lambda: get_or_create_state_sync(user_id, db),
update_slot=lambda stage, slot_name, snippet, seg_ids: update_slot_sync(
user_id,
stage,
slot_name,
snippet,
seg_ids,
db,
memoir_batch=True,
),
)
chapters_to_enqueue: Set[str] = set()
affected_chapter_ids: Set[str] = set()
for (
chapter_category,
category_segments,
) in prepared.category_to_segments.items():
lock_handle = _acquire_chapter_lock(
user_id, chapter_category, ttl_seconds=_chapter_lock_ttl()
skip_ids = prepared.segment_skip_story_ids
missing_cat = [
seg.id
for seg in segments
if not prepared.segment_chapter_category.get(str(seg.id))
]
if missing_cat:
logger.error(
"event=memoir_phase1_missing_category abort segment_ids={}",
missing_cat,
)
raise RuntimeError(
f"memoir_phase1_missing_category: {len(missing_cat)} segments"
)
if lock_handle is None:
logger.warning(
"章节锁竞争: category={}, 延迟重试",
chapter_category,
)
raise self.retry(countdown=10)
try:
batch_ids = {str(s.id) for s in category_segments}
skip_ids = prepared.segment_skip_story_ids
in_skip = batch_ids & skip_ids
if in_skip:
logger.info(
"event=memoir_skip_story_signal chapter_category={} "
"segment_ids_in_skip_set={}",
chapter_category,
sorted(in_skip),
)
if batch_ids and batch_ids <= skip_ids:
logger.info(
"event=story_pipeline_skipped reason=no_substantive_after_none "
"chapter_category={} segment_ids={}",
chapter_category,
sorted(batch_ids),
)
continue
chapter, needs_cover, disp = run_story_pipeline_for_category_batch(
db,
user_id=user_id,
chapter_category=chapter_category,
category_segments=category_segments,
state=prepared.state,
user_profile=user_profile,
user_birth_year=user_birth_year,
llm=llm,
background_voice=background_voice,
occupation=user_occupation,
)
story_dispatch_ids |= disp
db.flush()
db.refresh(chapter)
affected_chapter_ids.add(chapter.id)
needs_cover_enqueue = (
image_settings.enabled and chapter_needs_cover_enqueue(chapter)
)
stmt_book = (
select(Book)
.where(Book.user_id == user_id)
.order_by(Book.updated_at.desc())
)
result_book = db.execute(stmt_book)
book = result_book.scalar_one_or_none()
if not book:
book = Book(
id=str(uuid.uuid4()),
user_id=user_id,
title="我的回忆录",
total_pages=0,
total_words=0,
cover_image_url=None,
)
db.add(book)
book.has_update = True
book.last_update_chapter_id = chapter.id
if chapter and needs_cover_enqueue:
chapters_to_enqueue.add(chapter.id)
finally:
_release_chapter_lock(lock_handle)
# 标记段落为已处理
for seg in segments:
seg.processed = True
cat = prepared.segment_chapter_category[str(seg.id)]
seg.topic_category = cat
is_skip = str(seg.id) in skip_ids
seg.skip_narrative = is_skip
seg.narrated = False
if is_skip:
seg.processed = True
db.flush()
categories_for_phase2: Set[str] = set()
phase2_immediate: list[str] = []
phase2_timeout: list[str] = []
for chapter_category, cat_segments in prepared.category_to_segments.items():
batch_non_skip = [
s
for s in cat_segments
if str(s.id) not in prepared.segment_skip_story_ids
]
if not batch_non_skip:
continue
max_chars = max(
len((s.user_input_text or "").strip()) for s in batch_non_skip
)
categories_for_phase2.add(chapter_category)
if _should_trigger_phase2(db, user_id, chapter_category, max_chars):
phase2_immediate.append(chapter_category)
else:
phase2_timeout.append(chapter_category)
db.commit()
from app.features.story.post_commit import enqueue_story_post_commit_effects
pc = enqueue_story_post_commit_effects(
user_id=user_id,
story_ids=set(story_dispatch_ids),
chapter_ids=affected_chapter_ids,
trigger_source="pipeline",
need_compaction=True,
compaction_extra={
"pipeline_run_id": str(task_id),
"story_dispatch_ids": sorted(story_dispatch_ids),
"chapters_to_enqueue": sorted(chapters_to_enqueue),
},
)
logger.info(
"event=story_post_commit user_id={} trigger=pipeline "
"enqueued_story_image_count={} enqueued_chapter_recompose_count={} "
"compaction_scheduled={} errors={}",
user_id,
pc.enqueued_story_image_count,
pc.enqueued_chapter_recompose_count,
pc.compaction_scheduled,
pc.errors,
)
from app.tasks.chapter_cover_enqueue import (
try_enqueue_generate_chapter_cover,
)
for chapter_id in sorted(chapters_to_enqueue):
if try_enqueue_generate_chapter_cover(chapter_id, source="pipeline"):
logger.info(f"派发章节封面任务: chapter={chapter_id}")
for cc in phase2_immediate:
_dispatch_phase2_immediate(user_id, cc)
for cc in phase2_timeout:
_schedule_phase2_timeout(user_id, cc)
categories_processed = sorted(prepared.category_to_segments.keys())
logger.info(
"回忆录处理完成: user_id={} task_id={} segment_count={} "
"categories_processed={}",
user_id,
task_id,
len(segments),
categories_processed,
)
# 更新任务状态为成功
_update_task_status_sync(
user_id,
task_id,
@@ -492,25 +661,35 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
{
"processed": len(segments),
"categories_processed": categories_processed,
"phase2_watch_categories": sorted(categories_for_phase2),
},
)
logger.info(
"event=memoir_phase1_done user_id={} task_id={} segment_count={} "
"categories={}",
user_id,
task_id,
len(segments),
categories_processed,
)
return {
"status": "success",
"processed": len(segments),
"categories_processed": categories_processed,
}
except Retry:
raise
except Exception as e:
logger.error(f"回忆录处理失败: {e}")
# 更新任务状态为失败
logger.error("event=memoir_phase1_failed user_id={} exc={}", user_id, e)
_update_task_status_sync(user_id, task_id, "failure", {"error": str(e)})
# 重试
raise self.retry(exc=e) from e
# 兼容旧 Celery/文档入口名
process_memoir_segments = process_memoir_phase1
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
"""
@@ -521,6 +700,7 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
stage: 阶段
new_content: 新内容
"""
stage = normalize_chapter_category(stage, fallback="summary")
logger.info(f"生成章节内容: user_id={user_id}, stage={stage}")
try:
@@ -547,7 +727,7 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
self.id = str(uuid.uuid4())
self.user_input_text = text
state = _get_or_create_state_sync(user_id, db)
state = get_or_create_state_sync(user_id, db)
chapter, _, dispatch_ids = run_story_pipeline_for_category_batch(
db,
user_id=user_id,
@@ -560,14 +740,24 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
background_voice=background_voice,
occupation=user_occupation,
)
db.flush()
if chapter is None:
logger.error(
"event=generate_chapter_content_no_chapter user_id={} stage={}",
user_id,
stage,
)
db.rollback()
raise self.retry(
exc=RuntimeError("story_pipeline returned no chapter"),
countdown=30,
)
db.commit()
db.refresh(chapter)
from app.features.story.post_commit import enqueue_story_post_commit_effects
ch_ids: set[str] = set()
if chapter is not None:
ch_ids.add(str(chapter.id))
ch_ids: set[str] = {str(chapter.id)}
pc = enqueue_story_post_commit_effects(
user_id=user_id,
story_ids=set(dispatch_ids),
@@ -599,6 +789,8 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
try_enqueue_generate_chapter_cover(chapter.id, source="pipeline")
return {"status": "success"}
except Retry:
raise
except Exception as e:
logger.error(f"章节生成失败: {e}")
raise self.retry(exc=e) from e