""" 回忆录处理 Celery 任务 """ import json import uuid from datetime import datetime, timezone from typing import Dict, List, Set import redis from celery import shared_task from celery.exceptions import Retry from celery.result import AsyncResult from sqlalchemy import func, select from sqlalchemy.orm import Session from app.agents.chat.background_voice import infer_background_voice from app.agents.chat.prompts_profile import format_user_profile_context from app.agents.memoir import MemoirOrchestrator from app.agents.stage_constants import normalize_chapter_category from app.core.chapter_pipeline_lock import ( acquire_chapter_pipeline_lock as _acquire_chapter_lock, ) from app.core.chapter_pipeline_lock import ( release_chapter_pipeline_lock as _release_chapter_lock, ) from app.core.config import settings from app.core.db import get_sync_db from app.core.dependencies import get_llm_provider, get_llm_provider_fast from app.core.logging import get_logger from app.core.memoir_pipeline_trace import ( effective_correlation_id, new_memoir_correlation_id, ) from app.features.conversation.models import Conversation, Segment from app.features.memoir.cover_eligibility import ( chapter_needs_cover_enqueue, ) from app.features.memoir.memoir_images.parser import ( build_initial_image_assets, ) from app.features.memoir.memoir_images.schema import ( IMAGE_STATUS_COMPLETED, IMAGE_STATUS_FAILED, IMAGE_STATUS_PENDING, normalize_image_assets, ) from app.features.memoir.memoir_images.serializers import ( image_dict_to_row_kwargs, ) from app.features.memoir.memoir_images.settings import MemoirImageSettings from app.features.memoir.models import ( Book, MemoirImage, ) from app.features.memoir.state_service import ( get_or_create_state_sync, update_slot_sync, ) from app.features.memoir.story_pipeline_sync import ( run_story_pipeline_for_category_batch, ) from app.features.user.models import User from app.tasks.celery_app import celery_app logger = get_logger(__name__) _REDIS_CLIENTS: dict[bool, redis.Redis] = {} def _get_llm(): """Celery 任务内获取 LangChain LLM(通过 port)""" try: return getattr(get_llm_provider(), "langchain_llm", None) except Exception: return None def _get_llm_fast(): """分类 / 抽取等快档位任务(与叙事、路由默认模型可分离)。""" try: return getattr(get_llm_provider_fast(), "langchain_llm", None) except Exception: return None def _get_redis_client(*, decode_responses: bool = False) -> redis.Redis: from app.core.config import settings client = _REDIS_CLIENTS.get(decode_responses) if client is None: client = redis.from_url( settings.redis_url, decode_responses=decode_responses, ) _REDIS_CLIENTS[decode_responses] = client return client def _chapter_lock_ttl() -> int: return int(settings.chapter_pipeline_lock_ttl_seconds) def _update_task_status_sync( user_id: str, task_id: str, status: str, result: Dict = None ): """同步更新任务状态(在 Celery 任务中使用)""" try: r = _get_redis_client(decode_responses=True) key = f"task:user:{user_id}:tasks" # 获取现有任务信息 data = r.hget(key, task_id) if data: task_info = json.loads(data) else: task_info = {"task_id": task_id} task_info["status"] = status task_info["updated_at"] = datetime.now(timezone.utc).isoformat() if result is not None: task_info["result"] = result r.hset(key, task_id, json.dumps(task_info)) r.expire(key, 3600) # 1小时过期 logger.debug("任务状态已更新: task_id={} status={}", task_id, status) except Exception as e: logger.error(f"更新任务状态失败: {e}") def _merge_chapter_image_assets( existing_images: list[dict] | None, placeholders: list[dict], provider: str, style: str, size: str, now_iso: str, ) -> list[dict]: normalized_existing_images = normalize_image_assets(existing_images) existing_by_placeholder = { item.get("placeholder"): dict(item) for item in normalized_existing_images if item.get("placeholder") } merged_assets: list[dict] = [] for item in placeholders: existing = existing_by_placeholder.get(item["placeholder"]) if existing: merged_item = dict(existing) merged_item["index"] = item["index"] merged_item["placeholder"] = item["placeholder"] merged_item["description"] = item["description"] merged_item["provider"] = merged_item.get("provider") or provider merged_item["style"] = merged_item.get("style") or style merged_item["size"] = merged_item.get("size") or size merged_item["created_at"] = merged_item.get("created_at") or now_iso merged_item["updated_at"] = merged_item.get("updated_at") or now_iso if merged_item.get("status") == IMAGE_STATUS_COMPLETED and not ( merged_item.get("storage_key") or merged_item.get("url") ): merged_item["status"] = IMAGE_STATUS_FAILED merged_item["error"] = merged_item.get("error") or "missing image url" else: merged_item = build_initial_image_assets( placeholders=[item], provider=provider, style=style, size=size, now_iso=now_iso, )[0] merged_assets.append(merged_item) return merged_assets def chapter_has_images_to_generate(images: list[dict] | None) -> bool: return any( item.get("status") in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED} for item in normalize_image_assets(images) ) def _memoir_image_from_asset( chapter_id: str, order_index: int, image_asset: dict, ) -> MemoirImage: """从单条图片 dict 构建 MemoirImage 行(用于写入 memoir_images 表)。""" kwargs = image_dict_to_row_kwargs(image_asset) return MemoirImage( id=str(uuid.uuid4()).replace("-", "")[:32], chapter_id=chapter_id, order_index=order_index, **kwargs, ) def _phase2_timeout_task_id(user_id: str, chapter_category: str) -> str: return f"phase2-timeout-{user_id}-{chapter_category}" def _revoke_phase2_timeout(user_id: str, chapter_category: str) -> None: tid = _phase2_timeout_task_id(user_id, chapter_category) try: AsyncResult(tid, app=celery_app).revoke(terminate=False) except Exception as e: logger.debug( "event=phase2_timeout_revoke_skipped task_id={} exc={}", tid, e, ) def _should_trigger_phase2( db: Session, user_id: str, chapter_category: str, current_segment_chars: int, ) -> bool: if current_segment_chars >= int(settings.memoir_narrative_immediate_char_threshold): return True user_convs = select(Conversation.id).where( Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) stmt = select( func.count(Segment.id), func.coalesce(func.sum(func.length(Segment.user_input_text)), 0), ).where( Segment.conversation_id.in_(user_convs), Segment.topic_category == chapter_category, Segment.narrated.is_(False), Segment.skip_narrative.is_(False), ) row = db.execute(stmt).one() count, total_chars = int(row[0] or 0), int(row[1] or 0) if count >= int(settings.memoir_narrative_batch_min_segments): return True if total_chars >= int(settings.memoir_narrative_batch_min_chars): return True return False def _phase2_immediate_task_id(user_id: str, chapter_category: str) -> str: return f"phase2-immediate-{user_id}-{chapter_category}" def _schedule_phase2_timeout( user_id: str, chapter_category: str, memoir_correlation_id: str | None = None ) -> None: """Reset countdown for Phase 2 narrative for one category.""" _revoke_phase2_timeout(user_id, chapter_category) countdown = float(max(1.0, settings.memoir_narrative_batch_max_wait_seconds)) p2_kwargs: dict = {} if memoir_correlation_id: p2_kwargs["memoir_correlation_id"] = memoir_correlation_id celery_app.send_task( "app.tasks.memoir_tasks.process_memoir_phase2", args=[user_id, chapter_category], kwargs=p2_kwargs, countdown=countdown, task_id=_phase2_timeout_task_id(user_id, chapter_category), ) logger.info( "event=phase2_timeout_scheduled user_id={} chapter_category={} countdown={} " "memoir_correlation_id={}", user_id, chapter_category, countdown, memoir_correlation_id or "", ) def _dispatch_phase2_immediate( user_id: str, chapter_category: str, memoir_correlation_id: str | None = None ) -> None: _revoke_phase2_timeout(user_id, chapter_category) p2_kwargs: dict = {} if memoir_correlation_id: p2_kwargs["memoir_correlation_id"] = memoir_correlation_id send_kw: dict = { "args": [user_id, chapter_category], "kwargs": p2_kwargs, } if settings.memoir_phase2_singleflight_immediate: send_kw["task_id"] = _phase2_immediate_task_id(user_id, chapter_category) celery_app.send_task("app.tasks.memoir_tasks.process_memoir_phase2", **send_kw) logger.info( "event=phase2_dispatched_immediate user_id={} chapter_category={} " "memoir_correlation_id={} task_id_mode={}", user_id, chapter_category, memoir_correlation_id or "", "singleflight" if settings.memoir_phase2_singleflight_immediate else "unique", ) def dispatch_pending_memoir_phase2_for_user(user_id: str) -> None: """会话结束等场景:为该用户所有待叙事类目各发一条 Phase2(幂等)。""" try: with get_sync_db() as db: user_convs = select(Conversation.id).where( Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) stmt = ( select(Segment.topic_category) .where( Segment.conversation_id.in_(user_convs), Segment.narrated.is_(False), Segment.skip_narrative.is_(False), Segment.topic_category.isnot(None), ) .distinct() ) cats = [r[0] for r in db.execute(stmt).all() if r[0]] for chapter_category in cats: _revoke_phase2_timeout(user_id, chapter_category) flush_cid = new_memoir_correlation_id() celery_app.send_task( "app.tasks.memoir_tasks.process_memoir_phase2", args=[user_id, chapter_category], kwargs={"memoir_correlation_id": flush_cid}, ) logger.info( "event=phase2_dispatched_flush user_id={} chapter_category={} " "memoir_correlation_id={}", user_id, chapter_category, flush_cid, ) except Exception as e: logger.error( "event=phase2_flush_failed user_id={} exc_type={} exc={}", user_id, type(e).__name__, e, ) @shared_task(bind=True, max_retries=3, default_retry_delay=30) def process_memoir_phase2( self, user_id: str, chapter_category: str, memoir_correlation_id: str | None = None, ): """Phase 2:叙事 / 路由 / 忠实度 / 标题;按类目加锁,消费未叙事且非 skip 的 segments。""" task_id = self.request.id cid = effective_correlation_id( explicit=memoir_correlation_id, celery_task_id=str(task_id) ) logger.info( "event=memoir_phase2_start user_id={} task_id={} chapter_category={} " "memoir_correlation_id={}", user_id, task_id, chapter_category, cid, ) try: with get_sync_db() as db: user_convs = select(Conversation.id).where( Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) stmt = ( select(Segment) .where( Segment.conversation_id.in_(user_convs), Segment.topic_category == chapter_category, Segment.narrated.is_(False), Segment.skip_narrative.is_(False), ) .order_by(Segment.created_at) ) category_segments = list(db.execute(stmt).scalars().all()) if not category_segments: logger.info( "event=memoir_phase2_noop user_id={} chapter_category={}", user_id, chapter_category, ) return {"status": "noop"} llm = _get_llm() user_obj = db.get(User, user_id) user_profile = "" user_birth_year = None background_voice = "default" user_occupation = "" if user_obj: user_birth_year = user_obj.birth_year user_profile = format_user_profile_context( birth_year=user_obj.birth_year, birth_place=user_obj.birth_place, grew_up_place=user_obj.grew_up_place, occupation=user_obj.occupation, ) background_voice = infer_background_voice(user_obj.occupation) user_occupation = user_obj.occupation or "" image_settings = MemoirImageSettings.from_env() story_dispatch_ids: Set[str] = set() chapters_to_enqueue: Set[str] = set() affected_chapter_ids: Set[str] = set() lock_handle = _acquire_chapter_lock( user_id, chapter_category, ttl_seconds=_chapter_lock_ttl() ) if lock_handle is None: logger.warning( "event=memoir_phase2_lock_busy user_id={} chapter_category={}", user_id, chapter_category, ) raise self.retry(countdown=10) try: # 锁内再查一次,避免等待锁期间状态已变 category_segments = list(db.execute(stmt).scalars().all()) if not category_segments: return {"status": "noop"} state = get_or_create_state_sync(user_id, db) chapter, needs_cover, disp = run_story_pipeline_for_category_batch( db, user_id=user_id, chapter_category=chapter_category, category_segments=category_segments, state=state, user_profile=user_profile, user_birth_year=user_birth_year, llm=llm, background_voice=background_voice, occupation=user_occupation, memoir_correlation_id=cid, ) story_dispatch_ids |= disp db.flush() if chapter is None: logger.error( "event=memoir_phase2_no_chapter user_id={} chapter_category={}", user_id, chapter_category, ) db.rollback() raise self.retry( exc=RuntimeError("story_pipeline returned no chapter"), countdown=30, ) db.refresh(chapter) affected_chapter_ids.add(chapter.id) needs_cover_enqueue = ( image_settings.enabled and chapter_needs_cover_enqueue(chapter) ) stmt_book = ( select(Book) .where(Book.user_id == user_id) .order_by(Book.updated_at.desc()) ) result_book = db.execute(stmt_book) book = result_book.scalar_one_or_none() if not book: book = Book( id=str(uuid.uuid4()), user_id=user_id, title="我的回忆录", total_pages=0, total_words=0, cover_image_url=None, ) db.add(book) book.has_update = True book.last_update_chapter_id = chapter.id if needs_cover_enqueue: chapters_to_enqueue.add(chapter.id) for seg in category_segments: seg.narrated = True seg.processed = True db.commit() from app.features.story.post_commit import ( enqueue_story_post_commit_effects, ) pc = enqueue_story_post_commit_effects( user_id=user_id, story_ids=set(story_dispatch_ids), chapter_ids=affected_chapter_ids, trigger_source="pipeline_phase2", need_compaction=True, compaction_extra={ "pipeline_run_id": str(task_id), "memoir_correlation_id": cid, "story_dispatch_ids": sorted(story_dispatch_ids), "chapters_to_enqueue": sorted(chapters_to_enqueue), "chapter_category": chapter_category, }, ) logger.info( "event=story_post_commit user_id={} trigger=pipeline_phase2 " "enqueued_story_image_count={} enqueued_chapter_recompose_count={} " "compaction_scheduled={} errors={}", user_id, pc.enqueued_story_image_count, pc.enqueued_chapter_recompose_count, pc.compaction_scheduled, pc.errors, ) from app.tasks.chapter_cover_enqueue import ( try_enqueue_generate_chapter_cover, ) for chapter_id in sorted(chapters_to_enqueue): if try_enqueue_generate_chapter_cover( chapter_id, source="pipeline_phase2" ): logger.info(f"派发章节封面任务: chapter={chapter_id}") logger.info( "event=memoir_phase2_done user_id={} task_id={} chapter_category={} " "segment_count={} memoir_correlation_id={}", user_id, task_id, chapter_category, len(category_segments), cid, ) return { "status": "success", "chapter_category": chapter_category, "segments": len(category_segments), } finally: _release_chapter_lock(lock_handle) except Retry: raise except Exception as e: logger.error( "event=memoir_phase2_failed user_id={} chapter_category={} exc={}", user_id, chapter_category, e, ) raise self.retry(exc=e) from e @shared_task(bind=True, max_retries=3, default_retry_delay=60) def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): """ Phase 1:记忆 ingest + 抽取/分类;持久化 topic_category / skip_narrative; 按需派发 Phase 2(阈值或延迟兜底)。 """ task_id = self.request.id memoir_correlation_id = new_memoir_correlation_id() logger.info( "event=memoir_phase1_start user_id={} task_id={} segments={} " "memoir_correlation_id={}", user_id, task_id, len(segment_ids), memoir_correlation_id, ) _update_task_status_sync(user_id, task_id, "running") try: with get_sync_db() as db: stmt = ( select(Segment) .where(Segment.id.in_(segment_ids)) .order_by(Segment.created_at.asc(), Segment.id.asc()) ) rows = db.execute(stmt).scalars().all() segments = [s for s in rows if not s.narrated] if not segments: logger.warning("event=memoir_phase1_no_segments ids={}", segment_ids) _update_task_status_sync( user_id, task_id, "success", {"processed": 0, "categories": []}, ) return {"status": "no_segments"} for seg in segments: conv_id = getattr(seg, "conversation_id", None) or "" text = (seg.user_input_text or "").strip() if not text: continue try: from app.features.memory.service import ingest_transcript_sync ln = getattr(seg, "lineage_json", None) lineage_payload = ln if isinstance(ln, dict) else None source_id = ingest_transcript_sync( db, user_id, conv_id, text, lineage_json=lineage_payload, ) logger.info( "event=memory_transcript_ingested user_id={} task_id={} " "source_id={} conversation_id={} segment_id={} transcript_chars={}", user_id, task_id, source_id, conv_id, seg.id, len(text), ) except Exception as e: logger.warning( "Memory ingest 跳过 segment_id={}: {} exc_type={}", getattr(seg, "id", ""), e, type(e).__name__, ) llm = _get_llm() llm_fast = _get_llm_fast() if (settings.llm_fast_model or "").strip(): logger.info( "event=llm_fast_tier_used pipeline=memoir_prepare_batches model={}", settings.llm_fast_model, ) memoir_orchestrator = MemoirOrchestrator() prepared = memoir_orchestrator.prepare_batches( segments=list(segments), llm=llm, llm_fast=llm_fast, get_or_create_state=lambda: get_or_create_state_sync(user_id, db), update_slot=lambda stage, slot_name, snippet, seg_ids: update_slot_sync( user_id, stage, slot_name, snippet, seg_ids, db, memoir_batch=True, ), ) skip_ids = prepared.segment_skip_story_ids missing_cat = [ seg.id for seg in segments if not prepared.segment_chapter_category.get(str(seg.id)) ] if missing_cat: logger.error( "event=memoir_phase1_missing_category abort segment_ids={}", missing_cat, ) raise RuntimeError( f"memoir_phase1_missing_category: {len(missing_cat)} segments" ) for seg in segments: cat = prepared.segment_chapter_category[str(seg.id)] seg.topic_category = cat is_skip = str(seg.id) in skip_ids seg.skip_narrative = is_skip seg.narrated = False if is_skip: seg.processed = True db.flush() categories_for_phase2: Set[str] = set() phase2_immediate: list[str] = [] phase2_timeout: list[str] = [] for chapter_category, cat_segments in prepared.category_to_segments.items(): batch_non_skip = [ s for s in cat_segments if str(s.id) not in prepared.segment_skip_story_ids ] if not batch_non_skip: continue max_chars = max( len((s.user_input_text or "").strip()) for s in batch_non_skip ) categories_for_phase2.add(chapter_category) if _should_trigger_phase2(db, user_id, chapter_category, max_chars): phase2_immediate.append(chapter_category) else: phase2_timeout.append(chapter_category) db.commit() for cc in phase2_immediate: _dispatch_phase2_immediate(user_id, cc, memoir_correlation_id) for cc in phase2_timeout: _schedule_phase2_timeout(user_id, cc, memoir_correlation_id) categories_processed = sorted(prepared.category_to_segments.keys()) _update_task_status_sync( user_id, task_id, "success", { "processed": len(segments), "categories_processed": categories_processed, "phase2_watch_categories": sorted(categories_for_phase2), }, ) logger.info( "event=memoir_phase1_done user_id={} task_id={} segment_count={} " "categories={} memoir_correlation_id={}", user_id, task_id, len(segments), categories_processed, memoir_correlation_id, ) return { "status": "success", "processed": len(segments), "categories_processed": categories_processed, } except Retry: raise except Exception as e: logger.error("event=memoir_phase1_failed user_id={} exc={}", user_id, e) _update_task_status_sync(user_id, task_id, "failure", {"error": str(e)}) raise self.retry(exc=e) from e # 兼容旧 Celery/文档入口名 process_memoir_segments = process_memoir_phase1 @shared_task(bind=True, max_retries=3, default_retry_delay=30) def generate_chapter_content(self, user_id: str, stage: str, new_content: str): """ 单独生成章节内容的任务(用于实时更新) Args: user_id: 用户 ID stage: 阶段 new_content: 新内容 """ stage = normalize_chapter_category(stage, fallback="summary") cid = effective_correlation_id(explicit=None, celery_task_id=str(self.request.id)) logger.info( "event=generate_chapter_content_start user_id={} stage={} memoir_correlation_id={}", user_id, stage, cid, ) try: with get_sync_db() as db: llm = _get_llm() user_obj = db.get(User, user_id) user_profile = "" user_birth_year = None background_voice = "default" user_occupation = "" if user_obj: user_birth_year = user_obj.birth_year user_profile = format_user_profile_context( birth_year=user_obj.birth_year, birth_place=user_obj.birth_place, grew_up_place=user_obj.grew_up_place, occupation=user_obj.occupation, ) background_voice = infer_background_voice(user_obj.occupation) user_occupation = user_obj.occupation or "" class _Seg: def __init__(self, text: str): self.id = str(uuid.uuid4()) self.user_input_text = text state = get_or_create_state_sync(user_id, db) chapter, _, dispatch_ids = run_story_pipeline_for_category_batch( db, user_id=user_id, chapter_category=stage, category_segments=[_Seg(new_content)], state=state, user_profile=user_profile, user_birth_year=user_birth_year, llm=llm, background_voice=background_voice, occupation=user_occupation, memoir_correlation_id=cid, ) db.flush() if chapter is None: logger.error( "event=generate_chapter_content_no_chapter user_id={} stage={}", user_id, stage, ) db.rollback() raise self.retry( exc=RuntimeError("story_pipeline returned no chapter"), countdown=30, ) db.commit() db.refresh(chapter) from app.features.story.post_commit import enqueue_story_post_commit_effects ch_ids: set[str] = {str(chapter.id)} pc = enqueue_story_post_commit_effects( user_id=user_id, story_ids=set(dispatch_ids), chapter_ids=ch_ids, trigger_source="pipeline", need_compaction=False, ) logger.info( "event=story_post_commit user_id={} trigger=pipeline_generate_chapter " "enqueued_story_image_count={} enqueued_chapter_recompose_count={} " "compaction_scheduled={} errors={}", user_id, pc.enqueued_story_image_count, pc.enqueued_chapter_recompose_count, pc.compaction_scheduled, pc.errors, ) image_settings = MemoirImageSettings.from_env() if ( image_settings.enabled and chapter and chapter_needs_cover_enqueue(chapter) ): from app.tasks.chapter_cover_enqueue import ( try_enqueue_generate_chapter_cover, ) try_enqueue_generate_chapter_cover(chapter.id, source="pipeline") return {"status": "success"} except Retry: raise except Exception as e: logger.error(f"章节生成失败: {e}") raise self.retry(exc=e) from e