feat: 添加Redis支持和Celery任务处理

- 新增Redis服务模块用于会话状态存储和缓存
- 集成Celery用于后台任务处理
- 更新Docker Compose配置以支持开发环境
- 优化API以支持异步调用和Redis会话存储
- 更新文档以反映新的开发环境配置和使用方法
This commit is contained in:
penghanyuan
2026-01-21 23:06:47 +01:00
parent 44bd478c1e
commit dbbb924625
16 changed files with 1339 additions and 309 deletions

7
api/tasks/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
Celery 任务模块
"""
from .celery_app import celery_app
from .memoir_tasks import process_memoir_segments
__all__ = ["celery_app", "process_memoir_segments"]

58
api/tasks/celery_app.py Normal file
View File

@@ -0,0 +1,58 @@
"""
Celery 应用配置
"""
import os
from celery import Celery
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# Redis URL
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
# 创建 Celery 应用
celery_app = Celery(
"life_echo",
broker=REDIS_URL,
backend=REDIS_URL,
include=["tasks.memoir_tasks"]
)
# Celery 配置
celery_app.conf.update(
# 任务序列化
task_serializer="json",
accept_content=["json"],
result_serializer="json",
# 时区
timezone="UTC",
enable_utc=True,
# 任务结果过期时间1小时
result_expires=3600,
# 任务执行设置
task_soft_time_limit=300, # 5分钟软超时
task_time_limit=600, # 10分钟硬超时
# 并发设置
worker_prefetch_multiplier=1, # 每次只预取一个任务
worker_concurrency=4, # 并发 worker 数量
# 任务重试设置
task_acks_late=True, # 任务完成后再确认
task_reject_on_worker_lost=True, # worker 丢失时拒绝任务
# 不设置自定义队列路由,使用 Celery 默认队列
)
# 定时任务配置(如果需要)
celery_app.conf.beat_schedule = {
# 示例:每小时清理过期会话
# "cleanup-expired-sessions": {
# "task": "tasks.cleanup.cleanup_sessions",
# "schedule": 3600.0,
# },
}

345
api/tasks/memoir_tasks.py Normal file
View File

@@ -0,0 +1,345 @@
"""
回忆录处理 Celery 任务
"""
import json
import logging
import uuid
from typing import Dict, List
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session
from database.database import SessionLocal
from database.models import Book, Chapter, Segment, MemoirState
from services.llm_service import llm_service
from agents.state_schema import MemoirStateSchema, SlotData, default_state
from agents.prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
)
logger = logging.getLogger(__name__)
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
def _detect_stage(user_message: str, fallback_stage: str) -> str:
"""检测消息所属阶段"""
message = user_message.lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
def _coerce_state(model: MemoirState) -> MemoirStateSchema:
"""将数据库模型转换为 Schema"""
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": model.slots if isinstance(model.slots, dict) else default_state().slots,
}
)
def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
"""同步获取或创建状态"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return _coerce_state(state)
default = default_state()
state = MemoirState(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()},
)
db.add(state)
db.commit()
db.refresh(state)
return _coerce_state(state)
def _update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: Session,
) -> MemoirStateSchema:
"""同步更新 slot"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
_get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
slots: Dict[str, Dict] = state.slots or {}
stage_slots = slots.get(stage, {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump()
slots[stage] = stage_slots
state.slots = slots
state.current_stage = state.current_stage or stage
db.commit()
db.refresh(state)
return _coerce_state(state)
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
"""
处理回忆录段落的 Celery 任务
Args:
user_id: 用户 ID
segment_ids: 段落 ID 列表
"""
logger.info(f"开始处理回忆录段落: user_id={user_id}, segments={len(segment_ids)}")
try:
db = SessionLocal()
try:
# 获取段落
stmt = select(Segment).where(Segment.id.in_(segment_ids))
result = db.execute(stmt)
segments = result.scalars().all()
if not segments:
logger.warning(f"未找到段落: {segment_ids}")
return {"status": "no_segments"}
# 获取用户状态
state = _get_or_create_state_sync(user_id, db)
llm = llm_service.get_llm()
# 按阶段分组处理
stage_to_segments: Dict[str, List[Segment]] = {}
for segment in segments:
text = segment.transcript_text
detected_stage = _detect_stage(text, state.current_stage)
# 尝试使用 LLM 提取信息
extracted_slots = {}
if llm:
try:
prompt = get_state_extraction_prompt(
user_message=text,
current_stage=state.current_stage,
stage_slots=state.slots.get(detected_stage, {}),
)
response = llm.invoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
detected_stage = parsed.get("detected_stage", detected_stage)
extracted_slots = parsed.get("slots", {}) or {}
except (json.JSONDecodeError, Exception) as e:
logger.warning(f"LLM 解析失败: {e}")
# 更新 slots
for slot_name, snippet in extracted_slots.items():
state = _update_slot_sync(
user_id=user_id,
stage=detected_stage,
slot_name=slot_name,
snippet=snippet,
segment_ids=[segment.id],
db=db,
)
stage_to_segments.setdefault(detected_stage, []).append(segment)
# 生成章节内容
for stage, stage_segments in stage_to_segments.items():
segment_texts = [seg.transcript_text for seg in stage_segments]
combined_text = "\n\n".join(segment_texts)
source_ids = [seg.id for seg in stage_segments]
# 查找或创建章节
stmt_chapter = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
)
result_chapter = db.execute(stmt_chapter)
chapter = result_chapter.scalar_one_or_none()
# 获取 slot snippets
slot_snippets = {
key: value.snippet
for key, value in (state.slots.get(stage, {}) or {}).items()
if value.snippet
}
# 生成标题和内容
title = chapter.title if chapter else f"{stage} 回忆"
existing_content = chapter.content if chapter else ""
narrative = combined_text
if llm:
try:
if not chapter:
title_prompt = get_creative_title_prompt(
stage=stage,
emotion="neutral",
slots=slot_snippets
)
title_response = llm.invoke(title_prompt)
title = title_response.content.strip().strip('"')
narrative_prompt = get_narrative_prompt(
stage=stage,
slots=slot_snippets,
new_content=combined_text,
existing_content=existing_content,
)
narrative_response = llm.invoke(narrative_prompt)
narrative = narrative_response.content.strip()
except Exception as e:
logger.warning(f"LLM 生成失败: {e}")
if existing_content:
narrative = f"{existing_content}\n\n{combined_text}"
# 更新或创建章节
if chapter:
chapter.content = narrative
chapter.title = title
chapter.is_new = True
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
else:
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
content=narrative,
order_index=999,
status="completed",
category=stage,
images=[],
is_new=True,
source_segments=source_ids,
)
db.add(chapter)
db.flush()
# 更新 Book
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
result_book = db.execute(stmt_book)
book = result_book.scalar_one_or_none()
if not book:
book = Book(
id=str(uuid.uuid4()),
user_id=user_id,
title="我的回忆录",
total_pages=0,
total_words=0,
cover_image_url=None,
)
db.add(book)
book.has_update = True
book.last_update_chapter_id = chapter.id
# 标记段落为已处理
for seg in segments:
seg.processed = True
db.commit()
logger.info(f"回忆录处理完成: user_id={user_id}")
return {"status": "success", "processed": len(segments)}
finally:
db.close()
except Exception as e:
logger.error(f"回忆录处理失败: {e}")
# 重试
raise self.retry(exc=e)
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
"""
单独生成章节内容的任务(用于实时更新)
Args:
user_id: 用户 ID
stage: 阶段
new_content: 新内容
"""
logger.info(f"生成章节内容: user_id={user_id}, stage={stage}")
try:
db = SessionLocal()
try:
llm = llm_service.get_llm()
# 查找章节
stmt = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
)
result = db.execute(stmt)
chapter = result.scalar_one_or_none()
existing_content = chapter.content if chapter else ""
if llm:
prompt = get_narrative_prompt(
stage=stage,
slots={},
new_content=new_content,
existing_content=existing_content,
)
response = llm.invoke(prompt)
narrative = response.content.strip()
else:
narrative = f"{existing_content}\n\n{new_content}" if existing_content else new_content
if chapter:
chapter.content = narrative
chapter.is_new = True
else:
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=f"{stage} 回忆",
content=narrative,
order_index=999,
status="completed",
category=stage,
images=[],
is_new=True,
source_segments=[],
)
db.add(chapter)
db.commit()
return {"status": "success"}
finally:
db.close()
except Exception as e:
logger.error(f"章节生成失败: {e}")
raise self.retry(exc=e)