Files
life-echo/api/services/task_tracker.py
penghanyuan 3f899aa16c feat: 新增任务状态管理和API支持
- 添加任务状态API路由,支持获取当前用户的任务状态和待处理任务列表
- 实现任务追踪服务,使用Redis存储任务状态
- 更新回忆录处理逻辑,集成Celery任务提交和状态更新
- 增强测试用例,支持任务状态的获取和清除功能
- 优化代码结构,提升可读性和维护性
2026-01-21 23:37:00 +01:00

203 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
任务追踪服务:追踪 Celery 任务状态
"""
import json
import logging
from typing import List, Dict, Optional, Any
from datetime import datetime, timezone
from services.redis_service import redis_service
logger = logging.getLogger(__name__)
class TaskTracker:
"""任务追踪器,使用 Redis 存储任务状态"""
# Redis key 前缀
KEY_PREFIX = "task:user:"
# 任务记录过期时间1小时
TASK_TTL = 3600
async def add_task(self, user_id: str, task_id: str, task_type: str = "memoir") -> bool:
"""
记录新任务
Args:
user_id: 用户 ID
task_id: Celery 任务 ID
task_type: 任务类型
Returns:
是否成功
"""
try:
client = await redis_service.get_client()
key = f"{self.KEY_PREFIX}{user_id}:tasks"
task_info = {
"task_id": task_id,
"task_type": task_type,
"status": "pending",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 使用 hash 存储任务信息
await client.hset(key, task_id, json.dumps(task_info))
await client.expire(key, self.TASK_TTL)
logger.info(f"任务已记录: user_id={user_id}, task_id={task_id}")
return True
except Exception as e:
logger.error(f"记录任务失败: {e}")
return False
async def update_task_status(self, user_id: str, task_id: str, status: str, result: Any = None) -> bool:
"""
更新任务状态
Args:
user_id: 用户 ID
task_id: Celery 任务 ID
status: 新状态 (pending, running, success, failure)
result: 任务结果
Returns:
是否成功
"""
try:
client = await redis_service.get_client()
key = f"{self.KEY_PREFIX}{user_id}:tasks"
# 获取现有任务信息
data = await client.hget(key, task_id)
if data:
task_info = json.loads(data)
else:
task_info = {"task_id": task_id}
task_info["status"] = status
task_info["updated_at"] = datetime.now(timezone.utc).isoformat()
if result is not None:
task_info["result"] = result
await client.hset(key, task_id, json.dumps(task_info))
return True
except Exception as e:
logger.error(f"更新任务状态失败: {e}")
return False
async def get_user_tasks(self, user_id: str) -> List[Dict]:
"""
获取用户的所有任务
Args:
user_id: 用户 ID
Returns:
任务列表
"""
try:
client = await redis_service.get_client()
key = f"{self.KEY_PREFIX}{user_id}:tasks"
tasks_data = await client.hgetall(key)
tasks = []
for task_id, data in tasks_data.items():
task_info = json.loads(data)
tasks.append(task_info)
return tasks
except Exception as e:
logger.error(f"获取用户任务失败: {e}")
return []
async def get_pending_tasks(self, user_id: str) -> List[Dict]:
"""
获取用户的待处理任务
Args:
user_id: 用户 ID
Returns:
待处理任务列表
"""
tasks = await self.get_user_tasks(user_id)
return [t for t in tasks if t.get("status") in ("pending", "running")]
async def check_tasks_status(self, user_id: str) -> Dict:
"""
检查用户任务状态汇总
Args:
user_id: 用户 ID
Returns:
状态汇总 {total, pending, running, success, failure, all_completed}
"""
tasks = await self.get_user_tasks(user_id)
status_counts = {
"total": len(tasks),
"pending": 0,
"running": 0,
"success": 0,
"failure": 0,
}
for task in tasks:
status = task.get("status", "pending")
if status in status_counts:
status_counts[status] += 1
status_counts["all_completed"] = (
status_counts["total"] > 0 and
status_counts["pending"] == 0 and
status_counts["running"] == 0
)
return status_counts
async def clear_user_tasks(self, user_id: str) -> bool:
"""
清除用户的所有任务记录
Args:
user_id: 用户 ID
Returns:
是否成功
"""
try:
client = await redis_service.get_client()
key = f"{self.KEY_PREFIX}{user_id}:tasks"
await client.delete(key)
return True
except Exception as e:
logger.error(f"清除用户任务失败: {e}")
return False
async def remove_task(self, user_id: str, task_id: str) -> bool:
"""
移除单个任务记录
Args:
user_id: 用户 ID
task_id: 任务 ID
Returns:
是否成功
"""
try:
client = await redis_service.get_client()
key = f"{self.KEY_PREFIX}{user_id}:tasks"
await client.hdel(key, task_id)
return True
except Exception as e:
logger.error(f"移除任务失败: {e}")
return False
# 全局实例
task_tracker = TaskTracker()