Files
life-echo/api/app/tasks/memoir_tasks.py
2026-03-20 15:15:35 +08:00

741 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录处理 Celery 任务
"""
import json
from app.core.logging import get_logger
import uuid
from io import BytesIO
from typing import Dict, List, Set
from datetime import datetime, timezone
import redis
from celery import shared_task
from PIL import Image
from sqlalchemy import delete, select
from sqlalchemy.orm import Session, joinedload
from app.core.db import get_sync_db
from app.features.conversation.models import Segment
from app.features.memoir.models import (
Book,
Chapter,
MemoirImage,
MemoirState,
)
from app.features.user.models import User
from app.core.dependencies import get_llm_provider
from app.agents.state_schema import MemoirStateSchema, SlotData, default_state
from app.agents.memoir import MemoirOrchestrator
from app.agents.chat.prompts_profile import format_user_profile_context
from app.features.memoir.memoir_images.parser import (
build_initial_image_assets,
parse_image_placeholders,
)
import hashlib
from app.core.dependencies import get_image_generator
from app.agents.image_prompt import ImagePromptOrchestrator
from app.features.memoir.memoir_images.schema import (
completed_image_assets,
IMAGE_STATUS_COMPLETED,
IMAGE_STATUS_FAILED,
IMAGE_STATUS_PENDING,
IMAGE_STATUS_PROCESSING,
normalize_image_assets,
)
from app.features.memoir.memoir_images.serializers import (
image_dict_to_row_kwargs,
memoir_image_to_dict,
)
from app.features.memoir.memoir_images.settings import MemoirImageSettings
from app.ports.image_gen import TaskStatus
from app.features.memoir.memoir_images.storage import (
TencentCosStorageService,
CosUploadError,
)
from app.features.memoir.cover_eligibility import (
chapter_needs_cover_enqueue,
cover_memoir_image_pending_or_failed,
)
from app.features.memoir.story_pipeline_sync import (
run_story_pipeline_for_category_batch,
)
logger = get_logger(__name__)
_REDIS_CLIENTS: dict[bool, redis.Redis] = {}
def _get_llm():
"""Celery 任务内获取 LangChain LLM通过 port"""
try:
return getattr(get_llm_provider(), "langchain_llm", None)
except Exception:
return None
def _get_redis_client(*, decode_responses: bool = False) -> redis.Redis:
from app.core.config import settings
client = _REDIS_CLIENTS.get(decode_responses)
if client is None:
client = redis.from_url(
settings.redis_url,
decode_responses=decode_responses,
)
_REDIS_CLIENTS[decode_responses] = client
return client
def _acquire_chapter_lock(user_id: str, stage: str, timeout: int = 120) -> bool:
"""获取章节分布式锁,防止并发写入同一章节"""
r = _get_redis_client()
lock_key = f"lock:chapter:{user_id}:{stage}"
return r.set(lock_key, "1", nx=True, ex=timeout)
def _release_chapter_lock(user_id: str, stage: str):
"""释放章节分布式锁"""
r = _get_redis_client()
lock_key = f"lock:chapter:{user_id}:{stage}"
r.delete(lock_key)
def _acquire_chapter_image_lock(chapter_id: str, timeout: int = 600) -> bool:
"""获取章节补图分布式锁,避免同一章节重复补图。"""
r = _get_redis_client()
lock_key = f"lock:chapter-images:{chapter_id}"
return r.set(lock_key, "1", nx=True, ex=timeout)
def _release_chapter_image_lock(chapter_id: str):
"""释放章节补图分布式锁。"""
r = _get_redis_client()
lock_key = f"lock:chapter-images:{chapter_id}"
r.delete(lock_key)
def _update_task_status_sync(
user_id: str, task_id: str, status: str, result: Dict = None
):
"""同步更新任务状态(在 Celery 任务中使用)"""
try:
r = _get_redis_client(decode_responses=True)
key = f"task:user:{user_id}:tasks"
# 获取现有任务信息
data = r.hget(key, task_id)
if data:
task_info = json.loads(data)
else:
task_info = {"task_id": task_id}
task_info["status"] = status
task_info["updated_at"] = datetime.now(timezone.utc).isoformat()
if result is not None:
task_info["result"] = result
r.hset(key, task_id, json.dumps(task_info))
r.expire(key, 3600) # 1小时过期
logger.info(f"任务状态已更新: task_id={task_id}, status={status}")
except Exception as e:
logger.error(f"更新任务状态失败: {e}")
def _merge_chapter_image_assets(
existing_images: list[dict] | None,
placeholders: list[dict],
provider: str,
style: str,
size: str,
now_iso: str,
) -> list[dict]:
normalized_existing_images = normalize_image_assets(existing_images)
existing_by_placeholder = {
item.get("placeholder"): dict(item)
for item in normalized_existing_images
if item.get("placeholder")
}
merged_assets: list[dict] = []
for item in placeholders:
existing = existing_by_placeholder.get(item["placeholder"])
if existing:
merged_item = dict(existing)
merged_item["index"] = item["index"]
merged_item["placeholder"] = item["placeholder"]
merged_item["description"] = item["description"]
merged_item["provider"] = merged_item.get("provider") or provider
merged_item["style"] = merged_item.get("style") or style
merged_item["size"] = merged_item.get("size") or size
merged_item["created_at"] = merged_item.get("created_at") or now_iso
merged_item["updated_at"] = merged_item.get("updated_at") or now_iso
if merged_item.get("status") == IMAGE_STATUS_COMPLETED and not (
merged_item.get("storage_key") or merged_item.get("url")
):
merged_item["status"] = IMAGE_STATUS_FAILED
merged_item["error"] = merged_item.get("error") or "missing image url"
else:
merged_item = build_initial_image_assets(
placeholders=[item],
provider=provider,
style=style,
size=size,
now_iso=now_iso,
)[0]
merged_assets.append(merged_item)
return merged_assets
def chapter_has_images_to_generate(images: list[dict] | None) -> bool:
return any(
item.get("status") in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED}
for item in normalize_image_assets(images)
)
def _memoir_image_from_asset(
chapter_id: str,
order_index: int,
image_asset: dict,
) -> MemoirImage:
"""从单条图片 dict 构建 MemoirImage 行(用于写入 memoir_images 表)。"""
kwargs = image_dict_to_row_kwargs(image_asset)
return MemoirImage(
id=str(uuid.uuid4()).replace("-", "")[:32],
chapter_id=chapter_id,
order_index=order_index,
**kwargs,
)
def _select_placeholders_for_effective_max(
placeholders: list[dict],
existing_images: list[dict] | None,
effective_max: int,
) -> list[dict]:
existing_placeholders = {
item.get("placeholder")
for item in normalize_image_assets(existing_images)
if item.get("placeholder")
}
existing_count_in_content = sum(
1 for item in placeholders if item.get("placeholder") in existing_placeholders
)
remaining_new_slots = max(0, effective_max - existing_count_in_content)
selected: list[dict] = []
for item in placeholders:
if item.get("placeholder") in existing_placeholders:
selected.append(item)
continue
if remaining_new_slots <= 0:
continue
selected.append(item)
remaining_new_slots -= 1
return [{**item, "index": index} for index, item in enumerate(selected)]
def initialize_chapter_images(_chapter):
"""兼容旧调用:封面由 generate_chapter_cover 处理。"""
logger.info("initialize_chapter_images: 封面由 generate_chapter_cover 处理,跳过")
return []
def _normalize_image_bytes_for_storage(image_bytes: bytes) -> bytes:
with Image.open(BytesIO(image_bytes)) as image:
output = BytesIO()
if image.mode in {"RGBA", "LA"}:
normalized = image
elif image.mode == "P":
normalized = image.convert("RGBA")
else:
normalized = image.convert("RGB")
normalized.save(output, format="PNG")
return output.getvalue()
def _coerce_state(model: MemoirState) -> MemoirStateSchema:
"""将数据库模型转换为 Schema"""
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": model.slots
if isinstance(model.slots, dict)
else default_state().slots,
}
)
def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
"""同步获取或创建状态"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return _coerce_state(state)
default = default_state()
state = MemoirState(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={
k: {sk: sv.model_dump() for sk, sv in v.items()}
for k, v in default.slots.items()
},
)
db.add(state)
db.commit()
db.refresh(state)
return _coerce_state(state)
def _update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: Session,
) -> MemoirStateSchema:
"""同步更新 slot"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
_get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
slots: Dict[str, Dict] = state.slots or {}
stage_slots = slots.get(stage, {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(
snippet=snippet, segment_ids=merged_segment_ids
).model_dump()
slots[stage] = stage_slots
state.slots = slots
state.current_stage = state.current_stage or stage
db.commit()
db.refresh(state)
return _coerce_state(state)
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
"""
处理回忆录段落的 Celery 任务
Args:
user_id: 用户 ID
segment_ids: 段落 ID 列表
"""
task_id = self.request.id
logger.info(
f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}"
)
# 更新任务状态为 running
_update_task_status_sync(user_id, task_id, "running")
try:
with get_sync_db() as db:
chapters_to_enqueue: set[str] = set()
# 获取段落
stmt = select(Segment).where(Segment.id.in_(segment_ids))
result = db.execute(stmt)
segments = result.scalars().all()
if not segments:
logger.warning(f"未找到段落: {segment_ids}")
return {"status": "no_segments"}
# Memory ingest: transcript -> memory_sources, chunks, FTS
conv_id = getattr(segments[0], "conversation_id", None) or ""
transcript = "\n\n".join(seg.transcript_text or "" for seg in segments)
if transcript.strip():
try:
from app.features.memory.service import ingest_transcript_sync
ingest_transcript_sync(db, user_id, conv_id, transcript)
except Exception as e:
logger.warning("Memory ingest 跳过: %s", e)
# 获取用户状态和资料
state = _get_or_create_state_sync(user_id, db)
llm = _get_llm()
image_settings = MemoirImageSettings.from_env()
user_obj = db.get(User, user_id)
user_profile = ""
user_birth_year = None
if user_obj:
user_birth_year = user_obj.birth_year
user_profile = format_user_profile_context(
birth_year=user_obj.birth_year,
birth_place=user_obj.birth_place,
grew_up_place=user_obj.grew_up_place,
occupation=user_obj.occupation,
)
story_dispatch_ids: Set[str] = set()
def _process_category(
chapter_category: str,
category_segments: List,
state: MemoirStateSchema,
profile: str,
birth_year,
llm,
):
"""stories-first路由 + 写 story物化 chapter。"""
nonlocal story_dispatch_ids
chapter, needs_cover, disp = run_story_pipeline_for_category_batch(
db,
user_id=user_id,
chapter_category=chapter_category,
category_segments=category_segments,
state=state,
user_profile=profile,
user_birth_year=birth_year,
llm=llm,
)
story_dispatch_ids |= disp
db.flush()
db.refresh(chapter)
needs_cover_enqueue = (
image_settings.enabled and chapter_needs_cover_enqueue(chapter)
)
stmt_book = (
select(Book)
.where(Book.user_id == user_id)
.order_by(Book.updated_at.desc())
)
result_book = db.execute(stmt_book)
book = result_book.scalar_one_or_none()
if not book:
book = Book(
id=str(uuid.uuid4()),
user_id=user_id,
title="我的回忆录",
total_pages=0,
total_words=0,
cover_image_url=None,
)
db.add(book)
book.has_update = True
book.last_update_chapter_id = chapter.id
return chapter, needs_cover_enqueue
def _raise_retry():
raise self.retry(countdown=10)
memoir_orchestrator = MemoirOrchestrator()
chapters_to_enqueue, _ = memoir_orchestrator.run(
segments=segments,
llm=llm,
user_profile=user_profile,
user_birth_year=user_birth_year,
get_or_create_state=lambda: _get_or_create_state_sync(user_id, db),
update_slot=lambda stage, slot_name, snippet, seg_ids: (
_update_slot_sync(user_id, stage, slot_name, snippet, seg_ids, db)
),
acquire_lock=lambda stage: _acquire_chapter_lock(user_id, stage),
release_lock=lambda stage: _release_chapter_lock(user_id, stage),
process_category=_process_category,
raise_retry=_raise_retry,
)
# 标记段落为已处理
for seg in segments:
seg.processed = True
db.commit()
from app.tasks.chapter_compose_tasks import recompose_chapters_for_story
from app.tasks.story_image_tasks import generate_story_image
for sid in story_dispatch_ids:
try:
generate_story_image.delay(sid)
except Exception as exc:
logger.warning("generate_story_image delay: %s", exc)
try:
recompose_chapters_for_story.delay(sid)
except Exception as exc:
logger.warning("recompose_chapters_for_story delay: %s", exc)
from app.tasks.chapter_cover_enqueue import (
try_enqueue_generate_chapter_cover,
)
for chapter_id in sorted(chapters_to_enqueue):
if try_enqueue_generate_chapter_cover(chapter_id, source="pipeline"):
logger.info(f"派发章节封面任务: chapter={chapter_id}")
logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}")
# 更新任务状态为成功
_update_task_status_sync(
user_id, task_id, "success", {"processed": len(segments)}
)
return {"status": "success", "processed": len(segments)}
except Exception as e:
logger.error(f"回忆录处理失败: {e}")
# 更新任务状态为失败
_update_task_status_sync(user_id, task_id, "failure", {"error": str(e)})
# 重试
raise self.retry(exc=e)
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
"""
单独生成章节内容的任务(用于实时更新)
Args:
user_id: 用户 ID
stage: 阶段
new_content: 新内容
"""
logger.info(f"生成章节内容: user_id={user_id}, stage={stage}")
try:
with get_sync_db() as db:
llm = _get_llm()
user_obj = db.get(User, user_id)
user_profile = ""
user_birth_year = None
if user_obj:
user_birth_year = user_obj.birth_year
user_profile = format_user_profile_context(
birth_year=user_obj.birth_year,
birth_place=user_obj.birth_place,
grew_up_place=user_obj.grew_up_place,
occupation=user_obj.occupation,
)
class _Seg:
def __init__(self, text: str):
self.id = str(uuid.uuid4())
self.transcript_text = text
state = _get_or_create_state_sync(user_id, db)
chapter, _, dispatch_ids = run_story_pipeline_for_category_batch(
db,
user_id=user_id,
chapter_category=stage,
category_segments=[_Seg(new_content)],
state=state,
user_profile=user_profile,
user_birth_year=user_birth_year,
llm=llm,
)
db.commit()
db.refresh(chapter)
from app.tasks.chapter_compose_tasks import recompose_chapters_for_story
from app.tasks.story_image_tasks import generate_story_image
for sid in dispatch_ids:
try:
generate_story_image.delay(sid)
except Exception as exc:
logger.warning("generate_story_image delay: %s", exc)
try:
recompose_chapters_for_story.delay(sid)
except Exception as exc:
logger.warning("recompose_chapters_for_story delay: %s", exc)
image_settings = MemoirImageSettings.from_env()
if (
image_settings.enabled
and chapter
and chapter_needs_cover_enqueue(chapter)
):
from app.tasks.chapter_cover_enqueue import (
try_enqueue_generate_chapter_cover,
)
try_enqueue_generate_chapter_cover(chapter.id, source="pipeline")
return {"status": "success"}
except Exception as e:
logger.error(f"章节生成失败: {e}")
raise self.retry(exc=e)
def build_cos_key(user_id: str, chapter_id: str, index: int | str, prompt: str) -> str:
short_hash = hashlib.sha1(prompt.encode("utf-8")).hexdigest()[:10]
index_part = "cover" if index in (-1, "cover") else str(index)
return f"memoirs/{user_id}/{chapter_id}/{index_part}-{short_hash}.png"
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def generate_chapter_images(self, chapter_id: str):
"""异步补图:仅处理章节级 MemoirImagepending/failed。正文配图走 story_image_tasks。"""
lock_acquired = False
provider = None
with get_sync_db() as db:
try:
stmt = (
select(Chapter)
.where(Chapter.id == chapter_id)
.options(joinedload(Chapter.images))
)
chapter = db.execute(stmt).unique().scalar_one_or_none()
if not chapter:
logger.info("章节补图跳过: chapter=%s, reason=not_found", chapter_id)
return {"status": "no_chapter"}
cover_to_generate = cover_memoir_image_pending_or_failed(chapter)
if not cover_to_generate:
logger.info(
"章节补图跳过: chapter=%s, reason=no_pending_cover", chapter_id
)
return {"status": "no_images"}
settings = MemoirImageSettings.from_env()
if not settings.enabled:
logger.info("章节补图跳过: chapter=%s, reason=disabled", chapter_id)
return {"status": "disabled"}
lock_acquired = _acquire_chapter_image_lock(chapter_id)
if not lock_acquired:
logger.info("章节补图跳过: chapter=%s, reason=locked", chapter_id)
return {"status": "locked"}
prompt_orchestrator = ImagePromptOrchestrator(_get_llm(), settings)
image_generator = get_image_generator()
storage = TencentCosStorageService.from_env()
logger.info(
"章节封面补图开始: chapter=%s, cover=%s",
chapter_id,
bool(cover_to_generate),
)
retryable_failures: list[str] = []
permanent_failures: list[str] = []
def _apply_item_to_memoir_image(rec: MemoirImage, d: dict):
rec.placeholder = d.get("placeholder")
rec.description = d.get("description")
rec.status = (d.get("status") or "pending").strip() or "pending"
rec.prompt = d.get("prompt")
rec.url = d.get("url")
rec.storage_key = d.get("storage_key")
rec.provider = d.get("provider")
rec.style = d.get("style")
rec.size = d.get("size")
rec.error = d.get("error")
rec.retryable = d.get("retryable")
rec.updated_at = datetime.now(timezone.utc)
# 封面图(正文来自 canonical_markdown
if cover_to_generate:
current_item = memoir_image_to_dict(cover_to_generate) or {}
current_item.setdefault("placeholder", "")
current_item.setdefault("description", "")
current_item["status"] = IMAGE_STATUS_PROCESSING
current_item["updated_at"] = datetime.now(timezone.utc).isoformat()
_apply_item_to_memoir_image(cover_to_generate, current_item)
db.commit()
try:
raw_md = (
getattr(chapter, "canonical_markdown", None) or ""
).strip()
context_excerpt = " ".join(raw_md.split("\n")[:5])[:200]
prompt_data = prompt_orchestrator.build_cover_prompt(
chapter_title=chapter.title,
chapter_category=chapter.category or "",
context_excerpt=context_excerpt,
)
result = image_generator.generate(
prompt_data["prompt"],
prompt_data["size"],
prompt_data["style"],
)
if result.status != TaskStatus.COMPLETED or not result.image_url:
raise RuntimeError(result.error or "Image generation failed")
image_bytes = _normalize_image_bytes_for_storage(
image_generator.download_image(result.image_url)
)
key = build_cos_key(
chapter.user_id, chapter.id, "cover", prompt_data["prompt"]
)
current_item["storage_key"] = key
current_item["url"] = storage.upload_bytes(
image_bytes, key, "image/png"
)
current_item["prompt"] = prompt_data["prompt"]
current_item["style"] = prompt_data["style"]
current_item["size"] = prompt_data["size"]
current_item["status"] = IMAGE_STATUS_COMPLETED
current_item["error"] = None
current_item["retryable"] = None
current_item["updated_at"] = datetime.now(timezone.utc).isoformat()
_apply_item_to_memoir_image(cover_to_generate, current_item)
db.commit()
logger.info(
"章节封面图生成成功: chapter=%s, url=%s",
chapter_id,
current_item["url"],
)
except Exception as exc:
failure_msg = f"cover, error={exc}"
if isinstance(exc, CosUploadError) and not exc.retryable:
permanent_failures.append(failure_msg)
logger.error(
"封面图上传不可重试,清理: chapter=%s, %s",
chapter_id,
failure_msg,
)
db.delete(cover_to_generate)
db.commit()
else:
current_item = memoir_image_to_dict(cover_to_generate) or {}
current_item["status"] = IMAGE_STATUS_FAILED
current_item["error"] = str(exc)
current_item["retryable"] = True
current_item["updated_at"] = datetime.now(
timezone.utc
).isoformat()
retryable_failures.append(failure_msg)
logger.warning(
"封面图生成失败(可重试): chapter=%s, %s",
chapter_id,
failure_msg,
)
_apply_item_to_memoir_image(cover_to_generate, current_item)
db.commit()
if retryable_failures:
raise RuntimeError(
f"章节补图存在可重试失败项: chapter={chapter_id}, failures={'; '.join(retryable_failures)}"
)
return {"status": "success"}
except Exception as exc:
logger.error("章节补图任务失败: chapter=%s, error=%s", chapter_id, exc)
raise self.retry(exc=exc)
finally:
if provider:
provider.close()
if lock_acquired:
_release_chapter_image_lock(chapter_id)