feat: 新增任务状态管理和API支持
- 添加任务状态API路由,支持获取当前用户的任务状态和待处理任务列表 - 实现任务追踪服务,使用Redis存储任务状态 - 更新回忆录处理逻辑,集成Celery任务提交和状态更新 - 增强测试用例,支持任务状态的获取和清除功能 - 优化代码结构,提升可读性和维护性
This commit is contained in:
@@ -3,9 +3,12 @@
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -22,6 +25,34 @@ from agents.prompts.memory_prompts import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None):
|
||||
"""同步更新任务状态(在 Celery 任务中使用)"""
|
||||
try:
|
||||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
r = redis.from_url(redis_url, decode_responses=True)
|
||||
|
||||
key = f"task:user:{user_id}:tasks"
|
||||
|
||||
# 获取现有任务信息
|
||||
data = r.hget(key, task_id)
|
||||
if data:
|
||||
task_info = json.loads(data)
|
||||
else:
|
||||
task_info = {"task_id": task_id}
|
||||
|
||||
task_info["status"] = status
|
||||
task_info["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
if result is not None:
|
||||
task_info["result"] = result
|
||||
|
||||
r.hset(key, task_id, json.dumps(task_info))
|
||||
r.expire(key, 3600) # 1小时过期
|
||||
|
||||
logger.info(f"任务状态已更新: task_id={task_id}, status={status}")
|
||||
except Exception as e:
|
||||
logger.error(f"更新任务状态失败: {e}")
|
||||
|
||||
STAGE_KEYWORDS = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
||||
@@ -115,7 +146,11 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
user_id: 用户 ID
|
||||
segment_ids: 段落 ID 列表
|
||||
"""
|
||||
logger.info(f"开始处理回忆录段落: user_id={user_id}, segments={len(segment_ids)}")
|
||||
task_id = self.request.id
|
||||
logger.info(f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}")
|
||||
|
||||
# 更新任务状态为 running
|
||||
_update_task_status_sync(user_id, task_id, "running")
|
||||
|
||||
try:
|
||||
db = SessionLocal()
|
||||
@@ -265,7 +300,11 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
seg.processed = True
|
||||
|
||||
db.commit()
|
||||
logger.info(f"回忆录处理完成: user_id={user_id}")
|
||||
logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}")
|
||||
|
||||
# 更新任务状态为成功
|
||||
_update_task_status_sync(user_id, task_id, "success", {"processed": len(segments)})
|
||||
|
||||
return {"status": "success", "processed": len(segments)}
|
||||
|
||||
finally:
|
||||
@@ -273,6 +312,10 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"回忆录处理失败: {e}")
|
||||
|
||||
# 更新任务状态为失败
|
||||
_update_task_status_sync(user_id, task_id, "failure", {"error": str(e)})
|
||||
|
||||
# 重试
|
||||
raise self.retry(exc=e)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user