""" 回忆录处理 Celery 任务 """ import json import logging import os import uuid from typing import Dict, List from datetime import datetime, timezone import redis from celery import shared_task from sqlalchemy import select from sqlalchemy.orm import Session from database.database import SessionLocal from database.models import Book, Chapter, Segment, MemoirState, User from services.llm_service import llm_service from agents.state_schema import MemoirStateSchema, SlotData, default_state from agents.prompts.memory_prompts import ( get_creative_title_prompt, get_narrative_prompt, get_state_extraction_prompt, get_chapter_classification_prompt, STAGE_TO_ORDER, CHAPTER_CATEGORIES, ) from agents.prompts.profile_prompts import format_user_profile_context import hashlib from services.memoir_images.parser import build_initial_image_assets, parse_image_placeholders from services.memoir_images.prompting import MemoirImagePromptService from services.memoir_images.provider import LiblibImageProvider from services.memoir_images.settings import MemoirImageSettings from services.memoir_images.storage import TencentCosStorageService logger = logging.getLogger(__name__) def _acquire_chapter_lock(user_id: str, stage: str, timeout: int = 120) -> bool: """获取章节分布式锁,防止并发写入同一章节""" r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0")) lock_key = f"lock:chapter:{user_id}:{stage}" return r.set(lock_key, "1", nx=True, ex=timeout) def _release_chapter_lock(user_id: str, stage: str): """释放章节分布式锁""" r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0")) lock_key = f"lock:chapter:{user_id}:{stage}" r.delete(lock_key) def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None): """同步更新任务状态(在 Celery 任务中使用)""" try: redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") r = redis.from_url(redis_url, decode_responses=True) key = f"task:user:{user_id}:tasks" # 获取现有任务信息 data = r.hget(key, task_id) if data: task_info = json.loads(data) else: task_info = {"task_id": task_id} task_info["status"] = status task_info["updated_at"] = datetime.now(timezone.utc).isoformat() if result is not None: task_info["result"] = result r.hset(key, task_id, json.dumps(task_info)) r.expire(key, 3600) # 1小时过期 logger.info(f"任务状态已更新: task_id={task_id}, status={status}") except Exception as e: logger.error(f"更新任务状态失败: {e}") def _merge_chapter_image_assets( existing_images: list[dict] | None, placeholders: list[dict], provider: str, style: str, size: str, now_iso: str, ) -> list[dict]: existing_by_placeholder = { item.get("placeholder"): dict(item) for item in (existing_images or []) if item.get("placeholder") } merged_assets: list[dict] = [] for item in placeholders: existing = existing_by_placeholder.get(item["placeholder"]) if existing: merged_item = dict(existing) merged_item["index"] = item["index"] merged_item["placeholder"] = item["placeholder"] merged_item["description"] = item["description"] merged_item["provider"] = merged_item.get("provider") or provider merged_item["style"] = merged_item.get("style") or style merged_item["size"] = merged_item.get("size") or size merged_item["created_at"] = merged_item.get("created_at") or now_iso merged_item["updated_at"] = merged_item.get("updated_at") or now_iso if merged_item.get("status") == "completed" and not ( merged_item.get("storage_key") or merged_item.get("url") ): merged_item["status"] = "failed" merged_item["error"] = merged_item.get("error") or "missing image url" else: merged_item = build_initial_image_assets( placeholders=[item], provider=provider, style=style, size=size, now_iso=now_iso, )[0] merged_assets.append(merged_item) return merged_assets def chapter_has_images_to_generate(images: list[dict] | None) -> bool: return any(item.get("status") in {"pending", "failed"} for item in (images or [])) def initialize_chapter_images(chapter) -> list[dict]: """Parse IMAGE placeholders from chapter content and build pending image assets.""" settings = MemoirImageSettings.from_env() if not settings.enabled: chapter.images = [] logger.info(f"章节图片初始化跳过: chapter={chapter.id}, enabled=false") return chapter.images prompt_service = MemoirImagePromptService(llm=None, settings=settings) placeholders = parse_image_placeholders(chapter.content, settings.max_per_chapter) style = prompt_service.CATEGORY_STYLE_MAP.get(chapter.category, settings.default_style) chapter.images = _merge_chapter_image_assets( existing_images=chapter.images, placeholders=placeholders, provider=settings.provider, style=style, size=settings.default_size, now_iso=datetime.now(timezone.utc).isoformat(), ) logger.info( "章节图片初始化完成: chapter=%s, placeholders=%d, images=%d, statuses=%s", chapter.id, len(placeholders), len(chapter.images or []), [item.get("status") for item in (chapter.images or [])], ) return chapter.images STAGE_KEYWORDS = { "childhood": ["童年", "小时候", "出生", "家乡", "小镇"], "education": ["上学", "学校", "老师", "同学", "教育", "大学"], "career": ["工作", "职业", "事业", "公司", "同事", "创业"], "family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"], "belief": ["信念", "价值观", "座右铭", "坚持", "原则"], } # 5-stage → 默认 8-category 映射(LLM 分类失败时的兜底) _STAGE_TO_DEFAULT_CATEGORY = { "childhood": "childhood", "education": "education", "career": "career_early", "family": "family", "belief": "beliefs", } def _detect_stage(user_message: str, fallback_stage: str) -> str: """检测消息所属的 5-stage 阶段(用于状态跟踪)""" message = user_message.lower() for stage, keywords in STAGE_KEYWORDS.items(): if any(word in message for word in keywords): return stage return fallback_stage def _classify_chapter_category(text: str, fallback_stage: str, llm=None) -> str | None: """ 将内容分类到 8 个章节类别之一。 优先使用 LLM,失败则按 5-stage 关键词映射到默认类别。 如果 LLM 判定内容无实质回忆录价值,返回 None。 """ if llm: try: prompt = get_chapter_classification_prompt(text) response = llm.invoke(prompt) category = response.content.strip().lower() if category == "none": logger.info(f"LLM 判定内容无回忆录价值,跳过: {text[:80]}...") return None if category in CHAPTER_CATEGORIES: return category except Exception as e: logger.warning(f"LLM 章节分类失败: {e}") stage = _detect_stage(text, fallback_stage) return _STAGE_TO_DEFAULT_CATEGORY.get(stage, _STAGE_TO_DEFAULT_CATEGORY.get(fallback_stage, "childhood")) def _coerce_state(model: MemoirState) -> MemoirStateSchema: """将数据库模型转换为 Schema""" return MemoirStateSchema.model_validate( { "stage_order": model.stage_order or default_state().stage_order, "current_stage": model.current_stage, "covered_stages": model.covered_stages or [], "slots": model.slots if isinstance(model.slots, dict) else default_state().slots, } ) def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema: """同步获取或创建状态""" stmt = select(MemoirState).where(MemoirState.user_id == user_id) result = db.execute(stmt) state = result.scalar_one_or_none() if state: return _coerce_state(state) default = default_state() state = MemoirState( id=str(uuid.uuid4()), user_id=user_id, stage_order=default.stage_order, current_stage=default.current_stage, covered_stages=default.covered_stages, slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()}, ) db.add(state) db.commit() db.refresh(state) return _coerce_state(state) def _update_slot_sync( user_id: str, stage: str, slot_name: str, snippet: str, segment_ids: List[str], db: Session, ) -> MemoirStateSchema: """同步更新 slot""" stmt = select(MemoirState).where(MemoirState.user_id == user_id) result = db.execute(stmt) state = result.scalar_one_or_none() if not state: _get_or_create_state_sync(user_id, db) result = db.execute(stmt) state = result.scalar_one() slots: Dict[str, Dict] = state.slots or {} stage_slots = slots.get(stage, {}) existing = stage_slots.get(slot_name, {}) merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump() slots[stage] = stage_slots state.slots = slots state.current_stage = state.current_stage or stage db.commit() db.refresh(state) return _coerce_state(state) @shared_task(bind=True, max_retries=3, default_retry_delay=60) def process_memoir_segments(self, user_id: str, segment_ids: List[str]): """ 处理回忆录段落的 Celery 任务 Args: user_id: 用户 ID segment_ids: 段落 ID 列表 """ task_id = self.request.id logger.info(f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}") # 更新任务状态为 running _update_task_status_sync(user_id, task_id, "running") try: db = SessionLocal() try: chapters_to_enqueue: set[str] = set() # 获取段落 stmt = select(Segment).where(Segment.id.in_(segment_ids)) result = db.execute(stmt) segments = result.scalars().all() if not segments: logger.warning(f"未找到段落: {segment_ids}") return {"status": "no_segments"} # 获取用户状态和资料 state = _get_or_create_state_sync(user_id, db) llm = llm_service.get_llm() user_obj = db.get(User, user_id) user_profile = "" user_birth_year = None if user_obj: user_birth_year = user_obj.birth_year user_profile = format_user_profile_context( birth_year=user_obj.birth_year, birth_place=user_obj.birth_place, grew_up_place=user_obj.grew_up_place, occupation=user_obj.occupation, ) # 分两步处理: # 1) 5-stage 状态跟踪(slots) # 2) 8-category 章节分类(chapter creation) category_to_segments: Dict[str, List[Segment]] = {} for segment in segments: text = segment.transcript_text detected_stage = _detect_stage(text, state.current_stage) # 提取 slots(5-stage 状态跟踪) extracted_slots = {} if llm: try: prompt = get_state_extraction_prompt( user_message=text, current_stage=state.current_stage, stage_slots=state.slots.get(detected_stage, {}), ) response = llm.invoke(prompt) content = response.content.strip() parsed = json.loads(content) detected_stage = parsed.get("detected_stage", detected_stage) extracted_slots = parsed.get("slots", {}) or {} except (json.JSONDecodeError, Exception) as e: logger.warning(f"LLM 解析失败: {e}") for slot_name, snippet in extracted_slots.items(): state = _update_slot_sync( user_id=user_id, stage=detected_stage, slot_name=slot_name, snippet=snippet, segment_ids=[segment.id], db=db, ) # 8-category 章节分类 chapter_category = _classify_chapter_category(text, detected_stage, llm) if chapter_category is None: logger.info(f"段落无回忆录价值,跳过: segment_id={segment.id}") continue category_to_segments.setdefault(chapter_category, []).append(segment) # 按 8 分类生成章节内容 for stage, stage_segments in category_to_segments.items(): if not _acquire_chapter_lock(user_id, stage): logger.warning(f"章节锁竞争: user={user_id}, stage={stage}, 延迟重试") raise self.retry(countdown=10) try: segment_texts = [seg.transcript_text for seg in stage_segments] combined_text = "\n\n".join(segment_texts) source_ids = [seg.id for seg in stage_segments] # 查找 active 章节(被清除的章节不继续更新,而是创建新的) stmt_chapter = select(Chapter).where( Chapter.user_id == user_id, Chapter.category == stage, Chapter.is_active == True, ) result_chapter = db.execute(stmt_chapter) chapter = result_chapter.scalar_one_or_none() # 获取 slot snippets slot_snippets = { key: value.snippet for key, value in (state.slots.get(stage, {}) or {}).items() if value.snippet } # 生成标题和内容 title = chapter.title if chapter else f"{stage} 回忆" existing_content = chapter.content if chapter else "" narrative = combined_text if llm: try: if not chapter: title_prompt = get_creative_title_prompt( stage=stage, emotion="neutral", slots=slot_snippets, user_profile=user_profile, birth_year=user_birth_year, ) title_response = llm.invoke(title_prompt) title = title_response.content.strip().strip('"') narrative_prompt = get_narrative_prompt( stage=stage, slots=slot_snippets, new_content=combined_text, existing_content=existing_content, user_profile=user_profile, birth_year=user_birth_year, ) narrative_response = llm.invoke(narrative_prompt) new_narrative = narrative_response.content.strip() # 追加而非替换 if existing_content: narrative = f"{existing_content}\n\n{new_narrative}" else: narrative = new_narrative except Exception as e: logger.warning(f"LLM 生成失败: {e}") if existing_content: narrative = f"{existing_content}\n\n{combined_text}" # 安全检查:新内容不应比旧内容短 if existing_content and len(narrative) < len(existing_content) * 0.8: logger.warning( f"内容长度异常: existing={len(existing_content)}, " f"new={len(narrative)}, stage={stage}. 回退为追加模式" ) narrative = f"{existing_content}\n\n{combined_text}" # 更新或创建章节 if chapter: chapter.content = narrative chapter.title = title chapter.is_new = True chapter.source_segments = list({*(chapter.source_segments or []), *source_ids}) else: # 根据 stage 计算正确的排序索引 calculated_order_index = STAGE_TO_ORDER.get(stage, 999) chapter = Chapter( id=str(uuid.uuid4()), user_id=user_id, title=title, content=narrative, order_index=calculated_order_index, status="completed", category=stage, images=[], is_new=True, source_segments=source_ids, ) db.add(chapter) db.flush() initialize_chapter_images(chapter) if chapter_has_images_to_generate(chapter.images): chapters_to_enqueue.add(chapter.id) # 更新 Book stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()) result_book = db.execute(stmt_book) book = result_book.scalar_one_or_none() if not book: book = Book( id=str(uuid.uuid4()), user_id=user_id, title="我的回忆录", total_pages=0, total_words=0, cover_image_url=None, ) db.add(book) book.has_update = True book.last_update_chapter_id = chapter.id finally: _release_chapter_lock(user_id, stage) # 标记段落为已处理 for seg in segments: seg.processed = True db.commit() for chapter_id in sorted(chapters_to_enqueue): try: logger.info(f"派发章节补图任务: chapter={chapter_id}") generate_chapter_images.delay(chapter_id) except Exception as exc: logger.warning(f"补图任务派发失败: chapter={chapter_id}, error={exc}") logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}") # 更新任务状态为成功 _update_task_status_sync(user_id, task_id, "success", {"processed": len(segments)}) return {"status": "success", "processed": len(segments)} finally: db.close() except Exception as e: logger.error(f"回忆录处理失败: {e}") # 更新任务状态为失败 _update_task_status_sync(user_id, task_id, "failure", {"error": str(e)}) # 重试 raise self.retry(exc=e) @shared_task(bind=True, max_retries=3, default_retry_delay=30) def generate_chapter_content(self, user_id: str, stage: str, new_content: str): """ 单独生成章节内容的任务(用于实时更新) Args: user_id: 用户 ID stage: 阶段 new_content: 新内容 """ logger.info(f"生成章节内容: user_id={user_id}, stage={stage}") try: db = SessionLocal() try: llm = llm_service.get_llm() # 查找 active 章节(被清除的章节不继续更新,而是创建新的) stmt = select(Chapter).where( Chapter.user_id == user_id, Chapter.category == stage, Chapter.is_active == True, ) result = db.execute(stmt) chapter = result.scalar_one_or_none() existing_content = chapter.content if chapter else "" if llm: prompt = get_narrative_prompt( stage=stage, slots={}, new_content=new_content, existing_content=existing_content, ) response = llm.invoke(prompt) new_narrative = response.content.strip() # 追加而非替换 if existing_content: narrative = f"{existing_content}\n\n{new_narrative}" else: narrative = new_narrative else: narrative = f"{existing_content}\n\n{new_content}" if existing_content else new_content # 安全检查:新内容不应比旧内容短 if existing_content and len(narrative) < len(existing_content) * 0.8: logger.warning( f"内容长度异常: existing={len(existing_content)}, " f"new={len(narrative)}, stage={stage}. 回退为追加模式" ) narrative = f"{existing_content}\n\n{new_content}" if chapter: chapter.content = narrative chapter.is_new = True else: # 根据 stage 计算正确的排序索引 calculated_order_index = STAGE_TO_ORDER.get(stage, 999) chapter = Chapter( id=str(uuid.uuid4()), user_id=user_id, title=f"{stage} 回忆", content=narrative, order_index=calculated_order_index, status="completed", category=stage, images=[], is_new=True, source_segments=[], ) db.add(chapter) db.commit() return {"status": "success"} finally: db.close() except Exception as e: logger.error(f"章节生成失败: {e}") raise self.retry(exc=e) def build_cos_key(user_id: str, chapter_id: str, index: int, prompt: str) -> str: short_hash = hashlib.sha1(prompt.encode("utf-8")).hexdigest()[:10] return f"memoirs/{user_id}/{chapter_id}/{index}-{short_hash}.png" @shared_task(bind=True, max_retries=3, default_retry_delay=30) def generate_chapter_images(self, chapter_id: str): """Async task to generate images for a chapter's pending image assets.""" db = SessionLocal() try: chapter = db.get(Chapter, chapter_id) if not chapter or not chapter.images: logger.info(f"章节补图跳过: chapter={chapter_id}, reason=no_images") return {"status": "no_images"} settings = MemoirImageSettings.from_env() prompt_service = MemoirImagePromptService(llm_service.get_llm(), settings) provider = LiblibImageProvider(template_uuid=settings.liblib_template_uuid) storage = TencentCosStorageService.from_env() images = [dict(item) for item in (chapter.images or [])] pending_count = sum(1 for item in images if item.get("status") in {"pending", "failed"}) logger.info( "章节补图开始: chapter=%s, total_images=%d, pending_images=%d", chapter_id, len(images), pending_count, ) for index, item in enumerate(images): if item.get("status") == "completed" and (item.get("storage_key") or item.get("url")): continue if item.get("status") not in {"pending", "failed"}: continue current_item = dict(item) current_item["status"] = "processing" current_item["updated_at"] = datetime.now(timezone.utc).isoformat() images[index] = current_item chapter.images = images db.commit() try: context_lines = (chapter.content or "").split("\n") context_excerpt = " ".join(context_lines[:5])[:200] prompt_data = prompt_service.build_prompt( chapter_title=chapter.title, chapter_category=chapter.category or "", description=item.get("description", ""), context_excerpt=context_excerpt, ) job = provider.submit_generation( prompt=prompt_data["prompt"], size=prompt_data["size"], style=prompt_data["style"], ) if job["status"] != "completed": job = provider.poll_until_complete( job, poll_interval_seconds=settings.poll_interval_seconds, max_attempts=settings.max_attempts, ) image_bytes = provider.download_image(job) key = build_cos_key(chapter.user_id, chapter.id, current_item["index"], prompt_data["prompt"]) current_item["storage_key"] = key current_item["url"] = storage.upload_bytes(image_bytes, key, "image/png") current_item["prompt"] = prompt_data["prompt"] current_item["style"] = prompt_data["style"] current_item["size"] = prompt_data["size"] current_item["status"] = "completed" current_item["error"] = None logger.info( "章节补图成功: chapter=%s, index=%s, url=%s", chapter_id, current_item.get("index"), current_item["url"], ) except Exception as exc: current_item["status"] = "failed" current_item["error"] = str(exc) logger.warning(f"图片生成失败: chapter={chapter_id}, index={current_item.get('index')}, error={exc}") current_item["updated_at"] = datetime.now(timezone.utc).isoformat() images[index] = current_item chapter.images = images db.commit() return {"status": "success"} finally: db.close()