Files
life-echo/api/tasks/memoir_tasks.py
penghanyuan 8b4a058640 feat: 优化回忆录内容处理和章节分类逻辑
- 更新 get_system_prompt 函数,增强对话内容的核心信息提炼和分类能力,确保只保留与人生经历相关的实质内容。
- 修改 _classify_chapter_category 函数,增加对无实质回忆录价值内容的处理,返回 None 以跳过无效段落。
- 在 Android 客户端中,更新章节阅读视图以移除内嵌章节标题,提升排版一致性。
- 新增 TextUtils 工具函数,专门用于移除 LLM 生成的内嵌章节标题,确保正文内容的流畅性。
2026-03-02 19:47:32 +01:00

500 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录处理 Celery 任务
"""
import json
import logging
import os
import uuid
from typing import Dict, List
from datetime import datetime, timezone
import redis
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session
from database.database import SessionLocal
from database.models import Book, Chapter, Segment, MemoirState, User
from services.llm_service import llm_service
from agents.state_schema import MemoirStateSchema, SlotData, default_state
from agents.prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
get_chapter_classification_prompt,
STAGE_TO_ORDER,
CHAPTER_CATEGORIES,
)
from agents.prompts.profile_prompts import format_user_profile_context
logger = logging.getLogger(__name__)
def _acquire_chapter_lock(user_id: str, stage: str, timeout: int = 120) -> bool:
"""获取章节分布式锁,防止并发写入同一章节"""
r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
lock_key = f"lock:chapter:{user_id}:{stage}"
return r.set(lock_key, "1", nx=True, ex=timeout)
def _release_chapter_lock(user_id: str, stage: str):
"""释放章节分布式锁"""
r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
lock_key = f"lock:chapter:{user_id}:{stage}"
r.delete(lock_key)
def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None):
"""同步更新任务状态(在 Celery 任务中使用)"""
try:
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
r = redis.from_url(redis_url, decode_responses=True)
key = f"task:user:{user_id}:tasks"
# 获取现有任务信息
data = r.hget(key, task_id)
if data:
task_info = json.loads(data)
else:
task_info = {"task_id": task_id}
task_info["status"] = status
task_info["updated_at"] = datetime.now(timezone.utc).isoformat()
if result is not None:
task_info["result"] = result
r.hset(key, task_id, json.dumps(task_info))
r.expire(key, 3600) # 1小时过期
logger.info(f"任务状态已更新: task_id={task_id}, status={status}")
except Exception as e:
logger.error(f"更新任务状态失败: {e}")
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
# 5-stage → 默认 8-category 映射LLM 分类失败时的兜底)
_STAGE_TO_DEFAULT_CATEGORY = {
"childhood": "childhood",
"education": "education",
"career": "career_early",
"family": "family",
"belief": "beliefs",
}
def _detect_stage(user_message: str, fallback_stage: str) -> str:
"""检测消息所属的 5-stage 阶段(用于状态跟踪)"""
message = user_message.lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
def _classify_chapter_category(text: str, fallback_stage: str, llm=None) -> str | None:
"""
将内容分类到 8 个章节类别之一。
优先使用 LLM失败则按 5-stage 关键词映射到默认类别。
如果 LLM 判定内容无实质回忆录价值,返回 None。
"""
if llm:
try:
prompt = get_chapter_classification_prompt(text)
response = llm.invoke(prompt)
category = response.content.strip().lower()
if category == "none":
logger.info(f"LLM 判定内容无回忆录价值,跳过: {text[:80]}...")
return None
if category in CHAPTER_CATEGORIES:
return category
except Exception as e:
logger.warning(f"LLM 章节分类失败: {e}")
stage = _detect_stage(text, fallback_stage)
return _STAGE_TO_DEFAULT_CATEGORY.get(stage, _STAGE_TO_DEFAULT_CATEGORY.get(fallback_stage, "childhood"))
def _coerce_state(model: MemoirState) -> MemoirStateSchema:
"""将数据库模型转换为 Schema"""
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": model.slots if isinstance(model.slots, dict) else default_state().slots,
}
)
def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
"""同步获取或创建状态"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return _coerce_state(state)
default = default_state()
state = MemoirState(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()},
)
db.add(state)
db.commit()
db.refresh(state)
return _coerce_state(state)
def _update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: Session,
) -> MemoirStateSchema:
"""同步更新 slot"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
_get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
slots: Dict[str, Dict] = state.slots or {}
stage_slots = slots.get(stage, {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump()
slots[stage] = stage_slots
state.slots = slots
state.current_stage = state.current_stage or stage
db.commit()
db.refresh(state)
return _coerce_state(state)
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
"""
处理回忆录段落的 Celery 任务
Args:
user_id: 用户 ID
segment_ids: 段落 ID 列表
"""
task_id = self.request.id
logger.info(f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}")
# 更新任务状态为 running
_update_task_status_sync(user_id, task_id, "running")
try:
db = SessionLocal()
try:
# 获取段落
stmt = select(Segment).where(Segment.id.in_(segment_ids))
result = db.execute(stmt)
segments = result.scalars().all()
if not segments:
logger.warning(f"未找到段落: {segment_ids}")
return {"status": "no_segments"}
# 获取用户状态和资料
state = _get_or_create_state_sync(user_id, db)
llm = llm_service.get_llm()
user_obj = db.get(User, user_id)
user_profile = ""
user_birth_year = None
if user_obj:
user_birth_year = user_obj.birth_year
user_profile = format_user_profile_context(
birth_year=user_obj.birth_year,
birth_place=user_obj.birth_place,
grew_up_place=user_obj.grew_up_place,
occupation=user_obj.occupation,
)
# 分两步处理:
# 1) 5-stage 状态跟踪slots
# 2) 8-category 章节分类chapter creation
category_to_segments: Dict[str, List[Segment]] = {}
for segment in segments:
text = segment.transcript_text
detected_stage = _detect_stage(text, state.current_stage)
# 提取 slots5-stage 状态跟踪)
extracted_slots = {}
if llm:
try:
prompt = get_state_extraction_prompt(
user_message=text,
current_stage=state.current_stage,
stage_slots=state.slots.get(detected_stage, {}),
)
response = llm.invoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
detected_stage = parsed.get("detected_stage", detected_stage)
extracted_slots = parsed.get("slots", {}) or {}
except (json.JSONDecodeError, Exception) as e:
logger.warning(f"LLM 解析失败: {e}")
for slot_name, snippet in extracted_slots.items():
state = _update_slot_sync(
user_id=user_id,
stage=detected_stage,
slot_name=slot_name,
snippet=snippet,
segment_ids=[segment.id],
db=db,
)
# 8-category 章节分类
chapter_category = _classify_chapter_category(text, detected_stage, llm)
if chapter_category is None:
logger.info(f"段落无回忆录价值,跳过: segment_id={segment.id}")
continue
category_to_segments.setdefault(chapter_category, []).append(segment)
# 按 8 分类生成章节内容
for stage, stage_segments in category_to_segments.items():
if not _acquire_chapter_lock(user_id, stage):
logger.warning(f"章节锁竞争: user={user_id}, stage={stage}, 延迟重试")
raise self.retry(countdown=10)
try:
segment_texts = [seg.transcript_text for seg in stage_segments]
combined_text = "\n\n".join(segment_texts)
source_ids = [seg.id for seg in stage_segments]
# 查找 active 章节(被清除的章节不继续更新,而是创建新的)
stmt_chapter = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
Chapter.is_active == True,
)
result_chapter = db.execute(stmt_chapter)
chapter = result_chapter.scalar_one_or_none()
# 获取 slot snippets
slot_snippets = {
key: value.snippet
for key, value in (state.slots.get(stage, {}) or {}).items()
if value.snippet
}
# 生成标题和内容
title = chapter.title if chapter else f"{stage} 回忆"
existing_content = chapter.content if chapter else ""
narrative = combined_text
if llm:
try:
if not chapter:
title_prompt = get_creative_title_prompt(
stage=stage,
emotion="neutral",
slots=slot_snippets,
user_profile=user_profile,
birth_year=user_birth_year,
)
title_response = llm.invoke(title_prompt)
title = title_response.content.strip().strip('"')
narrative_prompt = get_narrative_prompt(
stage=stage,
slots=slot_snippets,
new_content=combined_text,
existing_content=existing_content,
user_profile=user_profile,
birth_year=user_birth_year,
)
narrative_response = llm.invoke(narrative_prompt)
new_narrative = narrative_response.content.strip()
# 追加而非替换
if existing_content:
narrative = f"{existing_content}\n\n{new_narrative}"
else:
narrative = new_narrative
except Exception as e:
logger.warning(f"LLM 生成失败: {e}")
if existing_content:
narrative = f"{existing_content}\n\n{combined_text}"
# 安全检查:新内容不应比旧内容短
if existing_content and len(narrative) < len(existing_content) * 0.8:
logger.warning(
f"内容长度异常: existing={len(existing_content)}, "
f"new={len(narrative)}, stage={stage}. 回退为追加模式"
)
narrative = f"{existing_content}\n\n{combined_text}"
# 更新或创建章节
if chapter:
chapter.content = narrative
chapter.title = title
chapter.is_new = True
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
else:
# 根据 stage 计算正确的排序索引
calculated_order_index = STAGE_TO_ORDER.get(stage, 999)
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
content=narrative,
order_index=calculated_order_index,
status="completed",
category=stage,
images=[],
is_new=True,
source_segments=source_ids,
)
db.add(chapter)
db.flush()
# 更新 Book
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
result_book = db.execute(stmt_book)
book = result_book.scalar_one_or_none()
if not book:
book = Book(
id=str(uuid.uuid4()),
user_id=user_id,
title="我的回忆录",
total_pages=0,
total_words=0,
cover_image_url=None,
)
db.add(book)
book.has_update = True
book.last_update_chapter_id = chapter.id
finally:
_release_chapter_lock(user_id, stage)
# 标记段落为已处理
for seg in segments:
seg.processed = True
db.commit()
logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}")
# 更新任务状态为成功
_update_task_status_sync(user_id, task_id, "success", {"processed": len(segments)})
return {"status": "success", "processed": len(segments)}
finally:
db.close()
except Exception as e:
logger.error(f"回忆录处理失败: {e}")
# 更新任务状态为失败
_update_task_status_sync(user_id, task_id, "failure", {"error": str(e)})
# 重试
raise self.retry(exc=e)
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
"""
单独生成章节内容的任务(用于实时更新)
Args:
user_id: 用户 ID
stage: 阶段
new_content: 新内容
"""
logger.info(f"生成章节内容: user_id={user_id}, stage={stage}")
try:
db = SessionLocal()
try:
llm = llm_service.get_llm()
# 查找 active 章节(被清除的章节不继续更新,而是创建新的)
stmt = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
Chapter.is_active == True,
)
result = db.execute(stmt)
chapter = result.scalar_one_or_none()
existing_content = chapter.content if chapter else ""
if llm:
prompt = get_narrative_prompt(
stage=stage,
slots={},
new_content=new_content,
existing_content=existing_content,
)
response = llm.invoke(prompt)
new_narrative = response.content.strip()
# 追加而非替换
if existing_content:
narrative = f"{existing_content}\n\n{new_narrative}"
else:
narrative = new_narrative
else:
narrative = f"{existing_content}\n\n{new_content}" if existing_content else new_content
# 安全检查:新内容不应比旧内容短
if existing_content and len(narrative) < len(existing_content) * 0.8:
logger.warning(
f"内容长度异常: existing={len(existing_content)}, "
f"new={len(narrative)}, stage={stage}. 回退为追加模式"
)
narrative = f"{existing_content}\n\n{new_content}"
if chapter:
chapter.content = narrative
chapter.is_new = True
else:
# 根据 stage 计算正确的排序索引
calculated_order_index = STAGE_TO_ORDER.get(stage, 999)
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=f"{stage} 回忆",
content=narrative,
order_index=calculated_order_index,
status="completed",
category=stage,
images=[],
is_new=True,
source_segments=[],
)
db.add(chapter)
db.commit()
return {"status": "success"}
finally:
db.close()
except Exception as e:
logger.error(f"章节生成失败: {e}")
raise self.retry(exc=e)