""" 回忆录处理 Celery 任务 """ import json import uuid from datetime import datetime, timezone from typing import Dict, List, Set import redis from celery import shared_task from sqlalchemy import select from sqlalchemy.orm import Session from app.agents.chat.prompts_profile import format_user_profile_context from app.agents.memoir import MemoirOrchestrator from app.agents.state_schema import MemoirStateSchema, SlotData, default_state from app.core.db import get_sync_db from app.core.dependencies import get_llm_provider from app.core.logging import get_logger from app.features.conversation.models import Segment from app.features.memoir.cover_eligibility import ( chapter_needs_cover_enqueue, ) from app.features.memoir.memoir_images.parser import ( build_initial_image_assets, ) from app.features.memoir.memoir_images.schema import ( IMAGE_STATUS_COMPLETED, IMAGE_STATUS_FAILED, IMAGE_STATUS_PENDING, normalize_image_assets, ) from app.features.memoir.memoir_images.serializers import ( image_dict_to_row_kwargs, ) from app.features.memoir.memoir_images.settings import MemoirImageSettings from app.features.memoir.models import ( Book, MemoirImage, MemoirState, ) from app.features.memoir.story_pipeline_sync import ( run_story_pipeline_for_category_batch, ) from app.features.user.models import User logger = get_logger(__name__) _REDIS_CLIENTS: dict[bool, redis.Redis] = {} def _get_llm(): """Celery 任务内获取 LangChain LLM(通过 port)""" try: return getattr(get_llm_provider(), "langchain_llm", None) except Exception: return None def _get_redis_client(*, decode_responses: bool = False) -> redis.Redis: from app.core.config import settings client = _REDIS_CLIENTS.get(decode_responses) if client is None: client = redis.from_url( settings.redis_url, decode_responses=decode_responses, ) _REDIS_CLIENTS[decode_responses] = client return client def _acquire_chapter_lock(user_id: str, stage: str, timeout: int = 120) -> bool: """获取章节分布式锁,防止并发写入同一章节""" r = _get_redis_client() lock_key = f"lock:chapter:{user_id}:{stage}" return r.set(lock_key, "1", nx=True, ex=timeout) def _release_chapter_lock(user_id: str, stage: str): """释放章节分布式锁""" r = _get_redis_client() lock_key = f"lock:chapter:{user_id}:{stage}" r.delete(lock_key) def _update_task_status_sync( user_id: str, task_id: str, status: str, result: Dict = None ): """同步更新任务状态(在 Celery 任务中使用)""" try: r = _get_redis_client(decode_responses=True) key = f"task:user:{user_id}:tasks" # 获取现有任务信息 data = r.hget(key, task_id) if data: task_info = json.loads(data) else: task_info = {"task_id": task_id} task_info["status"] = status task_info["updated_at"] = datetime.now(timezone.utc).isoformat() if result is not None: task_info["result"] = result r.hset(key, task_id, json.dumps(task_info)) r.expire(key, 3600) # 1小时过期 logger.debug("任务状态已更新: task_id={} status={}", task_id, status) except Exception as e: logger.error(f"更新任务状态失败: {e}") def _merge_chapter_image_assets( existing_images: list[dict] | None, placeholders: list[dict], provider: str, style: str, size: str, now_iso: str, ) -> list[dict]: normalized_existing_images = normalize_image_assets(existing_images) existing_by_placeholder = { item.get("placeholder"): dict(item) for item in normalized_existing_images if item.get("placeholder") } merged_assets: list[dict] = [] for item in placeholders: existing = existing_by_placeholder.get(item["placeholder"]) if existing: merged_item = dict(existing) merged_item["index"] = item["index"] merged_item["placeholder"] = item["placeholder"] merged_item["description"] = item["description"] merged_item["provider"] = merged_item.get("provider") or provider merged_item["style"] = merged_item.get("style") or style merged_item["size"] = merged_item.get("size") or size merged_item["created_at"] = merged_item.get("created_at") or now_iso merged_item["updated_at"] = merged_item.get("updated_at") or now_iso if merged_item.get("status") == IMAGE_STATUS_COMPLETED and not ( merged_item.get("storage_key") or merged_item.get("url") ): merged_item["status"] = IMAGE_STATUS_FAILED merged_item["error"] = merged_item.get("error") or "missing image url" else: merged_item = build_initial_image_assets( placeholders=[item], provider=provider, style=style, size=size, now_iso=now_iso, )[0] merged_assets.append(merged_item) return merged_assets def chapter_has_images_to_generate(images: list[dict] | None) -> bool: return any( item.get("status") in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED} for item in normalize_image_assets(images) ) def _memoir_image_from_asset( chapter_id: str, order_index: int, image_asset: dict, ) -> MemoirImage: """从单条图片 dict 构建 MemoirImage 行(用于写入 memoir_images 表)。""" kwargs = image_dict_to_row_kwargs(image_asset) return MemoirImage( id=str(uuid.uuid4()).replace("-", "")[:32], chapter_id=chapter_id, order_index=order_index, **kwargs, ) def _coerce_state(model: MemoirState) -> MemoirStateSchema: """将数据库模型转换为 Schema""" return MemoirStateSchema.model_validate( { "stage_order": model.stage_order or default_state().stage_order, "current_stage": model.current_stage, "covered_stages": model.covered_stages or [], "slots": model.slots if isinstance(model.slots, dict) else default_state().slots, } ) def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema: """同步获取或创建状态""" stmt = select(MemoirState).where(MemoirState.user_id == user_id) result = db.execute(stmt) state = result.scalar_one_or_none() if state: return _coerce_state(state) default = default_state() state = MemoirState( id=str(uuid.uuid4()), user_id=user_id, stage_order=default.stage_order, current_stage=default.current_stage, covered_stages=default.covered_stages, slots={ k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items() }, ) db.add(state) db.commit() db.refresh(state) return _coerce_state(state) def _update_slot_sync( user_id: str, stage: str, slot_name: str, snippet: str, segment_ids: List[str], db: Session, ) -> MemoirStateSchema: """同步更新 slot""" stmt = select(MemoirState).where(MemoirState.user_id == user_id) result = db.execute(stmt) state = result.scalar_one_or_none() if not state: _get_or_create_state_sync(user_id, db) result = db.execute(stmt) state = result.scalar_one() slots: Dict[str, Dict] = state.slots or {} stage_slots = slots.get(stage, {}) existing = stage_slots.get(slot_name, {}) merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) stage_slots[slot_name] = SlotData( snippet=snippet, segment_ids=merged_segment_ids ).model_dump() slots[stage] = stage_slots state.slots = slots state.current_stage = stage db.commit() db.refresh(state) return _coerce_state(state) @shared_task(bind=True, max_retries=3, default_retry_delay=60) def process_memoir_segments(self, user_id: str, segment_ids: List[str]): """ 处理回忆录段落的 Celery 任务 Args: user_id: 用户 ID segment_ids: 段落 ID 列表 """ task_id = self.request.id logger.info( f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}" ) # 更新任务状态为 running _update_task_status_sync(user_id, task_id, "running") try: with get_sync_db() as db: # 获取段落 stmt = select(Segment).where(Segment.id.in_(segment_ids)) result = db.execute(stmt) segments = result.scalars().all() if not segments: logger.warning(f"未找到段落: {segment_ids}") return {"status": "no_segments"} # Memory ingest 先于回忆录流水线 commit,保证后续 retrieve_evidence_sync 可见本批 chunk # (见 api/docs/memory-retrieval.md) conv_id = getattr(segments[0], "conversation_id", None) or "" transcript = "\n\n".join(seg.user_input_text or "" for seg in segments) if transcript.strip(): try: from app.features.memory.service import ingest_transcript_sync source_id = ingest_transcript_sync(db, user_id, conv_id, transcript) logger.info( "event=memory_transcript_ingested user_id={} task_id={} " "source_id={} conversation_id={} transcript_chars={} " "segment_count={}", user_id, task_id, source_id, conv_id, len(transcript), len(segments), ) except Exception as e: logger.warning( "Memory ingest 跳过: {} exc_type={}", e, type(e).__name__, ) llm = _get_llm() image_settings = MemoirImageSettings.from_env() user_obj = db.get(User, user_id) user_profile = "" user_birth_year = None if user_obj: user_birth_year = user_obj.birth_year user_profile = format_user_profile_context( birth_year=user_obj.birth_year, birth_place=user_obj.birth_place, grew_up_place=user_obj.grew_up_place, occupation=user_obj.occupation, ) story_dispatch_ids: Set[str] = set() memoir_orchestrator = MemoirOrchestrator() prepared = memoir_orchestrator.prepare_batches( segments=list(segments), llm=llm, get_or_create_state=lambda: _get_or_create_state_sync(user_id, db), update_slot=lambda stage, slot_name, snippet, seg_ids: ( _update_slot_sync(user_id, stage, slot_name, snippet, seg_ids, db) ), ) chapters_to_enqueue: Set[str] = set() for ( chapter_category, category_segments, ) in prepared.category_to_segments.items(): if not _acquire_chapter_lock(user_id, chapter_category): logger.warning( "章节锁竞争: category={}, 延迟重试", chapter_category, ) raise self.retry(countdown=10) try: chapter, needs_cover, disp = run_story_pipeline_for_category_batch( db, user_id=user_id, chapter_category=chapter_category, category_segments=category_segments, state=prepared.state, user_profile=user_profile, user_birth_year=user_birth_year, llm=llm, ) story_dispatch_ids |= disp db.flush() db.refresh(chapter) needs_cover_enqueue = ( image_settings.enabled and chapter_needs_cover_enqueue(chapter) ) stmt_book = ( select(Book) .where(Book.user_id == user_id) .order_by(Book.updated_at.desc()) ) result_book = db.execute(stmt_book) book = result_book.scalar_one_or_none() if not book: book = Book( id=str(uuid.uuid4()), user_id=user_id, title="我的回忆录", total_pages=0, total_words=0, cover_image_url=None, ) db.add(book) book.has_update = True book.last_update_chapter_id = chapter.id if chapter and needs_cover_enqueue: chapters_to_enqueue.add(chapter.id) finally: _release_chapter_lock(user_id, chapter_category) # 标记段落为已处理 for seg in segments: seg.processed = True db.commit() from app.tasks.chapter_compose_tasks import recompose_chapters_for_story from app.tasks.story_image_tasks import generate_story_image for sid in story_dispatch_ids: try: generate_story_image.delay(sid) except Exception as exc: logger.warning("generate_story_image delay: {}", exc) try: recompose_chapters_for_story.delay(sid) except Exception as exc: logger.warning("recompose_chapters_for_story delay: {}", exc) from app.tasks.chapter_cover_enqueue import ( try_enqueue_generate_chapter_cover, ) for chapter_id in sorted(chapters_to_enqueue): if try_enqueue_generate_chapter_cover(chapter_id, source="pipeline"): logger.info(f"派发章节封面任务: chapter={chapter_id}") categories_processed = sorted(prepared.category_to_segments.keys()) logger.info( "回忆录处理完成: user_id={} task_id={} segment_count={} " "categories_processed={}", user_id, task_id, len(segments), categories_processed, ) # 更新任务状态为成功 _update_task_status_sync( user_id, task_id, "success", { "processed": len(segments), "categories_processed": categories_processed, }, ) return { "status": "success", "processed": len(segments), "categories_processed": categories_processed, } except Exception as e: logger.error(f"回忆录处理失败: {e}") # 更新任务状态为失败 _update_task_status_sync(user_id, task_id, "failure", {"error": str(e)}) # 重试 raise self.retry(exc=e) from e @shared_task(bind=True, max_retries=3, default_retry_delay=30) def generate_chapter_content(self, user_id: str, stage: str, new_content: str): """ 单独生成章节内容的任务(用于实时更新) Args: user_id: 用户 ID stage: 阶段 new_content: 新内容 """ logger.info(f"生成章节内容: user_id={user_id}, stage={stage}") try: with get_sync_db() as db: llm = _get_llm() user_obj = db.get(User, user_id) user_profile = "" user_birth_year = None if user_obj: user_birth_year = user_obj.birth_year user_profile = format_user_profile_context( birth_year=user_obj.birth_year, birth_place=user_obj.birth_place, grew_up_place=user_obj.grew_up_place, occupation=user_obj.occupation, ) class _Seg: def __init__(self, text: str): self.id = str(uuid.uuid4()) self.user_input_text = text state = _get_or_create_state_sync(user_id, db) chapter, _, dispatch_ids = run_story_pipeline_for_category_batch( db, user_id=user_id, chapter_category=stage, category_segments=[_Seg(new_content)], state=state, user_profile=user_profile, user_birth_year=user_birth_year, llm=llm, ) db.commit() db.refresh(chapter) from app.tasks.chapter_compose_tasks import recompose_chapters_for_story from app.tasks.story_image_tasks import generate_story_image for sid in dispatch_ids: try: generate_story_image.delay(sid) except Exception as exc: logger.warning("generate_story_image delay: {}", exc) try: recompose_chapters_for_story.delay(sid) except Exception as exc: logger.warning("recompose_chapters_for_story delay: {}", exc) image_settings = MemoirImageSettings.from_env() if ( image_settings.enabled and chapter and chapter_needs_cover_enqueue(chapter) ): from app.tasks.chapter_cover_enqueue import ( try_enqueue_generate_chapter_cover, ) try_enqueue_generate_chapter_cover(chapter.id, source="pipeline") return {"status": "success"} except Exception as e: logger.error(f"章节生成失败: {e}") raise self.retry(exc=e) from e