diff --git a/api/README.md b/api/README.md index 775fa6c..54f7ea3 100644 --- a/api/README.md +++ b/api/README.md @@ -105,18 +105,44 @@ from database import init_db init_db() ``` -## 运行服务 +## 快速启动 -### 开发模式 +### 本地开发 ```bash +cd api + +# 1. 启动 Redis +docker-compose -f docker-compose.dev.yml up -d + +# 2. 安装依赖 +pip install -r requirements.txt + +# 3. 启动 API(终端 1) uvicorn main:app --reload --host 0.0.0.0 --port 8000 + +# 4. 启动 Celery Worker(终端 2) +# macOS 使用 solo 池避免 fork 崩溃问题 +celery -A tasks.celery_app worker --loglevel=info --pool=solo + +# Linux/生产环境可以使用 prefork 池 +# celery -A tasks.celery_app worker --loglevel=info --concurrency=4 ``` -### 生产模式 +### 生产部署(一键) ```bash -uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4 +cd api + +# 创建生产配置 +cp .env .env.prod +# 编辑 .env.prod + +# 启动所有服务 +docker-compose up -d + +# 查看日志 +docker-compose logs -f ``` 服务启动后,访问: diff --git a/api/agents/conversation_agent.py b/api/agents/conversation_agent.py index 7b664b2..f890b5a 100644 --- a/api/agents/conversation_agent.py +++ b/api/agents/conversation_agent.py @@ -1,37 +1,54 @@ """ 对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应 +支持异步调用和 Redis 会话存储 """ -from typing import List, Optional +import logging +from typing import List, Optional, Dict, Any -from langchain.chains import ConversationChain -from langchain.memory import ConversationBufferMemory -from langchain.prompts import PromptTemplate +from langchain_core.messages import HumanMessage, AIMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from services.llm_service import llm_service +from services.redis_service import redis_service from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt from .state_schema import MemoirStateSchema +logger = logging.getLogger(__name__) + class ConversationAgent: - """对话 Agent""" + """对话 Agent(支持异步和 Redis 存储)""" def __init__(self): # 使用 LLM 服务获取 LLM 实例 self.llm = llm_service.get_llm() - - # 对话记忆 - self.memories: dict[str, ConversationBufferMemory] = {} - def _get_memory(self, conversation_id: str) -> ConversationBufferMemory: - """获取或创建对话记忆""" - if conversation_id not in self.memories: - self.memories[conversation_id] = ConversationBufferMemory( - return_messages=True, - memory_key="history" - ) - return self.memories[conversation_id] + async def _get_history_messages(self, conversation_id: str) -> List[Any]: + """从 Redis 获取对话历史并转换为 LangChain 消息格式""" + history = await redis_service.get_conversation_history(conversation_id) + messages = [] + for msg in history: + if msg["role"] == "human": + messages.append(HumanMessage(content=msg["content"])) + elif msg["role"] == "ai": + messages.append(AIMessage(content=msg["content"])) + return messages - def generate_response( + async def _save_message(self, conversation_id: str, role: str, content: str): + """保存消息到 Redis""" + await redis_service.add_message(conversation_id, role, content) + + def _format_history_string(self, messages: List[Any]) -> str: + """将消息列表格式化为字符串(用于 prompt)""" + history_parts = [] + for msg in messages: + if isinstance(msg, HumanMessage): + history_parts.append(f"Human: {msg.content}") + elif isinstance(msg, AIMessage): + history_parts.append(f"Assistant: {msg.content}") + return "\n\n".join(history_parts) + + async def generate_response( self, conversation_id: str, user_message: str, @@ -39,7 +56,7 @@ class ConversationAgent: covered_topics: Optional[List[str]] = None ) -> str: """ - 生成 Agent 回应 + 异步生成 Agent 回应 Args: conversation_id: 对话 ID @@ -60,38 +77,39 @@ class ConversationAgent: if not self.llm: return "抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。" - # 获取系统提示词 - system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message) - - # 获取对话记忆 - memory = self._get_memory(conversation_id) - - # 创建对话链 - prompt_template = PromptTemplate( - input_variables=["history", "input"], - template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:" - ) - - chain = ConversationChain( - llm=self.llm, - prompt=prompt_template, - memory=memory, - verbose=False - ) - - # 生成回应 - response = chain.predict(input=user_message) - - return response + try: + # 获取系统提示词 + system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message) + + # 从 Redis 获取对话历史 + history_messages = await self._get_history_messages(conversation_id) + history_string = self._format_history_string(history_messages) + + # 构建完整 prompt + full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" + + # 异步调用 LLM + response = await self.llm.ainvoke(full_prompt) + response_text = response.content if hasattr(response, 'content') else str(response) + + # 保存对话到 Redis + await self._save_message(conversation_id, "human", user_message) + await self._save_message(conversation_id, "ai", response_text) + + return response_text + + except Exception as e: + logger.error(f"生成回应失败: {e}") + return f"抱歉,生成回应时出现错误: {str(e)}" - def generate_response_with_state( + async def generate_response_with_state( self, conversation_id: str, user_message: str, memoir_state: MemoirStateSchema ) -> List[str]: """ - 基于共享状态生成引导式回复 + 基于共享状态异步生成引导式回复 Args: conversation_id: 对话 ID @@ -104,39 +122,44 @@ class ConversationAgent: if not self.llm: return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"] - empty_slots = memoir_state.empty_slots_for_current_stage() - filled_slots = { - key: value.snippet - for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items() - if value.snippet - } + try: + empty_slots = memoir_state.empty_slots_for_current_stage() + filled_slots = { + key: value.snippet + for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items() + if value.snippet + } - system_prompt = get_guided_conversation_prompt( - current_stage=memoir_state.current_stage, - empty_slots=empty_slots, - filled_slots=filled_slots, - user_message=user_message, - ) + system_prompt = get_guided_conversation_prompt( + current_stage=memoir_state.current_stage, + empty_slots=empty_slots, + filled_slots=filled_slots, + user_message=user_message, + ) - memory = self._get_memory(conversation_id) - prompt_template = PromptTemplate( - input_variables=["history", "input"], - template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:" - ) - - chain = ConversationChain( - llm=self.llm, - prompt=prompt_template, - memory=memory, - verbose=False - ) - - response = chain.predict(input=user_message) - - # 支持多条消息,用 [SPLIT] 分隔 - messages = [msg.strip() for msg in response.split("[SPLIT]") if msg.strip()] - # 最多返回 3 条 - return messages[:3] if messages else [response] + # 从 Redis 获取对话历史 + history_messages = await self._get_history_messages(conversation_id) + history_string = self._format_history_string(history_messages) + + # 构建完整 prompt + full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" + + # 异步调用 LLM + response = await self.llm.ainvoke(full_prompt) + response_text = response.content if hasattr(response, 'content') else str(response) + + # 保存对话到 Redis + await self._save_message(conversation_id, "human", user_message) + await self._save_message(conversation_id, "ai", response_text) + + # 支持多条消息,用 [SPLIT] 分隔 + messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + # 最多返回 3 条 + return messages[:3] if messages else [response_text] + + except Exception as e: + logger.error(f"生成回应失败: {e}") + return [f"抱歉,生成回应时出现错误: {str(e)}"] def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage: """ @@ -168,8 +191,7 @@ class ConversationAgent: # 默认返回当前阶段或童年阶段 return ConversationStage.CHILDHOOD - def clear_memory(self, conversation_id: str): - """清除对话记忆""" - if conversation_id in self.memories: - del self.memories[conversation_id] + async def clear_memory(self, conversation_id: str): + """清除对话记忆(从 Redis)""" + await redis_service.clear_conversation_history(conversation_id) diff --git a/api/agents/memoir_processor.py b/api/agents/memoir_processor.py index ecee339..00900c6 100644 --- a/api/agents/memoir_processor.py +++ b/api/agents/memoir_processor.py @@ -6,34 +6,25 @@ - 更新回忆录状态(slots) - 生成/更新章节内容 - 创建创意章节标题 + +使用 Celery 进行后台任务处理,支持可靠的任务队列和重试机制 """ from __future__ import annotations -import asyncio import json -import uuid +import logging from dataclasses import dataclass -from typing import Dict, List, Optional - -from sqlalchemy import select +from typing import Dict, List from agents.state_schema import MemoirStateSchema -from database.database import AsyncSessionLocal -from database.models import Book, Chapter, Segment from services.llm_service import llm_service -from services.memoir_state_service import ( - get_or_create_state, - get_empty_slots, - mark_stage_complete, - switch_stage, - update_slot, -) from .prompts.memory_prompts import ( get_creative_title_prompt, get_narrative_prompt, get_state_extraction_prompt, ) +logger = logging.getLogger(__name__) STAGE_KEYWORDS = { "childhood": ["童年", "小时候", "出生", "家乡", "小镇"], @@ -53,7 +44,7 @@ class AnalysisResult: class ContentAnalyzer: - """对话内容分析""" + """对话内容分析(支持异步)""" def __init__(self) -> None: self.llm = llm_service.get_llm() @@ -79,17 +70,15 @@ class ContentAnalyzer: is_new_chapter = False if self.llm: - prompt = get_state_extraction_prompt( - user_message=user_message, - current_stage=current_state.current_stage, - stage_slots=current_state.slots.get(detected_stage, {}), - ) - # 使用异步调用避免阻塞 - response = await asyncio.get_event_loop().run_in_executor( - None, lambda: self.llm.invoke(prompt) - ) - content = response.content.strip() try: + prompt = get_state_extraction_prompt( + user_message=user_message, + current_stage=current_state.current_stage, + stage_slots=current_state.slots.get(detected_stage, {}), + ) + # 使用异步调用 + response = await self.llm.ainvoke(prompt) + content = response.content.strip() parsed = json.loads(content) detected_stage = parsed.get("detected_stage", detected_stage) extracted_slots = parsed.get("slots", {}) or {} @@ -97,6 +86,9 @@ class ContentAnalyzer: is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter)) except json.JSONDecodeError: extracted_slots = self._fallback_slots(current_state, detected_stage, user_message) + except Exception as e: + logger.error(f"分析消息失败: {e}") + extracted_slots = self._fallback_slots(current_state, detected_stage, user_message) else: extracted_slots = self._fallback_slots(current_state, detected_stage, user_message) @@ -109,7 +101,7 @@ class ContentAnalyzer: class MemoirGenerator: - """回忆录生成与更新""" + """回忆录生成与更新(支持异步)""" def __init__(self) -> None: self.llm = llm_service.get_llm() @@ -117,142 +109,87 @@ class MemoirGenerator: async def generate_chapter_title(self, stage: str, slots: Dict[str, str], emotion: str) -> str: if not self.llm: return f"{stage} 回忆" - prompt = get_creative_title_prompt(stage=stage, emotion=emotion, slots=slots) - # 使用异步调用避免阻塞 - response = await asyncio.get_event_loop().run_in_executor( - None, lambda: self.llm.invoke(prompt) - ) - return response.content.strip().strip('"') + try: + prompt = get_creative_title_prompt(stage=stage, emotion=emotion, slots=slots) + # 使用异步调用 + response = await self.llm.ainvoke(prompt) + return response.content.strip().strip('"') + except Exception as e: + logger.error(f"生成标题失败: {e}") + return f"{stage} 回忆" async def generate_narrative(self, stage: str, slots: Dict[str, str], new_content: str, existing_content: str) -> str: if not self.llm: if existing_content: return f"{existing_content}\n\n{new_content}" return new_content - prompt = get_narrative_prompt(stage=stage, slots=slots, new_content=new_content, existing_content=existing_content) - # 使用异步调用避免阻塞 - response = await asyncio.get_event_loop().run_in_executor( - None, lambda: self.llm.invoke(prompt) - ) - return response.content.strip() + try: + prompt = get_narrative_prompt(stage=stage, slots=slots, new_content=new_content, existing_content=existing_content) + # 使用异步调用 + response = await self.llm.ainvoke(prompt) + return response.content.strip() + except Exception as e: + logger.error(f"生成叙事失败: {e}") + if existing_content: + return f"{existing_content}\n\n{new_content}" + return new_content class BackgroundTaskRunner: - """后台任务调度(去抖)""" + """后台任务调度器(使用 Celery)""" def __init__(self, debounce_seconds: int = 5) -> None: self.debounce_seconds = debounce_seconds - self.pending_tasks: Dict[str, List[str]] = {} - self._scheduled: Dict[str, asyncio.Task] = {} + # 内存中的待处理任务(用于去抖) + self._pending: Dict[str, List[str]] = {} + self._timers: Dict[str, object] = {} self.analyzer = ContentAnalyzer() self.generator = MemoirGenerator() async def queue_message(self, user_id: str, segment_id: str) -> None: - self.pending_tasks.setdefault(user_id, []).append(segment_id) - if user_id in self._scheduled: - self._scheduled[user_id].cancel() - self._scheduled[user_id] = asyncio.create_task(self._debounced_process(user_id)) + """ + 将消息加入处理队列 + + 使用 Celery 延迟任务实现去抖效果 + """ + import asyncio + + # 收集待处理的 segment_ids + self._pending.setdefault(user_id, []).append(segment_id) + + # 取消之前的定时器 + if user_id in self._timers: + self._timers[user_id].cancel() + + # 创建新的定时器 + async def delayed_submit(): + try: + await asyncio.sleep(self.debounce_seconds) + segment_ids = self._pending.pop(user_id, []) + if segment_ids: + # 提交到 Celery + from tasks.memoir_tasks import process_memoir_segments + process_memoir_segments.delay(user_id, segment_ids) + logger.info(f"已提交 Celery 任务: user_id={user_id}, segments={len(segment_ids)}") + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"提交 Celery 任务失败: {e}") + + self._timers[user_id] = asyncio.create_task(delayed_submit()) - async def _debounced_process(self, user_id: str) -> None: - try: - await asyncio.sleep(self.debounce_seconds) - except asyncio.CancelledError: - return - async with AsyncSessionLocal() as db: - await self.process_pending(user_id, db) - - async def process_pending(self, user_id: str, db) -> None: - segment_ids = self.pending_tasks.pop(user_id, []) - if not segment_ids: - return - - stmt = select(Segment).where(Segment.id.in_(segment_ids)) - result = await db.execute(stmt) - segments = result.scalars().all() - if not segments: - return - - state = await get_or_create_state(user_id, db) - stage_to_segments: Dict[str, List[Segment]] = {} - for segment in segments: - analysis = await self.analyzer.analyze_message(segment.transcript_text, state) - detected_stage = analysis.detected_stage - if detected_stage != state.current_stage: - state = await switch_stage(user_id, detected_stage, db) - - for slot_name, snippet in analysis.extracted_slots.items(): - state = await update_slot( - user_id=user_id, - stage=detected_stage, - slot_name=slot_name, - snippet=snippet, - segment_ids=[segment.id], - db=db, - ) - - stage_to_segments.setdefault(detected_stage, []).append(segment) - - for stage, stage_segments in stage_to_segments.items(): - segment_texts = [seg.transcript_text for seg in stage_segments] - combined_text = "\n\n".join(segment_texts) - source_ids = [seg.id for seg in stage_segments] - - stmt_chapter = select(Chapter).where( - Chapter.user_id == user_id, - Chapter.category == stage, - ) - result_chapter = await db.execute(stmt_chapter) - chapter = result_chapter.scalar_one_or_none() - slot_snippets = { - key: value.snippet for key, value in (state.slots.get(stage, {}) or {}).items() if value.snippet - } - title = chapter.title if chapter else await self.generator.generate_chapter_title(stage, slot_snippets, "neutral") - existing_content = chapter.content if chapter else "" - narrative = await self.generator.generate_narrative(stage, slot_snippets, combined_text, existing_content) - - if chapter: - chapter.content = narrative - chapter.title = title - chapter.is_new = True - chapter.source_segments = list({*(chapter.source_segments or []), *source_ids}) - else: - chapter = Chapter( - id=str(uuid.uuid4()), - user_id=user_id, - title=title, - content=narrative, - order_index=999, - status="completed", - category=stage, - images=[], - is_new=True, - source_segments=source_ids, - ) - db.add(chapter) - - await db.flush() - - stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()) - result_book = await db.execute(stmt_book) - book = result_book.scalar_one_or_none() - if not book: - book = Book( - id=str(uuid.uuid4()), - user_id=user_id, - title="我的回忆录", - total_pages=0, - total_words=0, - cover_image_url=None, - ) - db.add(book) - book.has_update = True - book.last_update_chapter_id = chapter.id - - empty_slots = await get_empty_slots(user_id, db) - if not empty_slots: - await mark_stage_complete(user_id, state.current_stage, db) - - for seg in segments: - seg.processed = True - - await db.commit() + async def flush_pending(self, user_id: str) -> None: + """ + 立即提交用户的待处理任务(用于对话结束时) + """ + # 取消定时器 + if user_id in self._timers: + self._timers[user_id].cancel() + del self._timers[user_id] + + # 提交待处理任务 + segment_ids = self._pending.pop(user_id, []) + if segment_ids: + from tasks.memoir_tasks import process_memoir_segments + process_memoir_segments.delay(user_id, segment_ids) + logger.info(f"立即提交 Celery 任务: user_id={user_id}, segments={len(segment_ids)}") diff --git a/api/agents/memory_agent.py b/api/agents/memory_agent.py index 018c918..47fb55b 100644 --- a/api/agents/memory_agent.py +++ b/api/agents/memory_agent.py @@ -1,7 +1,9 @@ """ 回忆录整理 Agent:基于传记结构,将口语改写为书面语,归类到章节 +支持异步调用 """ import json +import logging from typing import List, Dict, Optional from services.llm_service import llm_service @@ -14,17 +16,19 @@ from .prompts import ( CHAPTER_ORDER ) +logger = logging.getLogger(__name__) + class MemoryAgent: - """回忆录整理 Agent""" + """回忆录整理 Agent(支持异步)""" def __init__(self): # 使用 LLM 服务获取 LLM 实例 self.llm = llm_service.get_llm() - def classify_chapter(self, segments_text: str) -> str: + async def classify_chapter(self, segments_text: str) -> str: """ - 分类章节 + 异步分类章节 Args: segments_text: 对话段落文本 @@ -36,28 +40,34 @@ class MemoryAgent: # 如果没有配置 LLM,返回默认类别 return "childhood" - prompt = get_chapter_classification_prompt(segments_text) - - response = self.llm.invoke(prompt) - - # 提取类别 - category = response.content.strip().lower() - - # 验证类别是否有效 - if category in CHAPTER_CATEGORIES: - return category + try: + prompt = get_chapter_classification_prompt(segments_text) + + # 异步调用 LLM + response = await self.llm.ainvoke(prompt) + + # 提取类别 + content = response.content if hasattr(response, 'content') else str(response) + category = content.strip().lower() + + # 验证类别是否有效 + if category in CHAPTER_CATEGORIES: + return category + + except Exception as e: + logger.error(f"分类章节失败: {e}") # 默认返回 childhood return "childhood" - def rewrite_to_literary( + async def rewrite_to_literary( self, segments_text: str, chapter_category: str, existing_content: Optional[str] = None ) -> Dict: """ - 将口语改写为书面语 + 异步将口语改写为书面语 Args: segments_text: 对话段落文本 @@ -76,14 +86,16 @@ class MemoryAgent: "image_suggestions": [] } - prompt = get_text_rewrite_prompt(segments_text, chapter_category, existing_content or "") - - response = self.llm.invoke(prompt) - - # 尝试解析 JSON try: - # 提取 JSON 部分 - content = response.content.strip() + prompt = get_text_rewrite_prompt(segments_text, chapter_category, existing_content or "") + + # 异步调用 LLM + response = await self.llm.ainvoke(prompt) + + # 尝试解析 JSON + content = response.content if hasattr(response, 'content') else str(response) + content = content.strip() + # 移除可能的 markdown 代码块标记 if content.startswith("```json"): content = content[7:] @@ -95,22 +107,31 @@ class MemoryAgent: result = json.loads(content) return result + except json.JSONDecodeError: # 如果解析失败,返回基本结构 return { "title": CHAPTER_CATEGORIES.get(chapter_category, "章节"), - "content": response.content, + "content": response.content if hasattr(response, 'content') else str(response), + "summary": "", + "image_suggestions": [] + } + except Exception as e: + logger.error(f"改写文本失败: {e}") + return { + "title": CHAPTER_CATEGORIES.get(chapter_category, "章节"), + "content": segments_text, "summary": "", "image_suggestions": [] } - def process_segments( + async def process_segments( self, segments: List[Dict], existing_chapters: Optional[Dict[str, Dict]] = None ) -> Dict[str, Dict]: """ - 处理对话段落,生成或更新章节 + 异步处理对话段落,生成或更新章节 Args: segments: 对话段落列表,每个包含 transcript_text @@ -130,8 +151,8 @@ class MemoryAgent: if not text: continue - # 分类 - category = self.classify_chapter(text) + # 异步分类 + category = await self.classify_chapter(text) if category not in segments_by_category: segments_by_category[category] = [] @@ -145,8 +166,8 @@ class MemoryAgent: combined_text = "\n\n".join(texts) existing_content = existing_chapters.get(category, {}).get("content", "") - # 改写为书面语 - result = self.rewrite_to_literary(combined_text, category, existing_content) + # 异步改写为书面语 + result = await self.rewrite_to_literary(combined_text, category, existing_content) # 更新章节 updated_chapters[category] = { diff --git a/api/docker-compose.dev.yml b/api/docker-compose.dev.yml new file mode 100644 index 0000000..02ad24f --- /dev/null +++ b/api/docker-compose.dev.yml @@ -0,0 +1,29 @@ +version: '3.8' + +# 开发环境 Docker Compose +# 使用方法: docker-compose -f docker-compose.dev.yml up -d + +services: + # Redis 服务 + redis: + image: redis:7-alpine + container_name: life-echo-redis-dev + ports: + - "6379:6379" + volumes: + - redis_data_dev:/data + command: redis-server --appendonly yes + restart: unless-stopped + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + +networks: + default: + name: life-echo-dev + +volumes: + redis_data_dev: + driver: local diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 25eda1e..17cb3ce 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -1,6 +1,30 @@ version: '3.8' services: + # Redis 服务(用于会话存储和 Celery 消息队列) + redis: + image: redis:7-alpine + container_name: life-echo-redis + ports: + - "6379:6379" + volumes: + - redis_data:/data + command: redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru + restart: always + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + networks: + - life-echo-network + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" + + # FastAPI 应用 api: build: context: . @@ -9,19 +33,16 @@ services: container_name: life-echo-api-prod ports: - "8000:8000" - # 环境变量文件(优先级:env_file > 镜像中的 .env) - # 如果 .env.prod 不存在,容器会使用构建时打包的 .env 文件 env_file: - .env.prod + environment: + - REDIS_URL=redis://redis:6379/0 volumes: - # 持久化数据库文件(确保数据库文件在容器重启后保留) - # 数据库文件默认在 /app/life_echo.db,挂载到宿主机以持久化 - ./life_echo.db:/app/life_echo.db - # 如果需要将数据库存储在 data 目录,可以在 .env 中设置 DATABASE_URL=sqlite+aiosqlite:///./data/life_echo.db - # 然后取消下面的注释并注释掉上面的挂载 - # - ./data:/app/data restart: always - # 健康检查(使用 Python 内置库,无需额外依赖) + depends_on: + redis: + condition: service_healthy healthcheck: test: ["CMD", "python", "-c", "import http.client; conn = http.client.HTTPConnection('localhost', 8000); conn.request('GET', '/health'); r = conn.getresponse(); exit(0 if r.status == 200 else 1)"] interval: 30s @@ -30,16 +51,84 @@ services: start_period: 10s networks: - life-echo-network - # 日志配置(生产环境推荐) logging: driver: "json-file" options: max-size: "10m" max-file: "3" - # 注意:deploy 配置仅在 Docker Swarm 模式下有效 - # 如果使用 docker-compose up(非 swarm 模式),资源限制需要使用 ulimits 或其他方式 - # 如需资源限制,建议使用 docker run 的 --memory 和 --cpus 参数,或使用 docker-compose 的 ulimits + + # Celery Worker(后台任务处理) + celery-worker: + build: + context: . + dockerfile: Dockerfile + image: life-echo-api:latest + container_name: life-echo-celery-worker + command: celery -A tasks.celery_app worker --loglevel=info --concurrency=4 + env_file: + - .env.prod + environment: + - REDIS_URL=redis://redis:6379/0 + volumes: + - ./life_echo.db:/app/life_echo.db + restart: always + depends_on: + redis: + condition: service_healthy + api: + condition: service_healthy + networks: + - life-echo-network + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" + + # Celery Beat(定时任务调度,可选) + # celery-beat: + # build: + # context: . + # dockerfile: Dockerfile + # image: life-echo-api:latest + # container_name: life-echo-celery-beat + # command: celery -A tasks.celery_app beat --loglevel=info + # env_file: + # - .env.prod + # environment: + # - REDIS_URL=redis://redis:6379/0 + # restart: always + # depends_on: + # redis: + # condition: service_healthy + # networks: + # - life-echo-network + + # Flower(Celery 监控面板,可选) + # flower: + # build: + # context: . + # dockerfile: Dockerfile + # image: life-echo-api:latest + # container_name: life-echo-flower + # command: celery -A tasks.celery_app flower --port=5555 + # ports: + # - "5555:5555" + # env_file: + # - .env.prod + # environment: + # - REDIS_URL=redis://redis:6379/0 + # restart: always + # depends_on: + # redis: + # condition: service_healthy + # networks: + # - life-echo-network networks: life-echo-network: driver: bridge + +volumes: + redis_data: + driver: local diff --git a/api/docs/本地开发环境配置.md b/api/docs/本地开发环境配置.md new file mode 100644 index 0000000..1c37879 --- /dev/null +++ b/api/docs/本地开发环境配置.md @@ -0,0 +1,250 @@ +# Life Echo 本地开发环境配置 + +本文档介绍如何在本地配置和运行 Life Echo 服务,支持异步 LLM 调用、Redis 会话存储和 Celery 后台任务。 + +## 架构概述 + +``` +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ FastAPI API │────▶│ Redis │◀────│ Celery Worker │ +│ (WebSocket) │ │ (会话 + 队列) │ │ (后台任务) │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ + │ │ + │ ┌─────────────────┐ │ + └─────────────▶│ SQLite/DB │◀─────────────┘ + │ (持久化存储) │ + └─────────────────┘ +``` + +## 前置要求 + +- Python 3.10+ +- Docker 和 Docker Compose(用于 Redis) +- 有效的 LLM API Key(DeepSeek 或兼容 OpenAI 的服务) + +## 快速开始 + +### 1. 启动 Redis + +使用 Docker Compose 启动 Redis: + +```bash +cd api +docker-compose -f docker-compose.dev.yml up -d +``` + +验证 Redis 是否运行: + +```bash +docker exec life-echo-redis-dev redis-cli ping +# 应该返回 PONG +``` + +### 2. 配置环境变量 + +创建或编辑 `.env` 文件: + +```bash +cp .env.example .env # 如果有示例文件 +``` + +配置以下环境变量: + +```env +# LLM 配置(DeepSeek) +DEEPSEEK_API_KEY=your_api_key_here +DEEPSEEK_MODEL=deepseek-chat +DEEPSEEK_BASE_URL=https://api.deepseek.com + +# 或者使用通用 LLM 配置 +# LLM_API_KEY=your_api_key +# LLM_MODEL=gpt-4 +# LLM_BASE_URL=https://api.openai.com + +# Redis 配置 +REDIS_URL=redis://localhost:6379/0 +REDIS_SESSION_TTL=86400 # 会话过期时间(秒),默认 24 小时 + +# 数据库配置 +DATABASE_URL=sqlite+aiosqlite:///./life_echo.db + +# JWT 配置 +SECRET_KEY=your-secret-key-change-in-production +ALGORITHM=HS256 +ACCESS_TOKEN_EXPIRE_MINUTES=120 +``` + +### 3. 安装依赖 + +```bash +cd api +pip install -r requirements.txt +``` + +### 4. 启动 FastAPI 服务 + +```bash +cd api +uvicorn main:app --reload --host 0.0.0.0 --port 8000 +``` + +### 5. 启动 Celery Worker + +在另一个终端窗口: + +```bash +cd api +celery -A tasks.celery_app worker --loglevel=info --concurrency=2 +``` + +## 服务说明 + +### FastAPI API (端口 8000) + +- 主 API 服务,处理 HTTP 和 WebSocket 请求 +- 对话的实时响应通过异步 LLM 调用生成 +- 会话历史存储在 Redis 中 + +### Redis (端口 6379) + +- 存储对话会话历史(支持多实例部署) +- 作为 Celery 的消息队列 +- 会话数据自动过期(默认 24 小时) + +### Celery Worker + +- 处理回忆录生成等后台任务 +- 支持任务重试和失败恢复 +- 可以水平扩展 + +## 生产环境部署 + +### 使用 Docker Compose 一键部署 + +```bash +cd api + +# 创建生产环境配置 +cp .env .env.prod +# 编辑 .env.prod 配置生产环境变量 + +# 启动所有服务 +docker-compose up -d + +# 查看日志 +docker-compose logs -f + +# 停止服务 +docker-compose down +``` + +### 服务扩展 + +扩展 Celery Worker 以处理更多并发任务: + +```bash +# 启动额外的 worker +docker-compose up -d --scale celery-worker=3 +``` + +### 监控(可选) + +启用 Flower 监控面板: + +1. 编辑 `docker-compose.yml`,取消 `flower` 服务的注释 +2. 重启服务:`docker-compose up -d` +3. 访问 http://localhost:5555 查看 Celery 任务监控 + +## 常见问题 + +### Redis 连接失败 + +``` +Redis 连接失败: Error connecting to redis://localhost:6379/0 +``` + +**解决方法**: +1. 确认 Redis 容器正在运行:`docker ps | grep redis` +2. 检查 `REDIS_URL` 环境变量是否正确 +3. 如果在 Docker 内运行 API,使用 `redis://redis:6379/0` + +### Celery 任务不执行 + +**解决方法**: +1. 确认 Celery Worker 正在运行 +2. 检查 Redis 连接是否正常 +3. 查看 Celery 日志:`celery -A tasks.celery_app worker --loglevel=debug` + +### LLM 调用超时 + +**解决方法**: +1. 检查网络连接 +2. 确认 API Key 有效 +3. 考虑增加超时时间或切换到更快的模型 + +## 性能优化建议 + +### 支持几百用户 + +1. **Redis 集群**:对于高并发场景,考虑使用 Redis 集群 +2. **数据库**:从 SQLite 迁移到 PostgreSQL +3. **Celery Worker**:根据负载增加 Worker 数量 +4. **API 实例**:使用负载均衡器部署多个 API 实例 + +### 配置建议 + +```env +# Redis 最大内存(防止 OOM) +# 在 docker-compose.yml 中配置:--maxmemory 512mb + +# Celery 并发数(根据 CPU 核心数调整) +# 在启动命令中配置:--concurrency=4 + +# 会话 TTL(根据业务需求调整) +REDIS_SESSION_TTL=86400 +``` + +## API 端点 + +- `GET /` - API 信息 +- `GET /health` - 健康检查 +- `WS /ws/conversation/{conversation_id}?token=xxx` - WebSocket 对话 +- `POST /api/auth/register` - 用户注册 +- `POST /api/auth/login` - 用户登录 +- `GET /api/conversations/{id}` - 获取对话详情 +- `GET /api/chapters` - 获取章节列表 +- `GET /api/books` - 获取书籍列表 + +## 开发提示 + +### 测试 WebSocket 连接 + +```python +import asyncio +import websockets +import json + +async def test(): + uri = "ws://localhost:8000/ws/conversation/test-123?token=YOUR_TOKEN" + async with websockets.connect(uri) as ws: + # 发送消息 + await ws.send(json.dumps({ + "type": "text", + "data": {"text": "你好,我想聊聊我的童年"} + })) + # 接收响应 + response = await ws.recv() + print(response) + +asyncio.run(test()) +``` + +### 手动触发 Celery 任务 + +```python +from tasks.memoir_tasks import process_memoir_segments + +# 同步调用(测试) +result = process_memoir_segments.delay("user_id", ["segment_id_1", "segment_id_2"]) +print(result.get(timeout=60)) +``` diff --git a/api/main.py b/api/main.py index 5f64929..1185f69 100644 --- a/api/main.py +++ b/api/main.py @@ -87,12 +87,28 @@ async def startup_event(): logger.info("=" * 50) logger.info("Life Echo API 正在启动...") logger.info("=" * 50) + + # 初始化 Redis 连接 + try: + from services.redis_service import redis_service + await redis_service.get_client() + logger.info("Redis 连接已建立") + except Exception as e: + logger.warning(f"Redis 连接失败(会话存储将不可用): {e}") @app.on_event("shutdown") async def shutdown_event(): """应用关闭事件""" logger.info("Life Echo API 正在关闭...") + + # 关闭 Redis 连接 + try: + from services.redis_service import redis_service + await redis_service.close() + logger.info("Redis 连接已关闭") + except Exception as e: + logger.warning(f"关闭 Redis 连接失败: {e}") # CORS 配置 diff --git a/api/requirements.txt b/api/requirements.txt index f7d5d34..7aa63ce 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -13,6 +13,13 @@ sqlalchemy==2.0.36 aiosqlite==0.20.0 greenlet>=3.3.0 +# Redis for session storage +redis>=5.0.0 +aioredis>=2.0.0 + +# Celery for background tasks +celery[redis]>=5.3.0 + # PDF Generation reportlab==4.2.2 weasyprint==62.3 @@ -25,7 +32,7 @@ httpx==0.27.0 # Authentication python-jose[cryptography]==3.3.0 -passlib[bcrypt]==1.7.4 +bcrypt>=4.0.0 # Audio Processing (optional, for future ASR/TTS integration) # pydub==0.25.1 diff --git a/api/routers/websocket.py b/api/routers/websocket.py index 8c7c2be..212a3f3 100644 --- a/api/routers/websocket.py +++ b/api/routers/websocket.py @@ -1,6 +1,8 @@ """ WebSocket 路由:实时对话通信 +支持异步 Agent 调用和 Redis 会话存储 """ +import logging import uuid from datetime import datetime, timezone from enum import Enum @@ -19,6 +21,8 @@ from services.auth_service import verify_token from services.memoir_state_service import get_or_create_state from fastapi import HTTPException, status +logger = logging.getLogger(__name__) + class MessageType(str, Enum): """WebSocket 消息类型""" @@ -39,7 +43,8 @@ class ConnectionManager: def __init__(self): self.active_connections: Dict[str, WebSocket] = {} - self.conversation_agents: Dict[str, ConversationAgent] = {} + # ConversationAgent 现在是无状态的(会话存储在 Redis),可以复用 + self.conversation_agent = ConversationAgent() self.memory_agent = MemoryAgent() self.background_runner = BackgroundTaskRunner() @@ -47,15 +52,13 @@ class ConnectionManager: """建立连接""" await websocket.accept() self.active_connections[conversation_id] = websocket - self.conversation_agents[conversation_id] = ConversationAgent() - def disconnect(self, conversation_id: str): + async def disconnect(self, conversation_id: str): """断开连接""" if conversation_id in self.active_connections: del self.active_connections[conversation_id] - if conversation_id in self.conversation_agents: - self.conversation_agents[conversation_id].clear_memory(conversation_id) - del self.conversation_agents[conversation_id] + # 清除 Redis 中的会话记忆(可选,也可以保留用于恢复) + # await self.conversation_agent.clear_memory(conversation_id) async def send_message(self, conversation_id: str, message: dict): """发送消息""" @@ -198,10 +201,10 @@ async def websocket_endpoint( }) except WebSocketDisconnect: - manager.disconnect(conversation_id) + await manager.disconnect(conversation_id) break except Exception as e: - manager.disconnect(conversation_id) + await manager.disconnect(conversation_id) raise @@ -214,7 +217,7 @@ async def process_user_message( manager: ConnectionManager ) -> None: """ - 处理用户消息,生成Agent回应 + 处理用户消息,生成Agent回应(异步版本) Args: conversation_id: 对话ID @@ -227,24 +230,26 @@ async def process_user_message( Returns: 更新后的对话阶段 """ - agent = manager.conversation_agents.get(conversation_id) - if agent: - state = await get_or_create_state(conversation.user_id, db) + import asyncio as _asyncio + + agent = manager.conversation_agent + state = await get_or_create_state(conversation.user_id, db) - if conversation.conversation_stage != state.current_stage: - conversation.conversation_stage = state.current_stage - await db.commit() + if conversation.conversation_stage != state.current_stage: + conversation.conversation_stage = state.current_stage + await db.commit() - # 获取已聊话题(保留老逻辑用于提示) - stmt_segments = select(Segment).where( - Segment.conversation_id == conversation_id - ).order_by(Segment.created_at) - result_segments = await db.execute(stmt_segments) - previous_segments = result_segments.scalars().all() - covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category] - - # 生成回应(可能是多条消息) - responses = agent.generate_response_with_state( + # 获取已聊话题(保留老逻辑用于提示) + stmt_segments = select(Segment).where( + Segment.conversation_id == conversation_id + ).order_by(Segment.created_at) + result_segments = await db.execute(stmt_segments) + previous_segments = result_segments.scalars().all() + covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category] + + try: + # 异步生成回应(可能是多条消息) + responses = await agent.generate_response_with_state( conversation_id=conversation_id, user_message=user_message, memoir_state=state @@ -255,7 +260,6 @@ async def process_user_message( await db.commit() # 发送 Agent 回应(支持多条消息) - import asyncio as _asyncio for i, response_text in enumerate(responses): await manager.send_message(conversation_id, { "type": MessageType.AGENT_RESPONSE, @@ -266,6 +270,14 @@ async def process_user_message( # 多条消息之间稍作间隔,模拟打字效果 if i < len(responses) - 1: await _asyncio.sleep(0.5) + + except Exception as e: + logger.error(f"处理用户消息失败: {e}") + await manager.send_message(conversation_id, { + "type": MessageType.ERROR, + "data": {"message": f"生成回应失败: {str(e)}"}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) return @@ -274,8 +286,8 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession): """ 处理对话段落,生成章节(对话结束时调用) - 注意:大部分处理已通过 BackgroundTaskRunner 增量完成 - 这里只处理可能遗漏的最后几条消息 + 注意:大部分处理已通过 Celery 任务增量完成 + 这里立即提交所有待处理的段落到 Celery Args: conversation_id: 对话 ID @@ -295,9 +307,19 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession): segments = result.scalars().all() if not segments: + # 没有未处理的段落,直接 flush 待处理任务 + await manager.background_runner.flush_pending(conversation.user_id) return - # 将未处理的段落加入后台任务队列(不等待完成,避免阻塞) - for seg in segments: - await manager.background_runner.queue_message(conversation.user_id, seg.id) + # 将未处理的段落直接提交到 Celery(不通过去抖) + segment_ids = [seg.id for seg in segments] + try: + from tasks.memoir_tasks import process_memoir_segments + process_memoir_segments.delay(conversation.user_id, segment_ids) + logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}") + except Exception as e: + logger.error(f"提交 Celery 任务失败: {e}") + + # 同时 flush 任何待处理的任务 + await manager.background_runner.flush_pending(conversation.user_id) diff --git a/api/services/__init__.py b/api/services/__init__.py index 304ce60..ece21e3 100644 --- a/api/services/__init__.py +++ b/api/services/__init__.py @@ -4,10 +4,12 @@ from .asr_service import asr_service from .tts_service import tts_service from .llm_service import llm_service +from .redis_service import redis_service __all__ = [ "asr_service", "tts_service", "llm_service", + "redis_service", ] diff --git a/api/services/auth_service.py b/api/services/auth_service.py index 5631b9b..0887bce 100644 --- a/api/services/auth_service.py +++ b/api/services/auth_service.py @@ -6,11 +6,8 @@ import secrets from datetime import datetime, timedelta, timezone from typing import Optional, Dict +import bcrypt from jose import JWTError, jwt -from passlib.context import CryptContext - -# 密码加密上下文 -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # JWT配置 SECRET_KEY = os.getenv("SECRET_KEY", secrets.token_urlsafe(32)) @@ -29,7 +26,11 @@ def hash_password(password: str) -> str: Returns: 哈希后的密码 """ - return pwd_context.hash(password) + # bcrypt 要求 bytes 输入 + password_bytes = password.encode('utf-8') + salt = bcrypt.gensalt() + hashed = bcrypt.hashpw(password_bytes, salt) + return hashed.decode('utf-8') def verify_password(plain_password: str, hashed_password: str) -> bool: @@ -43,7 +44,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: Returns: 是否匹配 """ - return pwd_context.verify(plain_password, hashed_password) + try: + password_bytes = plain_password.encode('utf-8') + hashed_bytes = hashed_password.encode('utf-8') + return bcrypt.checkpw(password_bytes, hashed_bytes) + except Exception: + return False def create_access_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str: diff --git a/api/services/redis_service.py b/api/services/redis_service.py new file mode 100644 index 0000000..ba8d2b7 --- /dev/null +++ b/api/services/redis_service.py @@ -0,0 +1,193 @@ +""" +Redis 服务模块:用于会话状态存储和缓存 +""" +import os +import json +import logging +from typing import Optional, List, Dict, Any + +import redis.asyncio as aioredis + +logger = logging.getLogger(__name__) + + +class RedisService: + """Redis 服务,用于存储对话历史和状态""" + + def __init__(self): + """初始化 Redis 连接""" + self.redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + self._client: Optional[aioredis.Redis] = None + # 会话过期时间(默认 24 小时) + self.session_ttl = int(os.getenv("REDIS_SESSION_TTL", "86400")) + + async def get_client(self) -> aioredis.Redis: + """获取 Redis 客户端(延迟初始化)""" + if self._client is None: + try: + self._client = await aioredis.from_url( + self.redis_url, + encoding="utf-8", + decode_responses=True + ) + # 测试连接 + await self._client.ping() + logger.info(f"Redis 连接成功: {self.redis_url}") + except Exception as e: + logger.error(f"Redis 连接失败: {e}") + raise + return self._client + + async def close(self): + """关闭 Redis 连接""" + if self._client: + await self._client.close() + self._client = None + + # ==================== 对话历史管理 ==================== + + def _conversation_key(self, conversation_id: str) -> str: + """生成对话历史的 Redis key""" + return f"conversation:history:{conversation_id}" + + async def get_conversation_history(self, conversation_id: str) -> List[Dict[str, Any]]: + """ + 获取对话历史 + + Args: + conversation_id: 对话 ID + + Returns: + 消息列表 [{"role": "human/ai", "content": "..."}] + """ + try: + client = await self.get_client() + key = self._conversation_key(conversation_id) + data = await client.get(key) + if data: + return json.loads(data) + return [] + except Exception as e: + logger.error(f"获取对话历史失败: {e}") + return [] + + async def add_message( + self, + conversation_id: str, + role: str, + content: str + ) -> bool: + """ + 添加消息到对话历史 + + Args: + conversation_id: 对话 ID + role: 角色 ("human" 或 "ai") + content: 消息内容 + + Returns: + 是否成功 + """ + try: + client = await self.get_client() + key = self._conversation_key(conversation_id) + + # 获取现有历史 + history = await self.get_conversation_history(conversation_id) + + # 添加新消息 + history.append({"role": role, "content": content}) + + # 保存回 Redis(带过期时间) + await client.setex(key, self.session_ttl, json.dumps(history, ensure_ascii=False)) + return True + except Exception as e: + logger.error(f"添加消息失败: {e}") + return False + + async def clear_conversation_history(self, conversation_id: str) -> bool: + """ + 清除对话历史 + + Args: + conversation_id: 对话 ID + + Returns: + 是否成功 + """ + try: + client = await self.get_client() + key = self._conversation_key(conversation_id) + await client.delete(key) + return True + except Exception as e: + logger.error(f"清除对话历史失败: {e}") + return False + + async def extend_session_ttl(self, conversation_id: str) -> bool: + """ + 延长会话过期时间 + + Args: + conversation_id: 对话 ID + + Returns: + 是否成功 + """ + try: + client = await self.get_client() + key = self._conversation_key(conversation_id) + await client.expire(key, self.session_ttl) + return True + except Exception as e: + logger.error(f"延长会话TTL失败: {e}") + return False + + # ==================== 通用缓存方法 ==================== + + async def set_cache(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: + """设置缓存""" + try: + client = await self.get_client() + data = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value + if ttl: + await client.setex(key, ttl, data) + else: + await client.set(key, data) + return True + except Exception as e: + logger.error(f"设置缓存失败: {e}") + return False + + async def get_cache(self, key: str) -> Optional[Any]: + """获取缓存""" + try: + client = await self.get_client() + data = await client.get(key) + if data: + try: + return json.loads(data) + except json.JSONDecodeError: + return data + return None + except Exception as e: + logger.error(f"获取缓存失败: {e}") + return None + + async def delete_cache(self, key: str) -> bool: + """删除缓存""" + try: + client = await self.get_client() + await client.delete(key) + return True + except Exception as e: + logger.error(f"删除缓存失败: {e}") + return False + + def is_available(self) -> bool: + """检查 Redis 是否可用""" + return self._client is not None + + +# 创建全局实例 +redis_service = RedisService() diff --git a/api/tasks/__init__.py b/api/tasks/__init__.py new file mode 100644 index 0000000..c3a53ec --- /dev/null +++ b/api/tasks/__init__.py @@ -0,0 +1,7 @@ +""" +Celery 任务模块 +""" +from .celery_app import celery_app +from .memoir_tasks import process_memoir_segments + +__all__ = ["celery_app", "process_memoir_segments"] diff --git a/api/tasks/celery_app.py b/api/tasks/celery_app.py new file mode 100644 index 0000000..4a5a24d --- /dev/null +++ b/api/tasks/celery_app.py @@ -0,0 +1,58 @@ +""" +Celery 应用配置 +""" +import os +from celery import Celery +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv() + +# Redis URL +REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") + +# 创建 Celery 应用 +celery_app = Celery( + "life_echo", + broker=REDIS_URL, + backend=REDIS_URL, + include=["tasks.memoir_tasks"] +) + +# Celery 配置 +celery_app.conf.update( + # 任务序列化 + task_serializer="json", + accept_content=["json"], + result_serializer="json", + + # 时区 + timezone="UTC", + enable_utc=True, + + # 任务结果过期时间(1小时) + result_expires=3600, + + # 任务执行设置 + task_soft_time_limit=300, # 5分钟软超时 + task_time_limit=600, # 10分钟硬超时 + + # 并发设置 + worker_prefetch_multiplier=1, # 每次只预取一个任务 + worker_concurrency=4, # 并发 worker 数量 + + # 任务重试设置 + task_acks_late=True, # 任务完成后再确认 + task_reject_on_worker_lost=True, # worker 丢失时拒绝任务 + + # 不设置自定义队列路由,使用 Celery 默认队列 +) + +# 定时任务配置(如果需要) +celery_app.conf.beat_schedule = { + # 示例:每小时清理过期会话 + # "cleanup-expired-sessions": { + # "task": "tasks.cleanup.cleanup_sessions", + # "schedule": 3600.0, + # }, +} diff --git a/api/tasks/memoir_tasks.py b/api/tasks/memoir_tasks.py new file mode 100644 index 0000000..51f799c --- /dev/null +++ b/api/tasks/memoir_tasks.py @@ -0,0 +1,345 @@ +""" +回忆录处理 Celery 任务 +""" +import json +import logging +import uuid +from typing import Dict, List + +from celery import shared_task +from sqlalchemy import select +from sqlalchemy.orm import Session + +from database.database import SessionLocal +from database.models import Book, Chapter, Segment, MemoirState +from services.llm_service import llm_service +from agents.state_schema import MemoirStateSchema, SlotData, default_state +from agents.prompts.memory_prompts import ( + get_creative_title_prompt, + get_narrative_prompt, + get_state_extraction_prompt, +) + +logger = logging.getLogger(__name__) + +STAGE_KEYWORDS = { + "childhood": ["童年", "小时候", "出生", "家乡", "小镇"], + "education": ["上学", "学校", "老师", "同学", "教育", "大学"], + "career": ["工作", "职业", "事业", "公司", "同事", "创业"], + "family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"], + "belief": ["信念", "价值观", "座右铭", "坚持", "原则"], +} + + +def _detect_stage(user_message: str, fallback_stage: str) -> str: + """检测消息所属阶段""" + message = user_message.lower() + for stage, keywords in STAGE_KEYWORDS.items(): + if any(word in message for word in keywords): + return stage + return fallback_stage + + +def _coerce_state(model: MemoirState) -> MemoirStateSchema: + """将数据库模型转换为 Schema""" + return MemoirStateSchema.model_validate( + { + "stage_order": model.stage_order or default_state().stage_order, + "current_stage": model.current_stage, + "covered_stages": model.covered_stages or [], + "slots": model.slots if isinstance(model.slots, dict) else default_state().slots, + } + ) + + +def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema: + """同步获取或创建状态""" + stmt = select(MemoirState).where(MemoirState.user_id == user_id) + result = db.execute(stmt) + state = result.scalar_one_or_none() + if state: + return _coerce_state(state) + + default = default_state() + state = MemoirState( + id=str(uuid.uuid4()), + user_id=user_id, + stage_order=default.stage_order, + current_stage=default.current_stage, + covered_stages=default.covered_stages, + slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()}, + ) + db.add(state) + db.commit() + db.refresh(state) + return _coerce_state(state) + + +def _update_slot_sync( + user_id: str, + stage: str, + slot_name: str, + snippet: str, + segment_ids: List[str], + db: Session, +) -> MemoirStateSchema: + """同步更新 slot""" + stmt = select(MemoirState).where(MemoirState.user_id == user_id) + result = db.execute(stmt) + state = result.scalar_one_or_none() + if not state: + _get_or_create_state_sync(user_id, db) + result = db.execute(stmt) + state = result.scalar_one() + + slots: Dict[str, Dict] = state.slots or {} + stage_slots = slots.get(stage, {}) + existing = stage_slots.get(slot_name, {}) + + merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) + stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump() + slots[stage] = stage_slots + state.slots = slots + state.current_stage = state.current_stage or stage + db.commit() + db.refresh(state) + return _coerce_state(state) + + +@shared_task(bind=True, max_retries=3, default_retry_delay=60) +def process_memoir_segments(self, user_id: str, segment_ids: List[str]): + """ + 处理回忆录段落的 Celery 任务 + + Args: + user_id: 用户 ID + segment_ids: 段落 ID 列表 + """ + logger.info(f"开始处理回忆录段落: user_id={user_id}, segments={len(segment_ids)}") + + try: + db = SessionLocal() + try: + # 获取段落 + stmt = select(Segment).where(Segment.id.in_(segment_ids)) + result = db.execute(stmt) + segments = result.scalars().all() + + if not segments: + logger.warning(f"未找到段落: {segment_ids}") + return {"status": "no_segments"} + + # 获取用户状态 + state = _get_or_create_state_sync(user_id, db) + llm = llm_service.get_llm() + + # 按阶段分组处理 + stage_to_segments: Dict[str, List[Segment]] = {} + + for segment in segments: + text = segment.transcript_text + detected_stage = _detect_stage(text, state.current_stage) + + # 尝试使用 LLM 提取信息 + extracted_slots = {} + if llm: + try: + prompt = get_state_extraction_prompt( + user_message=text, + current_stage=state.current_stage, + stage_slots=state.slots.get(detected_stage, {}), + ) + response = llm.invoke(prompt) + content = response.content.strip() + parsed = json.loads(content) + detected_stage = parsed.get("detected_stage", detected_stage) + extracted_slots = parsed.get("slots", {}) or {} + except (json.JSONDecodeError, Exception) as e: + logger.warning(f"LLM 解析失败: {e}") + + # 更新 slots + for slot_name, snippet in extracted_slots.items(): + state = _update_slot_sync( + user_id=user_id, + stage=detected_stage, + slot_name=slot_name, + snippet=snippet, + segment_ids=[segment.id], + db=db, + ) + + stage_to_segments.setdefault(detected_stage, []).append(segment) + + # 生成章节内容 + for stage, stage_segments in stage_to_segments.items(): + segment_texts = [seg.transcript_text for seg in stage_segments] + combined_text = "\n\n".join(segment_texts) + source_ids = [seg.id for seg in stage_segments] + + # 查找或创建章节 + stmt_chapter = select(Chapter).where( + Chapter.user_id == user_id, + Chapter.category == stage, + ) + result_chapter = db.execute(stmt_chapter) + chapter = result_chapter.scalar_one_or_none() + + # 获取 slot snippets + slot_snippets = { + key: value.snippet + for key, value in (state.slots.get(stage, {}) or {}).items() + if value.snippet + } + + # 生成标题和内容 + title = chapter.title if chapter else f"{stage} 回忆" + existing_content = chapter.content if chapter else "" + narrative = combined_text + + if llm: + try: + if not chapter: + title_prompt = get_creative_title_prompt( + stage=stage, + emotion="neutral", + slots=slot_snippets + ) + title_response = llm.invoke(title_prompt) + title = title_response.content.strip().strip('"') + + narrative_prompt = get_narrative_prompt( + stage=stage, + slots=slot_snippets, + new_content=combined_text, + existing_content=existing_content, + ) + narrative_response = llm.invoke(narrative_prompt) + narrative = narrative_response.content.strip() + except Exception as e: + logger.warning(f"LLM 生成失败: {e}") + if existing_content: + narrative = f"{existing_content}\n\n{combined_text}" + + # 更新或创建章节 + if chapter: + chapter.content = narrative + chapter.title = title + chapter.is_new = True + chapter.source_segments = list({*(chapter.source_segments or []), *source_ids}) + else: + chapter = Chapter( + id=str(uuid.uuid4()), + user_id=user_id, + title=title, + content=narrative, + order_index=999, + status="completed", + category=stage, + images=[], + is_new=True, + source_segments=source_ids, + ) + db.add(chapter) + + db.flush() + + # 更新 Book + stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()) + result_book = db.execute(stmt_book) + book = result_book.scalar_one_or_none() + if not book: + book = Book( + id=str(uuid.uuid4()), + user_id=user_id, + title="我的回忆录", + total_pages=0, + total_words=0, + cover_image_url=None, + ) + db.add(book) + book.has_update = True + book.last_update_chapter_id = chapter.id + + # 标记段落为已处理 + for seg in segments: + seg.processed = True + + db.commit() + logger.info(f"回忆录处理完成: user_id={user_id}") + return {"status": "success", "processed": len(segments)} + + finally: + db.close() + + except Exception as e: + logger.error(f"回忆录处理失败: {e}") + # 重试 + raise self.retry(exc=e) + + +@shared_task(bind=True, max_retries=3, default_retry_delay=30) +def generate_chapter_content(self, user_id: str, stage: str, new_content: str): + """ + 单独生成章节内容的任务(用于实时更新) + + Args: + user_id: 用户 ID + stage: 阶段 + new_content: 新内容 + """ + logger.info(f"生成章节内容: user_id={user_id}, stage={stage}") + + try: + db = SessionLocal() + try: + llm = llm_service.get_llm() + + # 查找章节 + stmt = select(Chapter).where( + Chapter.user_id == user_id, + Chapter.category == stage, + ) + result = db.execute(stmt) + chapter = result.scalar_one_or_none() + + existing_content = chapter.content if chapter else "" + + if llm: + prompt = get_narrative_prompt( + stage=stage, + slots={}, + new_content=new_content, + existing_content=existing_content, + ) + response = llm.invoke(prompt) + narrative = response.content.strip() + else: + narrative = f"{existing_content}\n\n{new_content}" if existing_content else new_content + + if chapter: + chapter.content = narrative + chapter.is_new = True + else: + chapter = Chapter( + id=str(uuid.uuid4()), + user_id=user_id, + title=f"{stage} 回忆", + content=narrative, + order_index=999, + status="completed", + category=stage, + images=[], + is_new=True, + source_segments=[], + ) + db.add(chapter) + + db.commit() + return {"status": "success"} + + finally: + db.close() + + except Exception as e: + logger.error(f"章节生成失败: {e}") + raise self.retry(exc=e)