""" 任务追踪服务:追踪 Celery 任务状态 """ import json import logging from typing import List, Dict, Optional, Any from datetime import datetime, timezone from services.redis_service import redis_service logger = logging.getLogger(__name__) class TaskTracker: """任务追踪器,使用 Redis 存储任务状态""" # Redis key 前缀 KEY_PREFIX = "task:user:" # 任务记录过期时间(1小时) TASK_TTL = 3600 async def add_task(self, user_id: str, task_id: str, task_type: str = "memoir") -> bool: """ 记录新任务 Args: user_id: 用户 ID task_id: Celery 任务 ID task_type: 任务类型 Returns: 是否成功 """ try: client = await redis_service.get_client() key = f"{self.KEY_PREFIX}{user_id}:tasks" task_info = { "task_id": task_id, "task_type": task_type, "status": "pending", "created_at": datetime.now(timezone.utc).isoformat(), } # 使用 hash 存储任务信息 await client.hset(key, task_id, json.dumps(task_info)) await client.expire(key, self.TASK_TTL) logger.info(f"任务已记录: user_id={user_id}, task_id={task_id}") return True except Exception as e: logger.error(f"记录任务失败: {e}") return False async def update_task_status(self, user_id: str, task_id: str, status: str, result: Any = None) -> bool: """ 更新任务状态 Args: user_id: 用户 ID task_id: Celery 任务 ID status: 新状态 (pending, running, success, failure) result: 任务结果 Returns: 是否成功 """ try: client = await redis_service.get_client() key = f"{self.KEY_PREFIX}{user_id}:tasks" # 获取现有任务信息 data = await client.hget(key, task_id) if data: task_info = json.loads(data) else: task_info = {"task_id": task_id} task_info["status"] = status task_info["updated_at"] = datetime.now(timezone.utc).isoformat() if result is not None: task_info["result"] = result await client.hset(key, task_id, json.dumps(task_info)) return True except Exception as e: logger.error(f"更新任务状态失败: {e}") return False async def get_user_tasks(self, user_id: str) -> List[Dict]: """ 获取用户的所有任务 Args: user_id: 用户 ID Returns: 任务列表 """ try: client = await redis_service.get_client() key = f"{self.KEY_PREFIX}{user_id}:tasks" tasks_data = await client.hgetall(key) tasks = [] for task_id, data in tasks_data.items(): task_info = json.loads(data) tasks.append(task_info) return tasks except Exception as e: logger.error(f"获取用户任务失败: {e}") return [] async def get_pending_tasks(self, user_id: str) -> List[Dict]: """ 获取用户的待处理任务 Args: user_id: 用户 ID Returns: 待处理任务列表 """ tasks = await self.get_user_tasks(user_id) return [t for t in tasks if t.get("status") in ("pending", "running")] async def check_tasks_status(self, user_id: str) -> Dict: """ 检查用户任务状态汇总 Args: user_id: 用户 ID Returns: 状态汇总 {total, pending, running, success, failure, all_completed} """ tasks = await self.get_user_tasks(user_id) status_counts = { "total": len(tasks), "pending": 0, "running": 0, "success": 0, "failure": 0, } for task in tasks: status = task.get("status", "pending") if status in status_counts: status_counts[status] += 1 status_counts["all_completed"] = ( status_counts["total"] > 0 and status_counts["pending"] == 0 and status_counts["running"] == 0 ) return status_counts async def clear_user_tasks(self, user_id: str) -> bool: """ 清除用户的所有任务记录 Args: user_id: 用户 ID Returns: 是否成功 """ try: client = await redis_service.get_client() key = f"{self.KEY_PREFIX}{user_id}:tasks" await client.delete(key) return True except Exception as e: logger.error(f"清除用户任务失败: {e}") return False async def remove_task(self, user_id: str, task_id: str) -> bool: """ 移除单个任务记录 Args: user_id: 用户 ID task_id: 任务 ID Returns: 是否成功 """ try: client = await redis_service.get_client() key = f"{self.KEY_PREFIX}{user_id}:tasks" await client.hdel(key, task_id) return True except Exception as e: logger.error(f"移除任务失败: {e}") return False # 全局实例 task_tracker = TaskTracker()