""" 回忆录处理 Celery 任务 """ import json import logging import os import uuid from typing import Dict, List from datetime import datetime, timezone import redis from celery import shared_task from sqlalchemy import select from sqlalchemy.orm import Session from database.database import SessionLocal from database.models import Book, Chapter, Segment, MemoirState from services.llm_service import llm_service from agents.state_schema import MemoirStateSchema, SlotData, default_state from agents.prompts.memory_prompts import ( get_creative_title_prompt, get_narrative_prompt, get_state_extraction_prompt, STAGE_TO_ORDER, ) logger = logging.getLogger(__name__) def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None): """同步更新任务状态(在 Celery 任务中使用)""" try: redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") r = redis.from_url(redis_url, decode_responses=True) key = f"task:user:{user_id}:tasks" # 获取现有任务信息 data = r.hget(key, task_id) if data: task_info = json.loads(data) else: task_info = {"task_id": task_id} task_info["status"] = status task_info["updated_at"] = datetime.now(timezone.utc).isoformat() if result is not None: task_info["result"] = result r.hset(key, task_id, json.dumps(task_info)) r.expire(key, 3600) # 1小时过期 logger.info(f"任务状态已更新: task_id={task_id}, status={status}") except Exception as e: logger.error(f"更新任务状态失败: {e}") STAGE_KEYWORDS = { "childhood": ["童年", "小时候", "出生", "家乡", "小镇"], "education": ["上学", "学校", "老师", "同学", "教育", "大学"], "career": ["工作", "职业", "事业", "公司", "同事", "创业"], "family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"], "belief": ["信念", "价值观", "座右铭", "坚持", "原则"], } def _detect_stage(user_message: str, fallback_stage: str) -> str: """检测消息所属阶段""" message = user_message.lower() for stage, keywords in STAGE_KEYWORDS.items(): if any(word in message for word in keywords): return stage return fallback_stage def _coerce_state(model: MemoirState) -> MemoirStateSchema: """将数据库模型转换为 Schema""" return MemoirStateSchema.model_validate( { "stage_order": model.stage_order or default_state().stage_order, "current_stage": model.current_stage, "covered_stages": model.covered_stages or [], "slots": model.slots if isinstance(model.slots, dict) else default_state().slots, } ) def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema: """同步获取或创建状态""" stmt = select(MemoirState).where(MemoirState.user_id == user_id) result = db.execute(stmt) state = result.scalar_one_or_none() if state: return _coerce_state(state) default = default_state() state = MemoirState( id=str(uuid.uuid4()), user_id=user_id, stage_order=default.stage_order, current_stage=default.current_stage, covered_stages=default.covered_stages, slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()}, ) db.add(state) db.commit() db.refresh(state) return _coerce_state(state) def _update_slot_sync( user_id: str, stage: str, slot_name: str, snippet: str, segment_ids: List[str], db: Session, ) -> MemoirStateSchema: """同步更新 slot""" stmt = select(MemoirState).where(MemoirState.user_id == user_id) result = db.execute(stmt) state = result.scalar_one_or_none() if not state: _get_or_create_state_sync(user_id, db) result = db.execute(stmt) state = result.scalar_one() slots: Dict[str, Dict] = state.slots or {} stage_slots = slots.get(stage, {}) existing = stage_slots.get(slot_name, {}) merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump() slots[stage] = stage_slots state.slots = slots state.current_stage = state.current_stage or stage db.commit() db.refresh(state) return _coerce_state(state) @shared_task(bind=True, max_retries=3, default_retry_delay=60) def process_memoir_segments(self, user_id: str, segment_ids: List[str]): """ 处理回忆录段落的 Celery 任务 Args: user_id: 用户 ID segment_ids: 段落 ID 列表 """ task_id = self.request.id logger.info(f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}") # 更新任务状态为 running _update_task_status_sync(user_id, task_id, "running") try: db = SessionLocal() try: # 获取段落 stmt = select(Segment).where(Segment.id.in_(segment_ids)) result = db.execute(stmt) segments = result.scalars().all() if not segments: logger.warning(f"未找到段落: {segment_ids}") return {"status": "no_segments"} # 获取用户状态 state = _get_or_create_state_sync(user_id, db) llm = llm_service.get_llm() # 按阶段分组处理 stage_to_segments: Dict[str, List[Segment]] = {} for segment in segments: text = segment.transcript_text detected_stage = _detect_stage(text, state.current_stage) # 尝试使用 LLM 提取信息 extracted_slots = {} if llm: try: prompt = get_state_extraction_prompt( user_message=text, current_stage=state.current_stage, stage_slots=state.slots.get(detected_stage, {}), ) response = llm.invoke(prompt) content = response.content.strip() parsed = json.loads(content) detected_stage = parsed.get("detected_stage", detected_stage) extracted_slots = parsed.get("slots", {}) or {} except (json.JSONDecodeError, Exception) as e: logger.warning(f"LLM 解析失败: {e}") # 更新 slots for slot_name, snippet in extracted_slots.items(): state = _update_slot_sync( user_id=user_id, stage=detected_stage, slot_name=slot_name, snippet=snippet, segment_ids=[segment.id], db=db, ) stage_to_segments.setdefault(detected_stage, []).append(segment) # 生成章节内容 for stage, stage_segments in stage_to_segments.items(): segment_texts = [seg.transcript_text for seg in stage_segments] combined_text = "\n\n".join(segment_texts) source_ids = [seg.id for seg in stage_segments] # 查找或创建章节 stmt_chapter = select(Chapter).where( Chapter.user_id == user_id, Chapter.category == stage, ) result_chapter = db.execute(stmt_chapter) chapter = result_chapter.scalar_one_or_none() # 获取 slot snippets slot_snippets = { key: value.snippet for key, value in (state.slots.get(stage, {}) or {}).items() if value.snippet } # 生成标题和内容 title = chapter.title if chapter else f"{stage} 回忆" existing_content = chapter.content if chapter else "" narrative = combined_text if llm: try: if not chapter: title_prompt = get_creative_title_prompt( stage=stage, emotion="neutral", slots=slot_snippets ) title_response = llm.invoke(title_prompt) title = title_response.content.strip().strip('"') narrative_prompt = get_narrative_prompt( stage=stage, slots=slot_snippets, new_content=combined_text, existing_content=existing_content, ) narrative_response = llm.invoke(narrative_prompt) narrative = narrative_response.content.strip() except Exception as e: logger.warning(f"LLM 生成失败: {e}") if existing_content: narrative = f"{existing_content}\n\n{combined_text}" # 更新或创建章节 if chapter: chapter.content = narrative chapter.title = title chapter.is_new = True chapter.source_segments = list({*(chapter.source_segments or []), *source_ids}) else: # 根据 stage 计算正确的排序索引 calculated_order_index = STAGE_TO_ORDER.get(stage, 999) chapter = Chapter( id=str(uuid.uuid4()), user_id=user_id, title=title, content=narrative, order_index=calculated_order_index, status="completed", category=stage, images=[], is_new=True, source_segments=source_ids, ) db.add(chapter) db.flush() # 更新 Book stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()) result_book = db.execute(stmt_book) book = result_book.scalar_one_or_none() if not book: book = Book( id=str(uuid.uuid4()), user_id=user_id, title="我的回忆录", total_pages=0, total_words=0, cover_image_url=None, ) db.add(book) book.has_update = True book.last_update_chapter_id = chapter.id # 标记段落为已处理 for seg in segments: seg.processed = True db.commit() logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}") # 更新任务状态为成功 _update_task_status_sync(user_id, task_id, "success", {"processed": len(segments)}) return {"status": "success", "processed": len(segments)} finally: db.close() except Exception as e: logger.error(f"回忆录处理失败: {e}") # 更新任务状态为失败 _update_task_status_sync(user_id, task_id, "failure", {"error": str(e)}) # 重试 raise self.retry(exc=e) @shared_task(bind=True, max_retries=3, default_retry_delay=30) def generate_chapter_content(self, user_id: str, stage: str, new_content: str): """ 单独生成章节内容的任务(用于实时更新) Args: user_id: 用户 ID stage: 阶段 new_content: 新内容 """ logger.info(f"生成章节内容: user_id={user_id}, stage={stage}") try: db = SessionLocal() try: llm = llm_service.get_llm() # 查找章节 stmt = select(Chapter).where( Chapter.user_id == user_id, Chapter.category == stage, ) result = db.execute(stmt) chapter = result.scalar_one_or_none() existing_content = chapter.content if chapter else "" if llm: prompt = get_narrative_prompt( stage=stage, slots={}, new_content=new_content, existing_content=existing_content, ) response = llm.invoke(prompt) narrative = response.content.strip() else: narrative = f"{existing_content}\n\n{new_content}" if existing_content else new_content if chapter: chapter.content = narrative chapter.is_new = True else: # 根据 stage 计算正确的排序索引 calculated_order_index = STAGE_TO_ORDER.get(stage, 999) chapter = Chapter( id=str(uuid.uuid4()), user_id=user_id, title=f"{stage} 回忆", content=narrative, order_index=calculated_order_index, status="completed", category=stage, images=[], is_new=True, source_segments=[], ) db.add(chapter) db.commit() return {"status": "success"} finally: db.close() except Exception as e: logger.error(f"章节生成失败: {e}") raise self.retry(exc=e)