feat: 新增任务状态管理和API支持

- 添加任务状态API路由,支持获取当前用户的任务状态和待处理任务列表
- 实现任务追踪服务,使用Redis存储任务状态
- 更新回忆录处理逻辑,集成Celery任务提交和状态更新
- 增强测试用例,支持任务状态的获取和清除功能
- 优化代码结构,提升可读性和维护性
This commit is contained in:
penghanyuan
2026-01-21 23:37:00 +01:00
parent 0591e9d7c1
commit 3f899aa16c
6 changed files with 451 additions and 23 deletions

View File

@@ -146,6 +146,30 @@ class BackgroundTaskRunner:
self.analyzer = ContentAnalyzer()
self.generator = MemoirGenerator()
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
"""
提交 Celery 任务并记录
Returns:
任务 ID失败返回 None
"""
try:
from tasks.memoir_tasks import process_memoir_segments
from services.task_tracker import task_tracker
# 提交到 Celery
result = process_memoir_segments.delay(user_id, segment_ids)
task_id = result.id
# 记录任务
await task_tracker.add_task(user_id, task_id, "memoir")
logger.info(f"已提交 Celery 任务: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}")
return task_id
except Exception as e:
logger.error(f"提交 Celery 任务失败: {e}")
return None
async def queue_message(self, user_id: str, segment_id: str) -> None:
"""
将消息加入处理队列
@@ -167,20 +191,20 @@ class BackgroundTaskRunner:
await asyncio.sleep(self.debounce_seconds)
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
# 提交到 Celery
from tasks.memoir_tasks import process_memoir_segments
process_memoir_segments.delay(user_id, segment_ids)
logger.info(f"已提交 Celery 任务: user_id={user_id}, segments={len(segment_ids)}")
await self._submit_task(user_id, segment_ids)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"提交 Celery 任务失败: {e}")
logger.error(f"延迟提交任务失败: {e}")
self._timers[user_id] = asyncio.create_task(delayed_submit())
async def flush_pending(self, user_id: str) -> None:
async def flush_pending(self, user_id: str) -> str | None:
"""
立即提交用户的待处理任务(用于对话结束时)
Returns:
任务 ID无任务或失败返回 None
"""
# 取消定时器
if user_id in self._timers:
@@ -190,6 +214,5 @@ class BackgroundTaskRunner:
# 提交待处理任务
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
from tasks.memoir_tasks import process_memoir_segments
process_memoir_segments.delay(user_id, segment_ids)
logger.info(f"立即提交 Celery 任务: user_id={user_id}, segments={len(segment_ids)}")
return await self._submit_task(user_id, segment_ids)
return None

View File

@@ -52,7 +52,7 @@ else:
load_dotenv()
from database import init_db
from routers import websocket, chapters, books, conversations, auth, memoir_state
from routers import websocket, chapters, books, conversations, auth, memoir_state, tasks
# 初始化数据库
logger.info("正在初始化数据库...")
@@ -134,6 +134,7 @@ app.include_router(conversations.router)
app.include_router(chapters.router)
app.include_router(books.router)
app.include_router(memoir_state.router)
app.include_router(tasks.router) # 任务状态路由
@app.get("/")

78
api/routers/tasks.py Normal file
View File

@@ -0,0 +1,78 @@
"""
任务状态 API 路由
"""
from typing import List, Dict
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from database.models import User as UserModel
from middleware.auth import get_current_user
from services.task_tracker import task_tracker
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
class TaskInfo(BaseModel):
"""任务信息"""
task_id: str
task_type: str = "memoir"
status: str # pending, running, success, failure
created_at: str = None
updated_at: str = None
result: Dict = None
class TasksStatusResponse(BaseModel):
"""任务状态汇总响应"""
total: int
pending: int
running: int
success: int
failure: int
all_completed: bool
tasks: List[TaskInfo]
@router.get("/status", response_model=TasksStatusResponse)
async def get_tasks_status(
current_user: UserModel = Depends(get_current_user)
):
"""
获取当前用户的任务状态汇总
用于检查后台任务是否全部完成
"""
status = await task_tracker.check_tasks_status(current_user.id)
tasks = await task_tracker.get_user_tasks(current_user.id)
return TasksStatusResponse(
total=status["total"],
pending=status["pending"],
running=status["running"],
success=status["success"],
failure=status["failure"],
all_completed=status["all_completed"],
tasks=[TaskInfo(**t) for t in tasks]
)
@router.get("/pending")
async def get_pending_tasks(
current_user: UserModel = Depends(get_current_user)
):
"""
获取当前用户的待处理任务列表
"""
tasks = await task_tracker.get_pending_tasks(current_user.id)
return {"pending_tasks": tasks}
@router.delete("/clear")
async def clear_tasks(
current_user: UserModel = Depends(get_current_user)
):
"""
清除当前用户的所有任务记录
"""
await task_tracker.clear_user_tasks(current_user.id)
return {"message": "任务记录已清除"}

View File

@@ -0,0 +1,202 @@
"""
任务追踪服务:追踪 Celery 任务状态
"""
import json
import logging
from typing import List, Dict, Optional, Any
from datetime import datetime, timezone
from services.redis_service import redis_service
logger = logging.getLogger(__name__)
class TaskTracker:
"""任务追踪器,使用 Redis 存储任务状态"""
# Redis key 前缀
KEY_PREFIX = "task:user:"
# 任务记录过期时间1小时
TASK_TTL = 3600
async def add_task(self, user_id: str, task_id: str, task_type: str = "memoir") -> bool:
"""
记录新任务
Args:
user_id: 用户 ID
task_id: Celery 任务 ID
task_type: 任务类型
Returns:
是否成功
"""
try:
client = await redis_service.get_client()
key = f"{self.KEY_PREFIX}{user_id}:tasks"
task_info = {
"task_id": task_id,
"task_type": task_type,
"status": "pending",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 使用 hash 存储任务信息
await client.hset(key, task_id, json.dumps(task_info))
await client.expire(key, self.TASK_TTL)
logger.info(f"任务已记录: user_id={user_id}, task_id={task_id}")
return True
except Exception as e:
logger.error(f"记录任务失败: {e}")
return False
async def update_task_status(self, user_id: str, task_id: str, status: str, result: Any = None) -> bool:
"""
更新任务状态
Args:
user_id: 用户 ID
task_id: Celery 任务 ID
status: 新状态 (pending, running, success, failure)
result: 任务结果
Returns:
是否成功
"""
try:
client = await redis_service.get_client()
key = f"{self.KEY_PREFIX}{user_id}:tasks"
# 获取现有任务信息
data = await client.hget(key, task_id)
if data:
task_info = json.loads(data)
else:
task_info = {"task_id": task_id}
task_info["status"] = status
task_info["updated_at"] = datetime.now(timezone.utc).isoformat()
if result is not None:
task_info["result"] = result
await client.hset(key, task_id, json.dumps(task_info))
return True
except Exception as e:
logger.error(f"更新任务状态失败: {e}")
return False
async def get_user_tasks(self, user_id: str) -> List[Dict]:
"""
获取用户的所有任务
Args:
user_id: 用户 ID
Returns:
任务列表
"""
try:
client = await redis_service.get_client()
key = f"{self.KEY_PREFIX}{user_id}:tasks"
tasks_data = await client.hgetall(key)
tasks = []
for task_id, data in tasks_data.items():
task_info = json.loads(data)
tasks.append(task_info)
return tasks
except Exception as e:
logger.error(f"获取用户任务失败: {e}")
return []
async def get_pending_tasks(self, user_id: str) -> List[Dict]:
"""
获取用户的待处理任务
Args:
user_id: 用户 ID
Returns:
待处理任务列表
"""
tasks = await self.get_user_tasks(user_id)
return [t for t in tasks if t.get("status") in ("pending", "running")]
async def check_tasks_status(self, user_id: str) -> Dict:
"""
检查用户任务状态汇总
Args:
user_id: 用户 ID
Returns:
状态汇总 {total, pending, running, success, failure, all_completed}
"""
tasks = await self.get_user_tasks(user_id)
status_counts = {
"total": len(tasks),
"pending": 0,
"running": 0,
"success": 0,
"failure": 0,
}
for task in tasks:
status = task.get("status", "pending")
if status in status_counts:
status_counts[status] += 1
status_counts["all_completed"] = (
status_counts["total"] > 0 and
status_counts["pending"] == 0 and
status_counts["running"] == 0
)
return status_counts
async def clear_user_tasks(self, user_id: str) -> bool:
"""
清除用户的所有任务记录
Args:
user_id: 用户 ID
Returns:
是否成功
"""
try:
client = await redis_service.get_client()
key = f"{self.KEY_PREFIX}{user_id}:tasks"
await client.delete(key)
return True
except Exception as e:
logger.error(f"清除用户任务失败: {e}")
return False
async def remove_task(self, user_id: str, task_id: str) -> bool:
"""
移除单个任务记录
Args:
user_id: 用户 ID
task_id: 任务 ID
Returns:
是否成功
"""
try:
client = await redis_service.get_client()
key = f"{self.KEY_PREFIX}{user_id}:tasks"
await client.hdel(key, task_id)
return True
except Exception as e:
logger.error(f"移除任务失败: {e}")
return False
# 全局实例
task_tracker = TaskTracker()

View File

@@ -3,9 +3,12 @@
"""
import json
import logging
import os
import uuid
from typing import Dict, List
from datetime import datetime, timezone
import redis
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -22,6 +25,34 @@ from agents.prompts.memory_prompts import (
logger = logging.getLogger(__name__)
def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None):
"""同步更新任务状态(在 Celery 任务中使用)"""
try:
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
r = redis.from_url(redis_url, decode_responses=True)
key = f"task:user:{user_id}:tasks"
# 获取现有任务信息
data = r.hget(key, task_id)
if data:
task_info = json.loads(data)
else:
task_info = {"task_id": task_id}
task_info["status"] = status
task_info["updated_at"] = datetime.now(timezone.utc).isoformat()
if result is not None:
task_info["result"] = result
r.hset(key, task_id, json.dumps(task_info))
r.expire(key, 3600) # 1小时过期
logger.info(f"任务状态已更新: task_id={task_id}, status={status}")
except Exception as e:
logger.error(f"更新任务状态失败: {e}")
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
@@ -115,7 +146,11 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
user_id: 用户 ID
segment_ids: 段落 ID 列表
"""
logger.info(f"开始处理回忆录段落: user_id={user_id}, segments={len(segment_ids)}")
task_id = self.request.id
logger.info(f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}")
# 更新任务状态为 running
_update_task_status_sync(user_id, task_id, "running")
try:
db = SessionLocal()
@@ -265,7 +300,11 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
seg.processed = True
db.commit()
logger.info(f"回忆录处理完成: user_id={user_id}")
logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}")
# 更新任务状态为成功
_update_task_status_sync(user_id, task_id, "success", {"processed": len(segments)})
return {"status": "success", "processed": len(segments)}
finally:
@@ -273,6 +312,10 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
except Exception as e:
logger.error(f"回忆录处理失败: {e}")
# 更新任务状态为失败
_update_task_status_sync(user_id, task_id, "failure", {"error": str(e)})
# 重试
raise self.retry(exc=e)

View File

@@ -117,6 +117,25 @@ class ConversationTester:
return {"message": "获取失败"}
return resp.json()
async def get_tasks_status(self):
"""获取任务状态"""
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.get(
f"{BASE_URL}/api/tasks/status",
headers={"Authorization": f"Bearer {self.token}"}
)
if resp.status_code != 200:
return {"total": 0, "all_completed": True, "tasks": []}
return resp.json()
async def clear_tasks(self):
"""清除任务记录"""
async with httpx.AsyncClient(timeout=60.0) as client:
await client.delete(
f"{BASE_URL}/api/tasks/clear",
headers={"Authorization": f"Bearer {self.token}"}
)
async def run_conversation(self):
"""运行多轮对话"""
print(f"\n🔗 连接 WebSocket: {self.conversation_id}")
@@ -187,18 +206,75 @@ class ConversationTester:
except asyncio.TimeoutError:
print("⏰ 等待结束确认超时(但后台处理可能仍在进行)")
async def wait_for_processing(self, max_wait_seconds: int = 300, check_interval: int = 3):
"""
等待后台处理完成
通过查询 Celery 任务状态来判断处理是否完成
Args:
max_wait_seconds: 最大等待时间(秒),默认 5 分钟
check_interval: 检查间隔(秒)
Returns:
是否在超时前完成
"""
print(f"\n⏳ 等待后台任务完成(最多 {max_wait_seconds} 秒)...")
print(" 提示: 通过 Celery 任务状态 API 追踪任务进度")
start_time = asyncio.get_event_loop().time()
while True:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed >= max_wait_seconds:
print(f"\n⚠️ 已等待 {max_wait_seconds} 秒,超时退出")
return False
# 检查任务状态
tasks_status = await self.get_tasks_status()
total = tasks_status.get("total", 0)
pending = tasks_status.get("pending", 0)
running = tasks_status.get("running", 0)
success = tasks_status.get("success", 0)
failure = tasks_status.get("failure", 0)
all_completed = tasks_status.get("all_completed", False)
# 同时检查章节内容
chapters = await self.get_chapters()
chapter_count = len(chapters)
total_content_length = sum(len(ch.get('content', '')) for ch in chapters)
status_str = f"📊 总:{total} 等待:{pending} 运行:{running} 成功:{success} 失败:{failure}"
content_str = f"📚 章节:{chapter_count} 内容:{total_content_length}字符"
print(f" [{int(elapsed):3d}s] {status_str} | {content_str}")
# 判断是否完成:
# 1. 有任务且全部完成
# 2. 或者没有任务但有章节内容(兼容旧逻辑)
if total > 0 and all_completed:
print(f"\n✅ 所有任务已完成!共 {total} 个任务,等待 {int(elapsed)}")
return True
# 如果没有任务记录,等待一会儿任务提交
if total == 0 and elapsed < 15:
await asyncio.sleep(check_interval)
continue
# 如果长时间没有任务但有内容,也认为完成
if total == 0 and chapter_count > 0 and elapsed > 30:
print(f"\n✅ 无待处理任务,已有 {chapter_count} 个章节。等待 {int(elapsed)}")
return True
await asyncio.sleep(check_interval)
async def check_results(self):
"""检查回忆录生成结果"""
print(f"\n{'='*60}")
print("📊 检查结果")
print(f"{'='*60}")
# 等待后台处理完成(后台任务有5秒debounce + 多次LLM调用
# 每条消息可能触发2-3次LLM调用7条消息 = 约20次调用每次约3秒
print("\n⏳ 等待后台处理完成...")
for i in range(6): # 60秒每10秒检查一次
print(f" 等待中... {(i+1)*10}")
await asyncio.sleep(10)
# 等待后台处理完成(使用智能轮询
await self.wait_for_processing(max_wait_seconds=180, check_interval=5)
# 获取回忆录状态
print("\n📋 回忆录状态:")
@@ -223,7 +299,8 @@ class ConversationTester:
if chapters:
for ch in chapters:
is_new = "🆕" if ch.get("is_new") else ""
print(f" {is_new} [{ch.get('category', 'N/A')}] {ch.get('title', 'N/A')}")
content_len = len(ch.get('content', ''))
print(f" {is_new} [{ch.get('category', 'N/A')}] {ch.get('title', 'N/A')} ({content_len} 字符)")
else:
print(" (暂无章节)")
@@ -248,7 +325,7 @@ class ConversationTester:
content = ch.get('content', '')
print(f"\n{''*60}")
print(f"{title}")
print(f"{title}({category})")
print(f"{''*60}")
if content:
print(content)
@@ -270,15 +347,19 @@ async def main():
# 1. 注册/登录
await tester.register_or_login()
# 2. 查看初始状态
# 2. 清除旧的任务记录
await tester.clear_tasks()
print("\n🧹 已清除旧的任务记录")
# 3. 查看初始状态
print("\n📋 初始回忆录状态:")
state = await tester.get_memoir_state()
print(f" 当前阶段: {state.get('current_stage', 'N/A')}")
# 3. 运行多轮对话
# 4. 运行多轮对话
await tester.run_conversation()
# 4. 检查结果
# 5. 检查结果
await tester.check_results()
except Exception as e: