diff --git a/api/agents/memoir_processor.py b/api/agents/memoir_processor.py index 00900c6..df8695a 100644 --- a/api/agents/memoir_processor.py +++ b/api/agents/memoir_processor.py @@ -146,6 +146,30 @@ class BackgroundTaskRunner: self.analyzer = ContentAnalyzer() self.generator = MemoirGenerator() + async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None: + """ + 提交 Celery 任务并记录 + + Returns: + 任务 ID,失败返回 None + """ + try: + from tasks.memoir_tasks import process_memoir_segments + from services.task_tracker import task_tracker + + # 提交到 Celery + result = process_memoir_segments.delay(user_id, segment_ids) + task_id = result.id + + # 记录任务 + await task_tracker.add_task(user_id, task_id, "memoir") + + logger.info(f"已提交 Celery 任务: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}") + return task_id + except Exception as e: + logger.error(f"提交 Celery 任务失败: {e}") + return None + async def queue_message(self, user_id: str, segment_id: str) -> None: """ 将消息加入处理队列 @@ -167,20 +191,20 @@ class BackgroundTaskRunner: await asyncio.sleep(self.debounce_seconds) segment_ids = self._pending.pop(user_id, []) if segment_ids: - # 提交到 Celery - from tasks.memoir_tasks import process_memoir_segments - process_memoir_segments.delay(user_id, segment_ids) - logger.info(f"已提交 Celery 任务: user_id={user_id}, segments={len(segment_ids)}") + await self._submit_task(user_id, segment_ids) except asyncio.CancelledError: pass except Exception as e: - logger.error(f"提交 Celery 任务失败: {e}") + logger.error(f"延迟提交任务失败: {e}") self._timers[user_id] = asyncio.create_task(delayed_submit()) - async def flush_pending(self, user_id: str) -> None: + async def flush_pending(self, user_id: str) -> str | None: """ 立即提交用户的待处理任务(用于对话结束时) + + Returns: + 任务 ID,无任务或失败返回 None """ # 取消定时器 if user_id in self._timers: @@ -190,6 +214,5 @@ class BackgroundTaskRunner: # 提交待处理任务 segment_ids = self._pending.pop(user_id, []) if segment_ids: - from tasks.memoir_tasks import process_memoir_segments - process_memoir_segments.delay(user_id, segment_ids) - logger.info(f"立即提交 Celery 任务: user_id={user_id}, segments={len(segment_ids)}") + return await self._submit_task(user_id, segment_ids) + return None diff --git a/api/main.py b/api/main.py index a08e10b..d9203b3 100644 --- a/api/main.py +++ b/api/main.py @@ -52,7 +52,7 @@ else: load_dotenv() from database import init_db -from routers import websocket, chapters, books, conversations, auth, memoir_state +from routers import websocket, chapters, books, conversations, auth, memoir_state, tasks # 初始化数据库 logger.info("正在初始化数据库...") @@ -134,6 +134,7 @@ app.include_router(conversations.router) app.include_router(chapters.router) app.include_router(books.router) app.include_router(memoir_state.router) +app.include_router(tasks.router) # 任务状态路由 @app.get("/") diff --git a/api/routers/tasks.py b/api/routers/tasks.py new file mode 100644 index 0000000..12532e1 --- /dev/null +++ b/api/routers/tasks.py @@ -0,0 +1,78 @@ +""" +任务状态 API 路由 +""" +from typing import List, Dict +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel + +from database.models import User as UserModel +from middleware.auth import get_current_user +from services.task_tracker import task_tracker + +router = APIRouter(prefix="/api/tasks", tags=["tasks"]) + + +class TaskInfo(BaseModel): + """任务信息""" + task_id: str + task_type: str = "memoir" + status: str # pending, running, success, failure + created_at: str = None + updated_at: str = None + result: Dict = None + + +class TasksStatusResponse(BaseModel): + """任务状态汇总响应""" + total: int + pending: int + running: int + success: int + failure: int + all_completed: bool + tasks: List[TaskInfo] + + +@router.get("/status", response_model=TasksStatusResponse) +async def get_tasks_status( + current_user: UserModel = Depends(get_current_user) +): + """ + 获取当前用户的任务状态汇总 + + 用于检查后台任务是否全部完成 + """ + status = await task_tracker.check_tasks_status(current_user.id) + tasks = await task_tracker.get_user_tasks(current_user.id) + + return TasksStatusResponse( + total=status["total"], + pending=status["pending"], + running=status["running"], + success=status["success"], + failure=status["failure"], + all_completed=status["all_completed"], + tasks=[TaskInfo(**t) for t in tasks] + ) + + +@router.get("/pending") +async def get_pending_tasks( + current_user: UserModel = Depends(get_current_user) +): + """ + 获取当前用户的待处理任务列表 + """ + tasks = await task_tracker.get_pending_tasks(current_user.id) + return {"pending_tasks": tasks} + + +@router.delete("/clear") +async def clear_tasks( + current_user: UserModel = Depends(get_current_user) +): + """ + 清除当前用户的所有任务记录 + """ + await task_tracker.clear_user_tasks(current_user.id) + return {"message": "任务记录已清除"} diff --git a/api/services/task_tracker.py b/api/services/task_tracker.py new file mode 100644 index 0000000..fd8dfff --- /dev/null +++ b/api/services/task_tracker.py @@ -0,0 +1,202 @@ +""" +任务追踪服务:追踪 Celery 任务状态 +""" +import json +import logging +from typing import List, Dict, Optional, Any +from datetime import datetime, timezone + +from services.redis_service import redis_service + +logger = logging.getLogger(__name__) + + +class TaskTracker: + """任务追踪器,使用 Redis 存储任务状态""" + + # Redis key 前缀 + KEY_PREFIX = "task:user:" + # 任务记录过期时间(1小时) + TASK_TTL = 3600 + + async def add_task(self, user_id: str, task_id: str, task_type: str = "memoir") -> bool: + """ + 记录新任务 + + Args: + user_id: 用户 ID + task_id: Celery 任务 ID + task_type: 任务类型 + + Returns: + 是否成功 + """ + try: + client = await redis_service.get_client() + key = f"{self.KEY_PREFIX}{user_id}:tasks" + + task_info = { + "task_id": task_id, + "task_type": task_type, + "status": "pending", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + # 使用 hash 存储任务信息 + await client.hset(key, task_id, json.dumps(task_info)) + await client.expire(key, self.TASK_TTL) + + logger.info(f"任务已记录: user_id={user_id}, task_id={task_id}") + return True + except Exception as e: + logger.error(f"记录任务失败: {e}") + return False + + async def update_task_status(self, user_id: str, task_id: str, status: str, result: Any = None) -> bool: + """ + 更新任务状态 + + Args: + user_id: 用户 ID + task_id: Celery 任务 ID + status: 新状态 (pending, running, success, failure) + result: 任务结果 + + Returns: + 是否成功 + """ + try: + client = await redis_service.get_client() + key = f"{self.KEY_PREFIX}{user_id}:tasks" + + # 获取现有任务信息 + data = await client.hget(key, task_id) + if data: + task_info = json.loads(data) + else: + task_info = {"task_id": task_id} + + task_info["status"] = status + task_info["updated_at"] = datetime.now(timezone.utc).isoformat() + if result is not None: + task_info["result"] = result + + await client.hset(key, task_id, json.dumps(task_info)) + return True + except Exception as e: + logger.error(f"更新任务状态失败: {e}") + return False + + async def get_user_tasks(self, user_id: str) -> List[Dict]: + """ + 获取用户的所有任务 + + Args: + user_id: 用户 ID + + Returns: + 任务列表 + """ + try: + client = await redis_service.get_client() + key = f"{self.KEY_PREFIX}{user_id}:tasks" + + tasks_data = await client.hgetall(key) + tasks = [] + for task_id, data in tasks_data.items(): + task_info = json.loads(data) + tasks.append(task_info) + + return tasks + except Exception as e: + logger.error(f"获取用户任务失败: {e}") + return [] + + async def get_pending_tasks(self, user_id: str) -> List[Dict]: + """ + 获取用户的待处理任务 + + Args: + user_id: 用户 ID + + Returns: + 待处理任务列表 + """ + tasks = await self.get_user_tasks(user_id) + return [t for t in tasks if t.get("status") in ("pending", "running")] + + async def check_tasks_status(self, user_id: str) -> Dict: + """ + 检查用户任务状态汇总 + + Args: + user_id: 用户 ID + + Returns: + 状态汇总 {total, pending, running, success, failure, all_completed} + """ + tasks = await self.get_user_tasks(user_id) + + status_counts = { + "total": len(tasks), + "pending": 0, + "running": 0, + "success": 0, + "failure": 0, + } + + for task in tasks: + status = task.get("status", "pending") + if status in status_counts: + status_counts[status] += 1 + + status_counts["all_completed"] = ( + status_counts["total"] > 0 and + status_counts["pending"] == 0 and + status_counts["running"] == 0 + ) + + return status_counts + + async def clear_user_tasks(self, user_id: str) -> bool: + """ + 清除用户的所有任务记录 + + Args: + user_id: 用户 ID + + Returns: + 是否成功 + """ + try: + client = await redis_service.get_client() + key = f"{self.KEY_PREFIX}{user_id}:tasks" + await client.delete(key) + return True + except Exception as e: + logger.error(f"清除用户任务失败: {e}") + return False + + async def remove_task(self, user_id: str, task_id: str) -> bool: + """ + 移除单个任务记录 + + Args: + user_id: 用户 ID + task_id: 任务 ID + + Returns: + 是否成功 + """ + try: + client = await redis_service.get_client() + key = f"{self.KEY_PREFIX}{user_id}:tasks" + await client.hdel(key, task_id) + return True + except Exception as e: + logger.error(f"移除任务失败: {e}") + return False + + +# 全局实例 +task_tracker = TaskTracker() diff --git a/api/tasks/memoir_tasks.py b/api/tasks/memoir_tasks.py index 51f799c..dd81923 100644 --- a/api/tasks/memoir_tasks.py +++ b/api/tasks/memoir_tasks.py @@ -3,9 +3,12 @@ """ import json import logging +import os import uuid from typing import Dict, List +from datetime import datetime, timezone +import redis from celery import shared_task from sqlalchemy import select from sqlalchemy.orm import Session @@ -22,6 +25,34 @@ from agents.prompts.memory_prompts import ( logger = logging.getLogger(__name__) + +def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None): + """同步更新任务状态(在 Celery 任务中使用)""" + try: + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + r = redis.from_url(redis_url, decode_responses=True) + + key = f"task:user:{user_id}:tasks" + + # 获取现有任务信息 + data = r.hget(key, task_id) + if data: + task_info = json.loads(data) + else: + task_info = {"task_id": task_id} + + task_info["status"] = status + task_info["updated_at"] = datetime.now(timezone.utc).isoformat() + if result is not None: + task_info["result"] = result + + r.hset(key, task_id, json.dumps(task_info)) + r.expire(key, 3600) # 1小时过期 + + logger.info(f"任务状态已更新: task_id={task_id}, status={status}") + except Exception as e: + logger.error(f"更新任务状态失败: {e}") + STAGE_KEYWORDS = { "childhood": ["童年", "小时候", "出生", "家乡", "小镇"], "education": ["上学", "学校", "老师", "同学", "教育", "大学"], @@ -115,7 +146,11 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): user_id: 用户 ID segment_ids: 段落 ID 列表 """ - logger.info(f"开始处理回忆录段落: user_id={user_id}, segments={len(segment_ids)}") + task_id = self.request.id + logger.info(f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}") + + # 更新任务状态为 running + _update_task_status_sync(user_id, task_id, "running") try: db = SessionLocal() @@ -265,7 +300,11 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): seg.processed = True db.commit() - logger.info(f"回忆录处理完成: user_id={user_id}") + logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}") + + # 更新任务状态为成功 + _update_task_status_sync(user_id, task_id, "success", {"processed": len(segments)}) + return {"status": "success", "processed": len(segments)} finally: @@ -273,6 +312,10 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): except Exception as e: logger.error(f"回忆录处理失败: {e}") + + # 更新任务状态为失败 + _update_task_status_sync(user_id, task_id, "failure", {"error": str(e)}) + # 重试 raise self.retry(exc=e) diff --git a/api/test_conversation.py b/api/test_conversation.py index 89abf2f..1ef1a09 100644 --- a/api/test_conversation.py +++ b/api/test_conversation.py @@ -117,6 +117,25 @@ class ConversationTester: return {"message": "获取失败"} return resp.json() + async def get_tasks_status(self): + """获取任务状态""" + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.get( + f"{BASE_URL}/api/tasks/status", + headers={"Authorization": f"Bearer {self.token}"} + ) + if resp.status_code != 200: + return {"total": 0, "all_completed": True, "tasks": []} + return resp.json() + + async def clear_tasks(self): + """清除任务记录""" + async with httpx.AsyncClient(timeout=60.0) as client: + await client.delete( + f"{BASE_URL}/api/tasks/clear", + headers={"Authorization": f"Bearer {self.token}"} + ) + async def run_conversation(self): """运行多轮对话""" print(f"\n🔗 连接 WebSocket: {self.conversation_id}") @@ -187,18 +206,75 @@ class ConversationTester: except asyncio.TimeoutError: print("⏰ 等待结束确认超时(但后台处理可能仍在进行)") + async def wait_for_processing(self, max_wait_seconds: int = 300, check_interval: int = 3): + """ + 等待后台处理完成 + 通过查询 Celery 任务状态来判断处理是否完成 + + Args: + max_wait_seconds: 最大等待时间(秒),默认 5 分钟 + check_interval: 检查间隔(秒) + + Returns: + 是否在超时前完成 + """ + print(f"\n⏳ 等待后台任务完成(最多 {max_wait_seconds} 秒)...") + print(" 提示: 通过 Celery 任务状态 API 追踪任务进度") + + start_time = asyncio.get_event_loop().time() + + while True: + elapsed = asyncio.get_event_loop().time() - start_time + + if elapsed >= max_wait_seconds: + print(f"\n⚠️ 已等待 {max_wait_seconds} 秒,超时退出") + return False + + # 检查任务状态 + tasks_status = await self.get_tasks_status() + total = tasks_status.get("total", 0) + pending = tasks_status.get("pending", 0) + running = tasks_status.get("running", 0) + success = tasks_status.get("success", 0) + failure = tasks_status.get("failure", 0) + all_completed = tasks_status.get("all_completed", False) + + # 同时检查章节内容 + chapters = await self.get_chapters() + chapter_count = len(chapters) + total_content_length = sum(len(ch.get('content', '')) for ch in chapters) + + status_str = f"📊 总:{total} 等待:{pending} 运行:{running} 成功:{success} 失败:{failure}" + content_str = f"📚 章节:{chapter_count} 内容:{total_content_length}字符" + print(f" [{int(elapsed):3d}s] {status_str} | {content_str}") + + # 判断是否完成: + # 1. 有任务且全部完成 + # 2. 或者没有任务但有章节内容(兼容旧逻辑) + if total > 0 and all_completed: + print(f"\n✅ 所有任务已完成!共 {total} 个任务,等待 {int(elapsed)} 秒") + return True + + # 如果没有任务记录,等待一会儿任务提交 + if total == 0 and elapsed < 15: + await asyncio.sleep(check_interval) + continue + + # 如果长时间没有任务但有内容,也认为完成 + if total == 0 and chapter_count > 0 and elapsed > 30: + print(f"\n✅ 无待处理任务,已有 {chapter_count} 个章节。等待 {int(elapsed)} 秒") + return True + + await asyncio.sleep(check_interval) + async def check_results(self): """检查回忆录生成结果""" print(f"\n{'='*60}") print("📊 检查结果") print(f"{'='*60}") - # 等待后台处理完成(后台任务有5秒debounce + 多次LLM调用) - # 每条消息可能触发2-3次LLM调用,7条消息 = 约20次调用,每次约3秒 - print("\n⏳ 等待后台处理完成...") - for i in range(6): # 60秒,每10秒检查一次 - print(f" 等待中... {(i+1)*10}秒") - await asyncio.sleep(10) + # 等待后台处理完成(使用智能轮询) + await self.wait_for_processing(max_wait_seconds=180, check_interval=5) # 获取回忆录状态 print("\n📋 回忆录状态:") @@ -223,7 +299,8 @@ class ConversationTester: if chapters: for ch in chapters: is_new = "🆕" if ch.get("is_new") else "" - print(f" {is_new} [{ch.get('category', 'N/A')}] {ch.get('title', 'N/A')}") + content_len = len(ch.get('content', '')) + print(f" {is_new} [{ch.get('category', 'N/A')}] {ch.get('title', 'N/A')} ({content_len} 字符)") else: print(" (暂无章节)") @@ -248,7 +325,7 @@ class ConversationTester: content = ch.get('content', '') print(f"\n{'─'*60}") - print(f"【{title}】") + print(f"【{title}】({category})") print(f"{'─'*60}") if content: print(content) @@ -270,15 +347,19 @@ async def main(): # 1. 注册/登录 await tester.register_or_login() - # 2. 查看初始状态 + # 2. 清除旧的任务记录 + await tester.clear_tasks() + print("\n🧹 已清除旧的任务记录") + + # 3. 查看初始状态 print("\n📋 初始回忆录状态:") state = await tester.get_memoir_state() print(f" 当前阶段: {state.get('current_stage', 'N/A')}") - # 3. 运行多轮对话 + # 4. 运行多轮对话 await tester.run_conversation() - # 4. 检查结果 + # 5. 检查结果 await tester.check_results() except Exception as e: