From 6b930808a30ec9959a73d7d07ec274520e944f3f Mon Sep 17 00:00:00 2001 From: Kevin Date: Thu, 2 Apr 2026 16:37:14 +0800 Subject: [PATCH] =?UTF-8?q?feat(memoir):=20=E5=9B=9E=E5=BF=86=E5=BD=95?= =?UTF-8?q?=E5=88=86=E6=AE=B5=E4=B8=A4=E9=98=B6=E6=AE=B5=E7=AE=A1=E7=BA=BF?= =?UTF-8?q?=EF=BC=88Phase1=20=E5=88=86=E7=B1=BB=20/=20Phase2=20=E5=8F=99?= =?UTF-8?q?=E4=BA=8B=EF=BC=89=E4=B8=8E=E9=85=8D=E7=BD=AE=E3=80=81=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/.env.example | 4 + api/.env.production | 56 +- .../0006_segment_memoir_phase_flags.py | 42 ++ api/app/agents/chat/prompts_conversation.py | 3 + api/app/agents/chat/stage_detection.py | 130 +--- api/app/agents/chat/stage_prompts.py | 4 +- api/app/agents/memoir/classification_agent.py | 22 +- api/app/agents/memoir/extraction_agent.py | 7 +- api/app/agents/memoir/narrative_agent.py | 7 +- api/app/agents/memoir/orchestrator.py | 5 + api/app/agents/stage_constants.py | 198 ++++++ api/app/core/config.py | 13 + api/app/features/conversation/models.py | 5 + api/app/features/conversation/repo.py | 33 +- api/app/features/conversation/service.py | 29 +- api/app/features/conversation/ws/pipeline.py | 24 +- api/app/features/conversation/ws/router.py | 16 +- api/app/features/memoir/background_runner.py | 50 +- api/app/features/memoir/state_service.py | 160 ++++- .../features/memoir/story_pipeline_sync.py | 80 ++- api/app/tasks/__init__.py | 8 +- api/app/tasks/memoir_tasks.py | 672 +++++++++++------- api/tests/test_background_runner.py | 44 +- api/tests/test_memoir_skip_story.py | 7 +- api/tests/test_memoir_two_phase.py | 85 +++ api/tests/test_stage_validation.py | 83 +++ .../test_state_service_batch_stage_policy.py | 193 +++++ 27 files changed, 1550 insertions(+), 430 deletions(-) create mode 100644 api/alembic/versions/0006_segment_memoir_phase_flags.py create mode 100644 api/tests/test_memoir_two_phase.py create mode 100644 api/tests/test_stage_validation.py create mode 100644 api/tests/test_state_service_batch_stage_policy.py diff --git a/api/.env.example b/api/.env.example index a529a21..fee9738 100644 --- a/api/.env.example +++ b/api/.env.example @@ -58,6 +58,10 @@ EMBEDDING_MODEL=embedding-3 # CHAT_MEMORY_TOP_K=8 # CHAT_MEMORY_EVIDENCE_MAX_CHARS=4096 +# Memoir:批处理/抽取更新 slot 时是否允许改写 MemoirState.current_stage(默认 false,访谈 switch_stage 仍可推进) +# True 时仅当 proposed 与 existing 在同一 chat_bucket 才对齐 current_stage +# MEMOIR_EXTRACTION_UPDATES_CURRENT_STAGE=false + # Memoir:叙事前口述归一(segment 原文仍落库;仅 story 流水线派生输入) # MEMOIR_ORAL_NORMALIZE_ENABLED=true # off | rules | llm(llm 为先规则再 LLM 纠错,失败回退规则结果) diff --git a/api/.env.production b/api/.env.production index ae19db9..26b31fe 100644 --- a/api/.env.production +++ b/api/.env.production @@ -5,10 +5,24 @@ # 若仓库可被非授权人员访问,请不要在此文件中保留真实密钥。 # ============================================================================= +# ============================================================================= +# Docker Compose(宿主机独立 Caddy 反代到本 API) +# ============================================================================= +# 映射到宿主机的端口,默认 8000;与同机其它项目冲突时改为未占用端口,并在独立 Caddy 的 Caddyfile 中 reverse_proxy 到 127.0.0.1:该端口。 +# LIFE_ECHO_API_HOST_PORT=8000 +# 若 Caddy 跑在独立容器且非 host 网络,不要用 127.0.0.1,应把 Caddy 加入与本 compose 相同的 Docker 网络,并对 http://life-echo-api-prod:8000 做 reverse_proxy。 + # ============================================================================= # Logging(loguru sink 最低级别:TRACE / DEBUG / INFO / WARNING / ERROR / CRITICAL) # ============================================================================= LOG_LEVEL=INFO +# Agent 单行 INFO 摘要(耗时、路由、段落规模);与 LOG_LEVEL 独立,便于生产短暂排查 +# LOG_AGENT_VERBOSE=0 +# DEBUG 下 prompt/响应预览最大字符数 +# AGENT_LOG_MAX_CHARS=4096 +# 第三方 stdlib logging(空=自动:LOG_LEVEL 为 DEBUG/TRACE 时 Celery→INFO、httpx/httpcore→WARNING,减少刷屏) +# CELERY_LOG_LEVEL= +# HTTPX_LOG_LEVEL= # ============================================================================= # LLM / DeepSeek @@ -42,6 +56,10 @@ EMBEDDING_MODEL=embedding-3 # CHAT_MEMORY_TOP_K=8 # CHAT_MEMORY_EVIDENCE_MAX_CHARS=4096 +# Memoir:批处理/抽取更新 slot 时是否允许改写 MemoirState.current_stage(默认 false,访谈 switch_stage 仍可推进) +# True 时仅当 proposed 与 existing 在同一 chat_bucket 才对齐 current_stage +# MEMOIR_EXTRACTION_UPDATES_CURRENT_STAGE=false + # Memoir:叙事前口述归一(segment 原文仍落库;仅 story 流水线派生输入) MEMOIR_ORAL_NORMALIZE_ENABLED=true # off | rules | llm(llm 为先规则再 LLM 纠错,失败回退规则结果) @@ -49,6 +67,13 @@ MEMOIR_ORAL_NORMALIZE_MODE=llm MEMOIR_ORAL_NORMALIZE_LLM_MAX_TOKENS=512 MEMOIR_ORAL_NORMALIZE_LLM_MAX_INPUT_CHARS=8000 +# Chat:模型消费净稿(segment 原文仍落库;访谈编排层归一后注入 Agent / 记忆检索) +# CHAT_INPUT_NORMALIZE_ENABLED=true +# off | rules | llm(llm 为先规则再 LLM;失败回退规则;编排层已带 LLM 时不重复在 Agent 调) +# CHAT_INPUT_NORMALIZE_MODE=rules +# CHAT_INPUT_NORMALIZE_LLM_MAX_TOKENS=512 +# CHAT_INPUT_NORMALIZE_LLM_MAX_INPUT_CHARS=8000 + # ============================================================================= # Database # ============================================================================= @@ -57,6 +82,11 @@ MEMOIR_ORAL_NORMALIZE_LLM_MAX_INPUT_CHARS=8000 # Docker / 服务端(主机名一般为 compose 服务名 postgres): # DATABASE_URL=postgresql://postgres:postgres@postgres:5432/life_echo DATABASE_URL=postgresql://postgres:postgres@postgres:5432/life_echo +# 启动时 Alembic(main.py);生产可设 ALEMBIC_STARTUP_FAIL_FAST=true,迁移失败则拒绝启动 +# ALEMBIC_RUN_ON_STARTUP=true +# ALEMBIC_STARTUP_FAIL_FAST=false +# ALEMBIC_STARTUP_MAX_RETRIES=3 +# ALEMBIC_STARTUP_RETRY_BASE_SECONDS=1.0 # ============================================================================= # Redis @@ -82,6 +112,18 @@ REDIS_SESSION_TTL=86400 # MEMORY_COMPACTION_TEXT_JACCARD_MIN=0.55 # MEMORY_COMPACTION_METADATA_EVENT_YEAR_WINDOW=1 +# ============================================================================= +# Story 流水线(post-commit、章节物化、append 上限、evidence 检索) +# ============================================================================= +# STORY_IMAGE_ENQUEUE_DEDUP_TTL=300 +# RECOMPOSE_CHAPTER_DELAY_SECONDS=8 +# CHAPTER_PIPELINE_LOCK_TTL_SECONDS=120 +# STORY_APPEND_MAX_CANONICAL_CHARS=12000 +# STORY_APPEND_MAX_VERSIONS=20 +# EVIDENCE_TOP_K_DEFAULT=10 +# EVIDENCE_TOP_K_LARGE_BATCH=5 +# EVIDENCE_LARGE_BATCH_THRESHOLD=3 + # ============================================================================= # Auth # ============================================================================= @@ -135,7 +177,8 @@ TENCENT_SECRET_KEY=xiFbjlZ9XheS2NWYLvHRPAh2A5nGYcR2 # ENABLE_TTS:仅控制是否合成并下发 TTS_AUDIO;不影响用户语音转写(ASR) ENABLE_TTS=true TTS_PROVIDER=tencent -# 仅 TTS_PROVIDER=openai 时在控制台创建密钥后配置 OPENAI_API_KEY +# 仅 TTS_PROVIDER=openai 时需要 +# OPENAI_API_KEY=your_openai_api_key # 音色 ID 见 https://cloud.tencent.com/document/product/1073/92668 TTS_VOICE_TYPE=502001 TTS_CODEC=mp3 @@ -157,8 +200,11 @@ WECHAT_PAY_NOTIFY_URL=https://lifecho.worldsplats.com/api/payment/notify/wechat # WECHAT_PAY_PLATFORM_PUBLIC_KEY_ID=PUB_KEY_ID_0116629790992026020700181671002400 # ============================================================================= -# Alipay(启用时在此填写 ALIPAY_APP_ID / ALIPAY_PRIVATE_KEY / ALIPAY_PUBLIC_KEY) +# Alipay # ============================================================================= +# ALIPAY_APP_ID=your_alipay_app_id +# ALIPAY_PRIVATE_KEY=your_alipay_private_key +# ALIPAY_PUBLIC_KEY=your_alipay_public_key ALIPAY_NOTIFY_URL=https://lifecho.worldsplats.com/api/payment/notify/alipay # ============================================================================= @@ -180,8 +226,12 @@ STORY_IMAGE_MIN_BODY_CHARS=800 # 叙事模型输出相对口述过短则回退为口述原文 MEMOIR_NARRATIVE_FALLBACK_BODY_RATIO=0.5 MEMOIR_NARRATIVE_FALLBACK_MIN_CHARS=20 +# 回忆录 segment 入队:累计 strip 后字数未达此值则暂缓提交 Celery(0=关闭字数门闸,仅静默防抖后提交) +# MEMOIR_SEGMENT_BATCH_MIN_CHARS=50 +# 本批首条入队起最长等待(秒),超时仍提交;测试可调低,生产可调高 +# MEMOIR_SEGMENT_BATCH_MAX_WAIT_SECONDS=60 # 可选,Liblib 返回图片域名不在默认白名单时(逗号分隔) -# MEMOIR_IMAGE_DOWNLOAD_HOSTS= +# MEMOIR_IMAGE_DOWNLOAD_HOSTS=liblib.cloud,liblibai.cloud # ============================================================================= # Liblib image provider diff --git a/api/alembic/versions/0006_segment_memoir_phase_flags.py b/api/alembic/versions/0006_segment_memoir_phase_flags.py new file mode 100644 index 0000000..66cc362 --- /dev/null +++ b/api/alembic/versions/0006_segment_memoir_phase_flags.py @@ -0,0 +1,42 @@ +"""segments:Phase1/Phase2 标志(叙事延迟管线) + +Revision ID: 0006_segment_memoir_phases +Revises: 0005_cleanup_story_links +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +revision: str = "0006_segment_memoir_phases" +down_revision: Union[str, None] = "0005_cleanup_story_links" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "segments", + sa.Column( + "narrated", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + ) + op.add_column( + "segments", + sa.Column( + "skip_narrative", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + ) + + +def downgrade() -> None: + op.drop_column("segments", "skip_narrative") + op.drop_column("segments", "narrated") diff --git a/api/app/agents/chat/prompts_conversation.py b/api/app/agents/chat/prompts_conversation.py index c70ae39..99d6c1b 100644 --- a/api/app/agents/chat/prompts_conversation.py +++ b/api/app/agents/chat/prompts_conversation.py @@ -208,6 +208,9 @@ def _build_era_context(current_stage: str, user_profile_context: str) -> str: "career": (18, 50), "family": (20, 50), "belief": (30, 60), + # chapter / 防御性 key:与 belief 同档年龄参照 + "beliefs": (30, 60), + "summary": (30, 60), } age_range = stage_era_map.get(current_stage, (0, 30)) diff --git a/api/app/agents/chat/stage_detection.py b/api/app/agents/chat/stage_detection.py index 0f62721..4cfc8f3 100644 --- a/api/app/agents/chat/stage_detection.py +++ b/api/app/agents/chat/stage_detection.py @@ -12,6 +12,11 @@ from app.agents.chat.stage_prompts import ( get_chat_stage_detection_prompt, life_stage_display_zh, ) +from app.agents.stage_constants import ( + CHAT_STAGES, + STAGE_KEYWORD_WEIGHTS, + normalize_chat_stage, +) from app.core.config import settings from app.core.langchain_llm import ainvoke_json_object from app.core.logging import get_logger @@ -19,117 +24,26 @@ from app.core.json_utils import extract_json_payload logger = get_logger(__name__) -# 关键词按阶段打分;同一词不重复出现在多阶段,避免「父母」独占童年。 -_STAGE_KEYWORD_WEIGHTS: dict[str, list[tuple[str, int]]] = { - "childhood": [ - ("童年", 3), - ("小时候", 3), - ("幼年", 2), - ("出生", 2), - ("家乡", 2), - ("老家", 2), - ("小镇", 1), - ("幼儿园", 2), - ("玩伴", 1), - ], - "education": [ - ("上学", 2), - ("学校", 2), - ("老师", 2), - ("同学", 2), - ("教育", 1), - ("大学", 3), - ("高中", 2), - ("初中", 2), - ("小学", 2), - ("考试", 1), - ("毕业", 2), - ("读书", 1), - ("高考", 2), - ("课堂", 1), - ("宿舍", 1), - ], - "career": [ - ("工作", 3), - ("职业", 2), - ("事业", 2), - ("公司", 2), - ("同事", 2), - ("创业", 2), - ("升职", 1), - ("跳槽", 1), - ("老板", 1), - ("行业", 1), - ("项目", 1), - ("加班", 1), - ("薪水", 1), - ("面试", 1), - ("职场", 2), - ("离职", 1), - ], - "family": [ - ("伴侣", 2), - ("孩子", 2), - ("家庭", 2), - ("家人", 2), - ("结婚", 2), - ("爱人", 1), - ("老婆", 1), - ("老公", 1), - ("丈夫", 1), - ("妻子", 1), - ("儿子", 1), - ("女儿", 1), - ("婚礼", 1), - ("恋爱", 1), - ("父母", 2), - ("爸妈", 2), - ("父亲", 2), - ("母亲", 2), - ("爷爷", 1), - ("奶奶", 1), - ("外公", 1), - ("外婆", 1), - ], - "belief": [ - ("信念", 2), - ("价值观", 2), - ("座右铭", 2), - ("坚持", 1), - ("原则", 1), - ("信仰", 1), - ("意义", 1), - ("感悟", 1), - ("遗憾", 1), - ("骄傲", 1), - ], -} - def normalize_life_stage(raw: Optional[str], fallback: str) -> str: - if not raw or not isinstance(raw, str): - return fallback - s = raw.strip().lower() - if s in VALID_CHAT_LIFE_STAGES: - return s - return fallback + """兼容旧名:统一走 normalize_chat_stage。""" + return normalize_chat_stage(raw, fallback) def keyword_fallback_primary_stage(user_message: str) -> str: - """多阶段打分,取最高分;平局时按 stage_order 靠后的优先(更具体场景常后验)。""" + """多阶段打分,取最高分;平局按 CHAT_STAGES 逆序优先(与历史 tie_order 派生一致,可能有小幅行为差异)。""" if not (user_message or "").strip(): return "" text = user_message - scores: dict[str, int] = {k: 0 for k in _STAGE_KEYWORD_WEIGHTS} - for stage, pairs in _STAGE_KEYWORD_WEIGHTS.items(): + scores: dict[str, int] = {k: 0 for k in STAGE_KEYWORD_WEIGHTS} + for stage, pairs in STAGE_KEYWORD_WEIGHTS.items(): for word, w in pairs: if word in text: scores[stage] += w best = max(scores.values()) if best <= 0: return "" - # 平局:education > career > family > belief > childhood(避免童年默认胜出) - tie_order = ["childhood", "belief", "family", "career", "education"] + tie_order = list(reversed(CHAT_STAGES)) candidates = [s for s, v in scores.items() if v == best] for s in reversed(tie_order): if s in candidates: @@ -145,14 +59,14 @@ async def detect_primary_life_stage( """ 返回合法的人生阶段 key;失败时回退为 current_stage。 """ - fb = normalize_life_stage(current_stage, "childhood") + fb = normalize_chat_stage(current_stage, "childhood") if not settings.chat_stage_detection_enabled: k = keyword_fallback_primary_stage(user_message) - return normalize_life_stage(k, fb) if k else fb + return normalize_chat_stage(k, fb) if k else fb if not llm: k = keyword_fallback_primary_stage(user_message) - return normalize_life_stage(k, fb) if k else fb + return normalize_chat_stage(k, fb) if k else fb try: prompt = get_chat_stage_detection_prompt(user_message, fb) @@ -164,16 +78,26 @@ async def detect_primary_life_stage( ) if not raw.strip(): k = keyword_fallback_primary_stage(user_message) - return normalize_life_stage(k, fb) if k else fb + return normalize_chat_stage(k, fb) if k else fb parsed = json.loads(extract_json_payload(raw)) detected = parsed.get("detected_stage", fb) - return normalize_life_stage(str(detected) if detected is not None else "", fb) + return normalize_chat_stage(str(detected) if detected is not None else "", fb) except (json.JSONDecodeError, Exception) as e: logger.warning("detect_primary_life_stage 解析失败,使用关键词回退: {}", e) k = keyword_fallback_primary_stage(user_message) - return normalize_life_stage(k, fb) if k else fb + return normalize_chat_stage(k, fb) if k else fb def life_stage_display_name(stage: str) -> str: """供提示词展示的中文名。""" return life_stage_display_zh(stage) + + +# re-export for modules that still import VALID_CHAT_LIFE_STAGES from stage_detection +__all__ = [ + "VALID_CHAT_LIFE_STAGES", + "detect_primary_life_stage", + "keyword_fallback_primary_stage", + "life_stage_display_name", + "normalize_life_stage", +] diff --git a/api/app/agents/chat/stage_prompts.py b/api/app/agents/chat/stage_prompts.py index d91e187..145af12 100644 --- a/api/app/agents/chat/stage_prompts.py +++ b/api/app/agents/chat/stage_prompts.py @@ -2,9 +2,9 @@ 访谈「人生阶段」判定专用短提示词(与回忆录五阶段 slots 一致)。 """ -from app.agents.stage_constants import CHAT_STAGES, STAGE_DISPLAY_ZH +from app.agents.stage_constants import CHAT_STAGES, STAGE_DISPLAY_ZH, VALID_CHAT_STAGES -VALID_CHAT_LIFE_STAGES = frozenset(CHAT_STAGES) +VALID_CHAT_LIFE_STAGES = VALID_CHAT_STAGES def life_stage_display_zh(stage: str) -> str: diff --git a/api/app/agents/memoir/classification_agent.py b/api/app/agents/memoir/classification_agent.py index c745e28..c281cc2 100644 --- a/api/app/agents/memoir/classification_agent.py +++ b/api/app/agents/memoir/classification_agent.py @@ -14,8 +14,11 @@ from dataclasses import dataclass from typing import Any from app.agents.memoir.prompts import get_chapter_classification_json_prompt -from app.agents.stage_constants import CHAPTER_CATEGORIES -from app.agents.stage_constants import STAGE_TO_DEFAULT_CATEGORY +from app.agents.stage_constants import ( + CHAPTER_CATEGORIES, + STAGE_KEYWORD_WEIGHTS, + STAGE_TO_DEFAULT_CATEGORY, +) from app.core.json_utils import extract_json_payload from app.core.langchain_llm import invoke_json_object from app.core.logging import get_logger @@ -40,21 +43,12 @@ _SHORT_HUKOU_STYLE = re.compile( re.UNICODE, ) -# 5-stage 关键词(用于 LLM 失败时的兜底);注意勿含易与「仅年份句」共现的泛词,以免误推类别 -STAGE_KEYWORDS = { - "childhood": ["童年", "小时候", "家乡", "小镇"], - "education": ["上学", "学校", "老师", "同学", "教育", "大学"], - "career": ["工作", "职业", "事业", "公司", "同事", "创业"], - "family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"], - "belief": ["信念", "价值观", "座右铭", "坚持", "原则"], -} - def _detect_stage(text: str, fallback_stage: str) -> str: - """根据关键词检测消息所属的 5-stage 阶段""" + """根据关键词检测消息所属的 5-stage 阶段(与 stage_constants.STAGE_KEYWORD_WEIGHTS 同源;匹配方式为子串,非加权)。""" message = (text or "").lower() - for stage, keywords in STAGE_KEYWORDS.items(): - if any(word in message for word in keywords): + for stage, pairs in STAGE_KEYWORD_WEIGHTS.items(): + if any(word in message for word, _w in pairs): return stage return fallback_stage diff --git a/api/app/agents/memoir/extraction_agent.py b/api/app/agents/memoir/extraction_agent.py index 43134cd..090d3c1 100644 --- a/api/app/agents/memoir/extraction_agent.py +++ b/api/app/agents/memoir/extraction_agent.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from typing import Any, Dict from app.agents.memoir.prompts import get_state_extraction_prompt +from app.agents.stage_constants import normalize_chat_stage from app.core.langchain_llm import invoke_json_object from app.core.logging import get_logger from app.core.json_utils import extract_json_payload @@ -63,7 +64,11 @@ class ExtractionAgent: agent="ExtractionAgent.extract", ) parsed = json.loads(extract_json_payload(raw)) - detected_stage = parsed.get("detected_stage", detected_stage) + raw_detected = parsed.get("detected_stage", detected_stage) + detected_stage = normalize_chat_stage( + str(raw_detected) if raw_detected is not None else None, + fallback=current_stage, + ) raw_slots = parsed.get("slots", {}) or {} extracted_slots = { k: v if isinstance(v, str) else str(v) for k, v in raw_slots.items() diff --git a/api/app/agents/memoir/narrative_agent.py b/api/app/agents/memoir/narrative_agent.py index c6ae769..f3dd35d 100644 --- a/api/app/agents/memoir/narrative_agent.py +++ b/api/app/agents/memoir/narrative_agent.py @@ -8,6 +8,7 @@ from __future__ import annotations import json from typing import Any, Dict, Optional +from app.agents.stage_constants import CHAPTER_CATEGORIES from app.agents.memoir.prompts import ( get_creative_title_json_prompt, get_narrative_json_prompt, @@ -34,7 +35,7 @@ class NarrativeAgent: ) -> str: """生成创意标题。若无 LLM 则返回默认标题""" if not llm: - return f"{stage} 回忆" + return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆" try: prompt = get_creative_title_json_prompt( stage=stage, @@ -53,10 +54,10 @@ class NarrativeAgent: title = (data.get("title") or "").strip() if isinstance(data, dict) else "" if title: return title.strip('"') - return f"{stage} 回忆" + return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆" except Exception as e: logger.warning("NarrativeAgent 生成标题失败: {}", e) - return f"{stage} 回忆" + return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆" def generate_narrative( self, diff --git a/api/app/agents/memoir/orchestrator.py b/api/app/agents/memoir/orchestrator.py index 0b90417..06d787c 100644 --- a/api/app/agents/memoir/orchestrator.py +++ b/api/app/agents/memoir/orchestrator.py @@ -33,6 +33,8 @@ class PreparedMemoirBatches: category_to_segments: Dict[str, List[Segment]] #: segment id 在「LLM 判 none 且 extraction slots 为空」时加入;batch 级短路见 memoir_tasks segment_skip_story_ids: Set[str] + #: 每个 segment → Phase 1 分类 chapter_category(持久化到 Segment.topic_category) + segment_chapter_category: Dict[str, str] class MemoirOrchestrator: @@ -64,6 +66,7 @@ class MemoirOrchestrator: state = get_or_create_state() category_to_segments: Dict[str, List[Segment]] = {} segment_skip_story_ids: Set[str] = set() + segment_chapter_category: Dict[str, str] = {} classify_extract_llm = llm_fast if llm_fast is not None else llm for segment in segments: @@ -103,6 +106,7 @@ class MemoirOrchestrator: chapter_category = classify_result.category if (not result.slots) and classify_result.llm_said_none: segment_skip_story_ids.add(str(segment.id)) + segment_chapter_category[str(segment.id)] = chapter_category if agent_summary_enabled(): logger.info( @@ -126,6 +130,7 @@ class MemoirOrchestrator: state=state, category_to_segments=category_to_segments, segment_skip_story_ids=segment_skip_story_ids, + segment_chapter_category=segment_chapter_category, ) def run( diff --git a/api/app/agents/stage_constants.py b/api/app/agents/stage_constants.py index a3bebd6..4e86cba 100644 --- a/api/app/agents/stage_constants.py +++ b/api/app/agents/stage_constants.py @@ -2,6 +2,12 @@ from __future__ import annotations +from typing import Any + +from app.core.logging import get_logger + +logger = get_logger(__name__) + # 访谈五阶段(与 MemoirStateSchema.default_slots 顺序一致) CHAT_STAGES: tuple[str, ...] = ( "childhood", @@ -11,6 +17,8 @@ CHAT_STAGES: tuple[str, ...] = ( "belief", ) +VALID_CHAT_STAGES: frozenset[str] = frozenset(CHAT_STAGES) + STAGE_DISPLAY_ZH = { "childhood": "童年时光", "education": "求学经历", @@ -38,6 +46,8 @@ CHAPTER_CATEGORIES = { "summary": "人生总结", } +VALID_CHAPTER_CATEGORIES: frozenset[str] = frozenset(CHAPTER_CATEGORIES.keys()) + CHAPTER_ORDER = [ "childhood", "education", @@ -49,6 +59,7 @@ CHAPTER_ORDER = [ "summary", ] +# career / career_early 与 belief / beliefs 共用序号:chat-stage 别名与 chapter category 兼容映射 STAGE_TO_ORDER = { "childhood": 0, "education": 1, @@ -72,3 +83,190 @@ CATEGORY_TO_CHAT_STAGE: dict[str, str] = { "beliefs": "belief", "summary": "belief", } + +# 访谈关键词加权(chat 路径打分);classification 与子串检测共用此数据源 +STAGE_KEYWORD_WEIGHTS: dict[str, list[tuple[str, int]]] = { + "childhood": [ + ("童年", 3), + ("小时候", 3), + ("幼年", 2), + ("出生", 2), + ("家乡", 2), + ("老家", 2), + ("小镇", 1), + ("幼儿园", 2), + ("玩伴", 1), + ], + "education": [ + ("上学", 2), + ("学校", 2), + ("老师", 2), + ("同学", 2), + ("教育", 1), + ("大学", 3), + ("高中", 2), + ("初中", 2), + ("小学", 2), + ("考试", 1), + ("毕业", 2), + ("读书", 1), + ("高考", 2), + ("课堂", 1), + ("宿舍", 1), + ], + "career": [ + ("工作", 3), + ("职业", 2), + ("事业", 2), + ("公司", 2), + ("同事", 2), + ("创业", 2), + ("升职", 1), + ("跳槽", 1), + ("老板", 1), + ("行业", 1), + ("项目", 1), + ("加班", 1), + ("薪水", 1), + ("面试", 1), + ("职场", 2), + ("离职", 1), + ], + "family": [ + ("伴侣", 2), + ("孩子", 2), + ("家庭", 2), + ("家人", 2), + ("结婚", 2), + ("爱人", 1), + ("老婆", 1), + ("老公", 1), + ("丈夫", 1), + ("妻子", 1), + ("儿子", 1), + ("女儿", 1), + ("婚礼", 1), + ("恋爱", 1), + ("父母", 2), + ("爸妈", 2), + ("父亲", 2), + ("母亲", 2), + ("爷爷", 1), + ("奶奶", 1), + ("外公", 1), + ("外婆", 1), + ], + "belief": [ + ("信念", 2), + ("价值观", 2), + ("座右铭", 2), + ("坚持", 1), + ("原则", 1), + ("信仰", 1), + ("意义", 1), + ("感悟", 1), + ("遗憾", 1), + ("骄傲", 1), + ], +} + +# 模型/任务偶发使用 chapter key 当 chat stage 时收束到 bucket +_CHAT_STAGE_SYNONYMS: dict[str, str] = { + "beliefs": "belief", +} + + +def chat_bucket(stage_or_category: str | None) -> str: + """将 chat stage 或 chapter category 收束到 CHAT_STAGES 之一。""" + if not stage_or_category or not isinstance(stage_or_category, str): + return "" + s = stage_or_category.strip().lower() + if s in VALID_CHAT_STAGES: + return s + if s in _CHAT_STAGE_SYNONYMS: + return _CHAT_STAGE_SYNONYMS[s] + if s in CATEGORY_TO_CHAT_STAGE: + return CATEGORY_TO_CHAT_STAGE[s] + return "" + + +def normalize_chat_stage( + raw: str | None, + fallback: str, + *, + log_context: dict[str, Any] | None = None, +) -> str: + """校验并归一化 chat stage;非法非空输入回落到 fallback 并可选结构化日志。""" + if not raw or not isinstance(raw, str): + return fallback + stripped = raw.strip() + if not stripped: + return fallback + s = stripped.lower() + if s in VALID_CHAT_STAGES: + return s + if s in _CHAT_STAGE_SYNONYMS: + return _CHAT_STAGE_SYNONYMS[s] + if s in CATEGORY_TO_CHAT_STAGE: + mapped = CATEGORY_TO_CHAT_STAGE[s] + if log_context: + logger.bind(**log_context).info( + "event=normalize_chat_stage_mapped raw={} mapped={}", + raw, + mapped, + ) + return mapped + fb = ( + fallback.strip().lower() + if isinstance(fallback, str) and fallback.strip() + else "childhood" + ) + if fb in VALID_CHAT_STAGES: + pass + elif fb in _CHAT_STAGE_SYNONYMS: + fb = _CHAT_STAGE_SYNONYMS[fb] + elif fb in CATEGORY_TO_CHAT_STAGE: + fb = CATEGORY_TO_CHAT_STAGE[fb] + else: + fb = "childhood" + if log_context: + logger.bind(**log_context).info( + "event=normalize_chat_stage_fallback raw={} fallback={}", + raw, + fb, + ) + return fb + + +def normalize_chapter_category( + raw: str | None, + fallback: str, + *, + log_context: dict[str, Any] | None = None, +) -> str: + """校验 chapter category key(CHAPTER_CATEGORIES 的键)。""" + if not raw or not isinstance(raw, str): + out = fallback if fallback in VALID_CHAPTER_CATEGORIES else "summary" + return out + s = raw.strip().lower() + if s.startswith("`"): + s = s.strip("`").strip() + if (s.startswith('"') and s.endswith('"')) or ( + s.startswith("'") and s.endswith("'") + ): + s = s[1:-1].strip().lower() + if s in VALID_CHAPTER_CATEGORIES: + return s + fb = ( + fallback + if isinstance(fallback, str) + and fallback.strip().lower() in VALID_CHAPTER_CATEGORIES + else "summary" + ) + if log_context: + logger.bind(**log_context).info( + "event=normalize_chapter_category_fallback raw={} fallback={}", + raw, + fb, + ) + return fb diff --git a/api/app/core/config.py b/api/app/core/config.py index 34882c9..548874a 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -213,6 +213,19 @@ class Settings(BaseSettings): memoir_segment_batch_max_wait_seconds: float = Field( default=60.0, ge=0.0, le=3600.0 ) + # 回忆录叙事 Phase 2( Celery)触发:单条口述达到该 strip 字数则立即跑叙事 + memoir_narrative_immediate_char_threshold: int = Field(default=50, ge=0, le=50_000) + # 同一 topic_category 下未叙事段数达到该值则触发 Phase 2 + memoir_narrative_batch_min_segments: int = Field(default=3, ge=1, le=500) + # 同上,累计 user_input_text 字符数(strip 后由 Segment 列 length 近似) + memoir_narrative_batch_min_chars: int = Field(default=80, ge=0, le=500_000) + # Phase 1 完成后未触发 Phase 2 时,延迟任务兜底(秒);新 Phase 1 会 revoke 旧定时 + memoir_narrative_batch_max_wait_seconds: float = Field( + default=120.0, ge=1.0, le=3600.0 + ) + # False:Celery/批处理更新 slot 时不改写 MemoirState.current_stage(访谈路径仍可由 switch_stage 推进) + # True:仅当 chat_bucket( proposed ) == chat_bucket( existing ) 时允许批处理对齐 current_stage + memoir_extraction_updates_current_stage: bool = False # ── Memory 检索与富化 ───────────────────────────────────── # True:query 为空时仍返回 rolling 摘要 + 最近事实/时间线(无 chunk FTS) diff --git a/api/app/features/conversation/models.py b/api/app/features/conversation/models.py index fd9f102..a8811b9 100644 --- a/api/app/features/conversation/models.py +++ b/api/app/features/conversation/models.py @@ -50,7 +50,12 @@ class Segment(Base): audio_duration_seconds = Column(Integer, nullable=True) created_at = Column(DateTime(timezone=True), default=utc_now) processed = Column(Boolean, default=False) + # Phase 1 分类结果(回忆录 chapter 类目);非空表示 Phase 1 已完成 topic_category = Column(String, nullable=True) + # Phase 2 已消费该段并完成叙事落库 + narrated = Column(Boolean, default=False, server_default="false") + # Phase 1 判定无需进故事管线(无 slots 且 LLM 判 none) + skip_narrative = Column(Boolean, default=False, server_default="false") agent_response = Column(Text, nullable=True) tts_audio_urls = Column(JSON, nullable=True) diff --git a/api/app/features/conversation/repo.py b/api/app/features/conversation/repo.py index 447a56b..0287277 100644 --- a/api/app/features/conversation/repo.py +++ b/api/app/features/conversation/repo.py @@ -84,20 +84,43 @@ async def get_segments_for_conversation( async def get_segments_for_organize( conversation_id: str, db: AsyncSession ) -> list[Segment]: - """Unprocessed segments first; if none, all segments.""" + """兼容旧语义:优先返回 Phase1 未完成的片段;若无则返回本会话全部片段。""" + pending = await get_segments_pending_phase1(conversation_id, db) + if pending: + return pending + return await get_segments_for_conversation(conversation_id, db) + + +async def get_segments_pending_phase1( + conversation_id: str, db: AsyncSession +) -> list[Segment]: + """尚未跑 Phase1 分类的 segments(topic_category 为空且未标记 narrated)。""" stmt = ( select(Segment) .where( Segment.conversation_id == conversation_id, + Segment.topic_category.is_(None), + Segment.narrated.is_(False), Segment.processed.is_(False), ) .order_by(Segment.created_at) ) result = await db.execute(stmt) - segments = list(result.scalars().all()) - if not segments: - return await get_segments_for_conversation(conversation_id, db) - return segments + return list(result.scalars().all()) + + +async def conversation_has_pending_phase2( + conversation_id: str, db: AsyncSession +) -> bool: + """Phase1 已完成但叙事未消费的片段(本会话范围内)。""" + stmt = select(func.count(Segment.id)).where( + Segment.conversation_id == conversation_id, + Segment.topic_category.isnot(None), + Segment.narrated.is_(False), + Segment.skip_narrative.is_(False), + ) + result = await db.execute(stmt) + return int(result.scalar() or 0) > 0 async def count_segments_for_user(user_id: str, db: AsyncSession) -> int: diff --git a/api/app/features/conversation/service.py b/api/app/features/conversation/service.py index 6ca3816..6a9b3ad 100644 --- a/api/app/features/conversation/service.py +++ b/api/app/features/conversation/service.py @@ -1,5 +1,6 @@ """Conversation service — 对话编排(列表、创建、结束、删除、消息、整理)。""" +import asyncio import uuid from datetime import datetime, timezone @@ -23,7 +24,10 @@ from app.features.conversation.tts_delivery import apply_presigned_tts_urls_to_m from app.features.memory import repo as memory_repo from app.features.quota.service import QuotaService from app.ports.storage import ObjectStorage -from app.tasks.memoir_tasks import process_memoir_segments +from app.tasks.memoir_tasks import ( + dispatch_pending_memoir_phase2_for_user, + process_memoir_phase1, +) logger = get_logger(__name__) @@ -318,23 +322,26 @@ class ConversationService: self, conversation_id: str, user_id: str, subscription_type: str ) -> dict: conv = await self.get_or_404(conversation_id, user_id) - segments = await repo.get_segments_for_organize(conversation_id, self._db) - if not segments: + pending_p1 = await repo.get_segments_pending_phase1(conversation_id, self._db) + has_p2 = await repo.conversation_has_pending_phase2(conversation_id, self._db) + if not pending_p1 and not has_p2: raise HTTPException(status_code=400, detail="该对话没有可整理的内容") can_submit, quota_message = await self._quota.check_can_submit_organize( user_id, subscription_type ) if not can_submit: raise HTTPException(status_code=403, detail=quota_message) - segment_ids = [s.id for s in segments] - process_memoir_segments.delay(conv.user_id, segment_ids) - logger.info( - "手动触发对话整理: conversation_id={}, segments={}", - conversation_id, - len(segment_ids), - ) + if pending_p1: + segment_ids = [s.id for s in pending_p1] + process_memoir_phase1.delay(conv.user_id, segment_ids) + logger.info( + "手动触发 Phase1: conversation_id={}, segments={}", + conversation_id, + len(segment_ids), + ) + await asyncio.to_thread(dispatch_pending_memoir_phase2_for_user, conv.user_id) return { "message": "对话整理任务已提交", "conversation_id": conversation_id, - "segments_count": len(segment_ids), + "segments_count": len(pending_p1), } diff --git a/api/app/features/conversation/ws/pipeline.py b/api/app/features/conversation/ws/pipeline.py index d127cf9..206c8a6 100644 --- a/api/app/features/conversation/ws/pipeline.py +++ b/api/app/features/conversation/ws/pipeline.py @@ -792,11 +792,13 @@ async def process_conversation_segments( conversation_id: str, db: AsyncSession, quota_service: "QuotaService" ): """ - 处理对话段落,生成章节(对话结束时调用) + 对话结束时:把本对话仍待 Phase1 的段落交给回忆录管线。 - 注意:大部分处理已通过 Celery 任务增量完成 - 这里立即提交所有待处理的段落到 Celery - 配额检查通过注入的 quota_service 完成,不直接 import quota 内部函数。 + 经 `BackgroundTaskRunner.flush_pending` 将内存防抖 batch 与当前查询到的 + `topic_category IS NULL` 段 ID 合并、去重后**单次**提交 `process_memoir_phase1`, + 并在 flush 末尾触发待叙事 Phase2 派发;避免会话结束路径与 debounce flush 双发 Phase1。 + + 配额检查通过注入的 `quota_service` 完成,不直接 import quota 内部函数。 """ conversation = await db.get(Conversation, conversation_id) if not conversation or conversation.deleted_at is not None: @@ -805,6 +807,7 @@ async def process_conversation_segments( stmt = select(Segment).where( Segment.conversation_id == conversation_id, Segment.processed == False, + Segment.topic_category.is_(None), ) result = await db.execute(stmt) segments = result.scalars().all() @@ -827,13 +830,14 @@ async def process_conversation_segments( segment_ids = [seg.id for seg in segments] try: - from app.tasks.memoir_tasks import process_memoir_segments - - process_memoir_segments.delay(conversation.user_id, segment_ids) + await background_runner.flush_pending( + conversation.user_id, extra_segment_ids=segment_ids + ) logger.info( - f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}" + "对话结束,合并批内 segment 与 DB 待分类段,单次提交 Phase1: " + "conversation_id={} segments={}", + conversation_id, + len(segment_ids), ) except Exception as e: logger.error(f"提交 Celery 任务失败: {e}") - - await background_runner.flush_pending(conversation.user_id) diff --git a/api/app/features/conversation/ws/router.py b/api/app/features/conversation/ws/router.py index c7ff776..0518449 100644 --- a/api/app/features/conversation/ws/router.py +++ b/api/app/features/conversation/ws/router.py @@ -12,6 +12,7 @@ from fastapi import WebSocket, WebSocketDisconnect, status from starlette.websockets import WebSocketState from app.agents.chat.background_voice import infer_background_voice +from app.agents.stage_constants import STAGE_TO_ORDER from app.agents.chat.prompts_profile import format_user_profile_context from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider @@ -140,6 +141,19 @@ async def websocket_endpoint( ) return + # 冷启动对齐 conversation_stage 与 MemoirState.current_stage; + # 若对话行已有更靠前的人生阶段(STAGE_TO_ORDER 更大),不覆盖以免回退。 + memoir_state = await get_or_create_state(user_id, db) + ms = (memoir_state.current_stage or "").strip() + cs = (conversation.conversation_stage or "").strip() + if ms: + if not cs: + conversation.conversation_stage = ms + elif STAGE_TO_ORDER.get(ms, -1) >= STAGE_TO_ORDER.get(cs, -1): + conversation.conversation_stage = ms + await db.commit() + await db.refresh(conversation) + history = await conversation_service.ensure_redis_history_from_db( conversation_id ) @@ -180,7 +194,7 @@ async def websocket_endpoint( logger.error(f"发送资料收集开场白失败: {e}", exc_info=True) else: try: - state = await get_or_create_state(user_id, db) + state = memoir_state user_profile_context = format_user_profile_context( birth_year=user.birth_year, birth_place=user.birth_place, diff --git a/api/app/features/memoir/background_runner.py b/api/app/features/memoir/background_runner.py index b0331fc..74afd61 100644 --- a/api/app/features/memoir/background_runner.py +++ b/api/app/features/memoir/background_runner.py @@ -1,11 +1,11 @@ -"""回忆录后台任务聚合:debounce 后派发 process_memoir_segments(feature 层)。""" +"""回忆录后台任务聚合:debounce 后派发 process_memoir_phase1;flush 时触发待叙事 Phase2。""" from __future__ import annotations import asyncio import time from dataclasses import dataclass, field -from typing import Dict, List +from typing import Dict, List, Sequence from app.core.config import settings from app.core.logging import get_logger @@ -60,11 +60,24 @@ class BackgroundTaskRunner: ids = st.segment_ids return ids + async def _flush_pending_phase2(self, user_id: str) -> None: + try: + from app.tasks.memoir_tasks import dispatch_pending_memoir_phase2_for_user + + await asyncio.to_thread(dispatch_pending_memoir_phase2_for_user, user_id) + except Exception as e: + logger.error( + "flush Phase2 失败: user_id={} exc_type={} exc={}", + user_id, + type(e).__name__, + e, + ) + async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None: try: - from app.tasks.memoir_tasks import process_memoir_segments + from app.tasks.memoir_tasks import process_memoir_phase1 - result = process_memoir_segments.delay(user_id, segment_ids) + result = process_memoir_phase1.delay(user_id, segment_ids) task_id = result.id await task_tracker.add_task(user_id, task_id, "memoir") logger.info( @@ -149,11 +162,30 @@ class BackgroundTaskRunner: self._timers[user_id] = asyncio.create_task(delayed_submit()) - async def flush_pending(self, user_id: str) -> str | None: + def _dedupe_preserve_order(self, ids: Sequence[str]) -> list[str]: + seen: set[str] = set() + out: list[str] = [] + for sid in ids: + if sid not in seen: + seen.add(sid) + out.append(sid) + return out + + async def flush_pending( + self, + user_id: str, + *, + extra_segment_ids: Sequence[str] | None = None, + ) -> str | None: if user_id in self._timers: self._timers[user_id].cancel() del self._timers[user_id] - segment_ids = self._pop_batch(user_id) - if segment_ids: - return await self._submit_task(user_id, segment_ids) - return None + popped = self._pop_batch(user_id) + merged = self._dedupe_preserve_order( + list(popped) + list(extra_segment_ids or ()) + ) + task_id: str | None = None + if merged: + task_id = await self._submit_task(user_id, merged) + await self._flush_pending_phase2(user_id) + return task_id diff --git a/api/app/features/memoir/state_service.py b/api/app/features/memoir/state_service.py index 2690a65..a74b668 100644 --- a/api/app/features/memoir/state_service.py +++ b/api/app/features/memoir/state_service.py @@ -1,6 +1,6 @@ """ 回忆录状态服务:get_or_create_state、update_slot、mark_stage_complete 等。 -供 memoir service、conversation ws 使用;Celery 任务内使用同步版本(见 tasks/memoir_tasks)。 +供 memoir service、conversation ws 使用;Celery 任务内使用同步版本。 """ import uuid @@ -8,12 +8,25 @@ from typing import Dict, List from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from app.agents.state_schema import MemoirStateSchema, SlotData, default_state +from app.agents.stage_constants import ( + chat_bucket, + normalize_chat_stage, +) +from app.core.config import settings from app.features.memoir.models import MemoirState as MemoirStateModel -def _coerce_state(model: MemoirStateModel) -> MemoirStateSchema: +def _slots_snapshot_for_merge(raw: Dict[str, Dict] | None) -> Dict[str, Dict]: + """浅拷贝 slots,避免就地改 JSON 列同一 dict 引用导致 ORM 不标记 dirty。""" + if not raw or not isinstance(raw, dict): + return {} + return {k: dict(v or {}) for k, v in raw.items()} + + +def coerce_memoir_state(model: MemoirStateModel) -> MemoirStateSchema: return MemoirStateSchema.model_validate( { "stage_order": model.stage_order or default_state().stage_order, @@ -31,7 +44,7 @@ async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSche result = await db.execute(stmt) state = result.scalar_one_or_none() if state: - return _coerce_state(state) + return coerce_memoir_state(state) default = default_state() state = MemoirStateModel( @@ -48,7 +61,27 @@ async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSche db.add(state) await db.commit() await db.refresh(state) - return _coerce_state(state) + return coerce_memoir_state(state) + + +def _apply_current_stage_policy( + state: MemoirStateModel, + stage_norm: str, + *, + memoir_batch: bool, +) -> None: + """按 memoir_extraction_updates_current_stage 与 chat_bucket 真值表更新 current_stage。""" + current_from_db = state.current_stage or "childhood" + if not memoir_batch: + state.current_stage = stage_norm + return + + if not settings.memoir_extraction_updates_current_stage: + return + cur_b = chat_bucket(state.current_stage or current_from_db) + new_b = chat_bucket(stage_norm) + if new_b == cur_b: + state.current_stage = stage_norm async def update_slot( @@ -58,8 +91,14 @@ async def update_slot( snippet: str, segment_ids: List[str], db: AsyncSession, + *, + memoir_batch: bool = False, ) -> MemoirStateSchema: - stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) + stmt = ( + select(MemoirStateModel) + .where(MemoirStateModel.user_id == user_id) + .with_for_update() + ) result = await db.execute(stmt) state = result.scalar_one_or_none() if not state: @@ -67,25 +106,35 @@ async def update_slot( result = await db.execute(stmt) state = result.scalar_one() - slots: Dict[str, Dict] = state.slots or {} - stage_slots = slots.get(stage, {}) + current_from_db = state.current_stage or "childhood" + stage_norm = normalize_chat_stage( + stage, + fallback=current_from_db, + log_context={"user_id": user_id}, + ) + + slots = _slots_snapshot_for_merge( + state.slots if isinstance(state.slots, dict) else None + ) + stage_slots = dict(slots.get(stage_norm, {}) or {}) existing = stage_slots.get(slot_name, {}) merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) stage_slots[slot_name] = SlotData( snippet=snippet, segment_ids=merged_segment_ids ).model_dump() - slots[stage] = stage_slots + slots[stage_norm] = stage_slots state.slots = slots - state.current_stage = stage + _apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch) await db.commit() await db.refresh(state) - return _coerce_state(state) + return coerce_memoir_state(state) async def mark_stage_complete( user_id: str, stage: str, db: AsyncSession ) -> MemoirStateSchema: + """推进 covered_stages 并在当前阶段匹配时尝试进入下一阶段。当前无调用方,预留未来阶段推进逻辑。""" stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) result = await db.execute(stmt) state = result.scalar_one_or_none() @@ -106,7 +155,7 @@ async def mark_stage_complete( state.current_stage = default_state().current_stage await db.commit() await db.refresh(state) - return _coerce_state(state) + return coerce_memoir_state(state) async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]: @@ -117,13 +166,94 @@ async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]: async def switch_stage( user_id: str, new_stage: str, db: AsyncSession ) -> MemoirStateSchema: - stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) + stmt = ( + select(MemoirStateModel) + .where(MemoirStateModel.user_id == user_id) + .with_for_update() + ) result = await db.execute(stmt) state = result.scalar_one_or_none() if not state: - return await get_or_create_state(user_id, db) + await get_or_create_state(user_id, db) + result = await db.execute(stmt) + state = result.scalar_one() - state.current_stage = new_stage + fb = state.current_stage or "childhood" + state.current_stage = normalize_chat_stage( + new_stage, fallback=fb, log_context={"user_id": user_id} + ) await db.commit() await db.refresh(state) - return _coerce_state(state) + return coerce_memoir_state(state) + + +def get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema: + stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) + result = db.execute(stmt) + state = result.scalar_one_or_none() + if state: + return coerce_memoir_state(state) + + default = default_state() + state = MemoirStateModel( + id=str(uuid.uuid4()), + user_id=user_id, + stage_order=default.stage_order, + current_stage=default.current_stage, + covered_stages=default.covered_stages, + slots={ + k: {sk: sv.model_dump() for sk, sv in v.items()} + for k, v in default.slots.items() + }, + ) + db.add(state) + db.commit() + db.refresh(state) + return coerce_memoir_state(state) + + +def update_slot_sync( + user_id: str, + stage: str, + slot_name: str, + snippet: str, + segment_ids: List[str], + db: Session, + *, + memoir_batch: bool = True, +) -> MemoirStateSchema: + stmt = ( + select(MemoirStateModel) + .where(MemoirStateModel.user_id == user_id) + .with_for_update() + ) + result = db.execute(stmt) + state = result.scalar_one_or_none() + if not state: + get_or_create_state_sync(user_id, db) + result = db.execute(stmt) + state = result.scalar_one() + + current_from_db = state.current_stage or "childhood" + stage_norm = normalize_chat_stage( + stage, + fallback=current_from_db, + log_context={"user_id": user_id}, + ) + + slots = _slots_snapshot_for_merge( + state.slots if isinstance(state.slots, dict) else None + ) + stage_slots = dict(slots.get(stage_norm, {}) or {}) + existing = stage_slots.get(slot_name, {}) + + merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) + stage_slots[slot_name] = SlotData( + snippet=snippet, segment_ids=merged_segment_ids + ).model_dump() + slots[stage_norm] = stage_slots + state.slots = slots + _apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch) + db.commit() + db.refresh(state) + return coerce_memoir_state(state) diff --git a/api/app/features/memoir/story_pipeline_sync.py b/api/app/features/memoir/story_pipeline_sync.py index 01a7c53..df5070b 100644 --- a/api/app/features/memoir/story_pipeline_sync.py +++ b/api/app/features/memoir/story_pipeline_sync.py @@ -22,6 +22,7 @@ from app.agents.memoir.prompts import ( from app.agents.stage_constants import ( CATEGORY_TO_CHAT_STAGE, CHAPTER_CATEGORIES, + CHAT_STAGES, STAGE_TO_ORDER, ) from app.agents.memoir.story_route_agent import ( @@ -56,6 +57,71 @@ from app.features.story.sync_write import ( logger = get_logger(__name__) +# summary 章节跨阶段汇总 slots 时的上限(防叙事 prompt 膨胀) +MAX_SUMMARY_SLOT_KEYS = 80 +MAX_SUMMARY_SLOT_CHARS = 12_000 + + +def _slot_snippets_for_narrative( + *, + state: MemoirStateSchema, + chapter_category: str, + user_id: str, +) -> dict[str, str]: + """按章节类目收集 slot 片段;summary 时跨 CHAT_STAGES 汇总并做 key/字符上限。""" + slot_snippets: dict[str, str] = {} + if chapter_category == "summary": + total_chars = 0 + keys_added = 0 + capped = False + for chat_stage_key in CHAT_STAGES: + if keys_added >= MAX_SUMMARY_SLOT_KEYS: + capped = True + break + stage_slots = state.slots.get(chat_stage_key, {}) or {} + for key in sorted(stage_slots.keys()): + if keys_added >= MAX_SUMMARY_SLOT_KEYS: + capped = True + break + value = stage_slots[key] + snip = getattr(value, "snippet", None) or ( + value.get("snippet") if isinstance(value, dict) else None + ) + if not snip: + continue + composite = f"{chat_stage_key}_{key}" + s = str(snip).strip() + if total_chars + len(s) > MAX_SUMMARY_SLOT_CHARS: + remain = MAX_SUMMARY_SLOT_CHARS - total_chars + if remain > 32: + slot_snippets[composite] = s[:remain] + "…" + capped = True + break + slot_snippets[composite] = s + total_chars += len(s) + keys_added += 1 + if capped: + break + if capped: + logger.info( + "event=summary_slot_snippets_capped user_id={} keys={} chars={}", + user_id, + len(slot_snippets), + total_chars, + ) + return slot_snippets + + chat_stage = CATEGORY_TO_CHAT_STAGE.get(chapter_category, chapter_category) + stage_slots = state.slots.get(chat_stage, {}) or {} + for key in sorted(stage_slots.keys()): + value = stage_slots[key] + snip = getattr(value, "snippet", None) or ( + value.get("snippet") if isinstance(value, dict) else None + ) + if snip: + slot_snippets[key] = str(snip).strip() + return slot_snippets + def _placeholder_title(chapter_category: str) -> str: return CHAPTER_CATEGORIES.get(chapter_category, chapter_category) @@ -513,15 +579,11 @@ def run_story_pipeline_for_category_batch( ) chapter = session.execute(stmt_chapter).unique().scalar_one_or_none() - slot_snippets: dict[str, str] = {} - chat_stage = CATEGORY_TO_CHAT_STAGE.get(chapter_category, chapter_category) - stage_slots = state.slots.get(chat_stage, {}) or {} - for key, value in stage_slots.items(): - snip = getattr(value, "snippet", None) or ( - value.get("snippet") if isinstance(value, dict) else None - ) - if snip: - slot_snippets[key] = snip + slot_snippets = _slot_snippets_for_narrative( + state=state, + chapter_category=chapter_category, + user_id=user_id, + ) title = chapter.title if chapter else _placeholder_title(chapter_category) diff --git a/api/app/tasks/__init__.py b/api/app/tasks/__init__.py index b29f4e5..efeae56 100644 --- a/api/app/tasks/__init__.py +++ b/api/app/tasks/__init__.py @@ -4,12 +4,18 @@ Celery 任务模块 from .celery_app import celery_app from .chapter_cover_tasks import generate_chapter_cover -from .memoir_tasks import process_memoir_segments +from .memoir_tasks import ( + process_memoir_phase1, + process_memoir_phase2, + process_memoir_segments, +) from .memory_compaction_tasks import memory_compaction_run from .story_image_tasks import generate_story_image __all__ = [ "celery_app", + "process_memoir_phase1", + "process_memoir_phase2", "process_memoir_segments", "generate_chapter_cover", "generate_story_image", diff --git a/api/app/tasks/memoir_tasks.py b/api/app/tasks/memoir_tasks.py index 8d9d17d..d98bcc0 100644 --- a/api/app/tasks/memoir_tasks.py +++ b/api/app/tasks/memoir_tasks.py @@ -9,13 +9,16 @@ from typing import Dict, List, Set import redis from celery import shared_task -from sqlalchemy import select +from celery.exceptions import Retry +from celery.result import AsyncResult +from sqlalchemy import func, select from sqlalchemy.orm import Session from app.agents.chat.background_voice import infer_background_voice from app.agents.chat.prompts_profile import format_user_profile_context from app.agents.memoir import MemoirOrchestrator -from app.agents.state_schema import MemoirStateSchema, SlotData, default_state +from app.agents.stage_constants import normalize_chapter_category +from app.agents.state_schema import MemoirStateSchema, default_state from app.core.chapter_pipeline_lock import ( acquire_chapter_pipeline_lock as _acquire_chapter_lock, ) @@ -26,7 +29,9 @@ from app.core.config import settings from app.core.db import get_sync_db from app.core.dependencies import get_llm_provider, get_llm_provider_fast from app.core.logging import get_logger -from app.features.conversation.models import Segment +from app.features.conversation.models import Conversation, Segment + +from app.tasks.celery_app import celery_app from app.features.memoir.cover_eligibility import ( chapter_needs_cover_enqueue, ) @@ -46,7 +51,10 @@ from app.features.memoir.memoir_images.settings import MemoirImageSettings from app.features.memoir.models import ( Book, MemoirImage, - MemoirState, +) +from app.features.memoir.state_service import ( + get_or_create_state_sync, + update_slot_sync, ) from app.features.memoir.story_pipeline_sync import ( run_story_pipeline_for_category_batch, @@ -187,109 +195,361 @@ def _memoir_image_from_asset( ) -def _coerce_state(model: MemoirState) -> MemoirStateSchema: - """将数据库模型转换为 Schema""" - return MemoirStateSchema.model_validate( - { - "stage_order": model.stage_order or default_state().stage_order, - "current_stage": model.current_stage, - "covered_stages": model.covered_stages or [], - "slots": model.slots - if isinstance(model.slots, dict) - else default_state().slots, - } - ) +def _phase2_timeout_task_id(user_id: str, chapter_category: str) -> str: + return f"phase2-timeout-{user_id}-{chapter_category}" -def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema: - """同步获取或创建状态""" - stmt = select(MemoirState).where(MemoirState.user_id == user_id) - result = db.execute(stmt) - state = result.scalar_one_or_none() - if state: - return _coerce_state(state) - - default = default_state() - state = MemoirState( - id=str(uuid.uuid4()), - user_id=user_id, - stage_order=default.stage_order, - current_stage=default.current_stage, - covered_stages=default.covered_stages, - slots={ - k: {sk: sv.model_dump() for sk, sv in v.items()} - for k, v in default.slots.items() - }, - ) - db.add(state) - db.commit() - db.refresh(state) - return _coerce_state(state) +def _revoke_phase2_timeout(user_id: str, chapter_category: str) -> None: + tid = _phase2_timeout_task_id(user_id, chapter_category) + try: + AsyncResult(tid, app=celery_app).revoke(terminate=False) + except Exception as e: + logger.debug( + "event=phase2_timeout_revoke_skipped task_id={} exc={}", + tid, + e, + ) -def _update_slot_sync( - user_id: str, - stage: str, - slot_name: str, - snippet: str, - segment_ids: List[str], +def _should_trigger_phase2( db: Session, -) -> MemoirStateSchema: - """同步更新 slot""" - stmt = select(MemoirState).where(MemoirState.user_id == user_id) - result = db.execute(stmt) - state = result.scalar_one_or_none() - if not state: - _get_or_create_state_sync(user_id, db) - result = db.execute(stmt) - state = result.scalar_one() + user_id: str, + chapter_category: str, + current_segment_chars: int, +) -> bool: + if current_segment_chars >= int(settings.memoir_narrative_immediate_char_threshold): + return True + user_convs = select(Conversation.id).where( + Conversation.user_id == user_id, + Conversation.deleted_at.is_(None), + ) + stmt = select( + func.count(Segment.id), + func.coalesce(func.sum(func.length(Segment.user_input_text)), 0), + ).where( + Segment.conversation_id.in_(user_convs), + Segment.topic_category == chapter_category, + Segment.narrated.is_(False), + Segment.skip_narrative.is_(False), + ) + row = db.execute(stmt).one() + count, total_chars = int(row[0] or 0), int(row[1] or 0) + if count >= int(settings.memoir_narrative_batch_min_segments): + return True + if total_chars >= int(settings.memoir_narrative_batch_min_chars): + return True + return False - slots: Dict[str, Dict] = state.slots or {} - stage_slots = slots.get(stage, {}) - existing = stage_slots.get(slot_name, {}) - merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) - stage_slots[slot_name] = SlotData( - snippet=snippet, segment_ids=merged_segment_ids - ).model_dump() - slots[stage] = stage_slots - state.slots = slots - state.current_stage = stage - db.commit() - db.refresh(state) - return _coerce_state(state) +def _schedule_phase2_timeout(user_id: str, chapter_category: str) -> None: + """Reset countdown for Phase 2 narrative for one category.""" + _revoke_phase2_timeout(user_id, chapter_category) + countdown = float(max(1.0, settings.memoir_narrative_batch_max_wait_seconds)) + celery_app.send_task( + "app.tasks.memoir_tasks.process_memoir_phase2", + args=[user_id, chapter_category], + countdown=countdown, + task_id=_phase2_timeout_task_id(user_id, chapter_category), + ) + logger.info( + "event=phase2_timeout_scheduled user_id={} chapter_category={} countdown={}", + user_id, + chapter_category, + countdown, + ) + + +def _dispatch_phase2_immediate(user_id: str, chapter_category: str) -> None: + _revoke_phase2_timeout(user_id, chapter_category) + celery_app.send_task( + "app.tasks.memoir_tasks.process_memoir_phase2", + args=[user_id, chapter_category], + ) + logger.info( + "event=phase2_dispatched_immediate user_id={} chapter_category={}", + user_id, + chapter_category, + ) + + +def dispatch_pending_memoir_phase2_for_user(user_id: str) -> None: + """会话结束等场景:为该用户所有待叙事类目各发一条 Phase2(幂等)。""" + try: + with get_sync_db() as db: + user_convs = select(Conversation.id).where( + Conversation.user_id == user_id, + Conversation.deleted_at.is_(None), + ) + stmt = ( + select(Segment.topic_category) + .where( + Segment.conversation_id.in_(user_convs), + Segment.narrated.is_(False), + Segment.skip_narrative.is_(False), + Segment.topic_category.isnot(None), + ) + .distinct() + ) + cats = [r[0] for r in db.execute(stmt).all() if r[0]] + for chapter_category in cats: + _revoke_phase2_timeout(user_id, chapter_category) + celery_app.send_task( + "app.tasks.memoir_tasks.process_memoir_phase2", + args=[user_id, chapter_category], + ) + logger.info( + "event=phase2_dispatched_flush user_id={} chapter_category={}", + user_id, + chapter_category, + ) + except Exception as e: + logger.error( + "event=phase2_flush_failed user_id={} exc_type={} exc={}", + user_id, + type(e).__name__, + e, + ) + + +@shared_task(bind=True, max_retries=3, default_retry_delay=30) +def process_memoir_phase2(self, user_id: str, chapter_category: str): + """Phase 2:叙事 / 路由 / 忠实度 / 标题;按类目加锁,消费未叙事且非 skip 的 segments。""" + task_id = self.request.id + logger.info( + "event=memoir_phase2_start user_id={} task_id={} chapter_category={}", + user_id, + task_id, + chapter_category, + ) + try: + with get_sync_db() as db: + user_convs = select(Conversation.id).where( + Conversation.user_id == user_id, + Conversation.deleted_at.is_(None), + ) + stmt = ( + select(Segment) + .where( + Segment.conversation_id.in_(user_convs), + Segment.topic_category == chapter_category, + Segment.narrated.is_(False), + Segment.skip_narrative.is_(False), + ) + .order_by(Segment.created_at) + ) + category_segments = list(db.execute(stmt).scalars().all()) + + if not category_segments: + logger.info( + "event=memoir_phase2_noop user_id={} chapter_category={}", + user_id, + chapter_category, + ) + return {"status": "noop"} + + llm = _get_llm() + user_obj = db.get(User, user_id) + user_profile = "" + user_birth_year = None + background_voice = "default" + user_occupation = "" + if user_obj: + user_birth_year = user_obj.birth_year + user_profile = format_user_profile_context( + birth_year=user_obj.birth_year, + birth_place=user_obj.birth_place, + grew_up_place=user_obj.grew_up_place, + occupation=user_obj.occupation, + ) + background_voice = infer_background_voice(user_obj.occupation) + user_occupation = user_obj.occupation or "" + + image_settings = MemoirImageSettings.from_env() + story_dispatch_ids: Set[str] = set() + chapters_to_enqueue: Set[str] = set() + affected_chapter_ids: Set[str] = set() + + lock_handle = _acquire_chapter_lock( + user_id, chapter_category, ttl_seconds=_chapter_lock_ttl() + ) + if lock_handle is None: + logger.warning( + "event=memoir_phase2_lock_busy user_id={} chapter_category={}", + user_id, + chapter_category, + ) + raise self.retry(countdown=10) + + try: + # 锁内再查一次,避免等待锁期间状态已变 + category_segments = list(db.execute(stmt).scalars().all()) + if not category_segments: + return {"status": "noop"} + + state = get_or_create_state_sync(user_id, db) + chapter, needs_cover, disp = run_story_pipeline_for_category_batch( + db, + user_id=user_id, + chapter_category=chapter_category, + category_segments=category_segments, + state=state, + user_profile=user_profile, + user_birth_year=user_birth_year, + llm=llm, + background_voice=background_voice, + occupation=user_occupation, + ) + story_dispatch_ids |= disp + db.flush() + if chapter is None: + logger.error( + "event=memoir_phase2_no_chapter user_id={} chapter_category={}", + user_id, + chapter_category, + ) + db.rollback() + raise self.retry( + exc=RuntimeError("story_pipeline returned no chapter"), + countdown=30, + ) + + db.refresh(chapter) + affected_chapter_ids.add(chapter.id) + + needs_cover_enqueue = ( + image_settings.enabled and chapter_needs_cover_enqueue(chapter) + ) + + stmt_book = ( + select(Book) + .where(Book.user_id == user_id) + .order_by(Book.updated_at.desc()) + ) + result_book = db.execute(stmt_book) + book = result_book.scalar_one_or_none() + if not book: + book = Book( + id=str(uuid.uuid4()), + user_id=user_id, + title="我的回忆录", + total_pages=0, + total_words=0, + cover_image_url=None, + ) + db.add(book) + book.has_update = True + book.last_update_chapter_id = chapter.id + + if needs_cover_enqueue: + chapters_to_enqueue.add(chapter.id) + + for seg in category_segments: + seg.narrated = True + seg.processed = True + + db.commit() + + from app.features.story.post_commit import ( + enqueue_story_post_commit_effects, + ) + + pc = enqueue_story_post_commit_effects( + user_id=user_id, + story_ids=set(story_dispatch_ids), + chapter_ids=affected_chapter_ids, + trigger_source="pipeline_phase2", + need_compaction=True, + compaction_extra={ + "pipeline_run_id": str(task_id), + "story_dispatch_ids": sorted(story_dispatch_ids), + "chapters_to_enqueue": sorted(chapters_to_enqueue), + "chapter_category": chapter_category, + }, + ) + logger.info( + "event=story_post_commit user_id={} trigger=pipeline_phase2 " + "enqueued_story_image_count={} enqueued_chapter_recompose_count={} " + "compaction_scheduled={} errors={}", + user_id, + pc.enqueued_story_image_count, + pc.enqueued_chapter_recompose_count, + pc.compaction_scheduled, + pc.errors, + ) + + from app.tasks.chapter_cover_enqueue import ( + try_enqueue_generate_chapter_cover, + ) + + for chapter_id in sorted(chapters_to_enqueue): + if try_enqueue_generate_chapter_cover( + chapter_id, source="pipeline_phase2" + ): + logger.info(f"派发章节封面任务: chapter={chapter_id}") + + logger.info( + "event=memoir_phase2_done user_id={} task_id={} chapter_category={} " + "segment_count={}", + user_id, + task_id, + chapter_category, + len(category_segments), + ) + return { + "status": "success", + "chapter_category": chapter_category, + "segments": len(category_segments), + } + finally: + _release_chapter_lock(lock_handle) + + except Retry: + raise + except Exception as e: + logger.error( + "event=memoir_phase2_failed user_id={} chapter_category={} exc={}", + user_id, + chapter_category, + e, + ) + raise self.retry(exc=e) from e @shared_task(bind=True, max_retries=3, default_retry_delay=60) -def process_memoir_segments(self, user_id: str, segment_ids: List[str]): +def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): """ - 处理回忆录段落的 Celery 任务 - - Args: - user_id: 用户 ID - segment_ids: 段落 ID 列表 + Phase 1:记忆 ingest + 抽取/分类;持久化 topic_category / skip_narrative; + 按需派发 Phase 2(阈值或延迟兜底)。 """ task_id = self.request.id logger.info( - f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}" + "event=memoir_phase1_start user_id={} task_id={} segments={}", + user_id, + task_id, + len(segment_ids), ) - - # 更新任务状态为 running _update_task_status_sync(user_id, task_id, "running") try: with get_sync_db() as db: - # 获取段落 - stmt = select(Segment).where(Segment.id.in_(segment_ids)) - result = db.execute(stmt) - segments = result.scalars().all() + stmt = ( + select(Segment) + .where(Segment.id.in_(segment_ids)) + .order_by(Segment.created_at.asc(), Segment.id.asc()) + ) + rows = db.execute(stmt).scalars().all() + segments = [s for s in rows if not s.narrated] if not segments: - logger.warning(f"未找到段落: {segment_ids}") + logger.warning("event=memoir_phase1_no_segments ids={}", segment_ids) + _update_task_status_sync( + user_id, + task_id, + "success", + {"processed": 0, "categories": []}, + ) return {"status": "no_segments"} - # Memory ingest 先于回忆录流水线 commit,保证后续 retrieve_evidence_sync 可见本批 chunk - # (见 api/docs/memory-retrieval.md) conv_id = getattr(segments[0], "conversation_id", None) or "" transcript = "\n\n".join(seg.user_input_text or "" for seg in segments) if transcript.strip(): @@ -322,169 +582,78 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): "event=llm_fast_tier_used pipeline=memoir_prepare_batches model={}", settings.llm_fast_model, ) - image_settings = MemoirImageSettings.from_env() - - user_obj = db.get(User, user_id) - user_profile = "" - user_birth_year = None - background_voice = "default" - user_occupation = "" - if user_obj: - user_birth_year = user_obj.birth_year - user_profile = format_user_profile_context( - birth_year=user_obj.birth_year, - birth_place=user_obj.birth_place, - grew_up_place=user_obj.grew_up_place, - occupation=user_obj.occupation, - ) - background_voice = infer_background_voice(user_obj.occupation) - user_occupation = user_obj.occupation or "" - - story_dispatch_ids: Set[str] = set() memoir_orchestrator = MemoirOrchestrator() prepared = memoir_orchestrator.prepare_batches( segments=list(segments), llm=llm, llm_fast=llm_fast, - get_or_create_state=lambda: _get_or_create_state_sync(user_id, db), - update_slot=lambda stage, slot_name, snippet, seg_ids: ( - _update_slot_sync(user_id, stage, slot_name, snippet, seg_ids, db) + get_or_create_state=lambda: get_or_create_state_sync(user_id, db), + update_slot=lambda stage, slot_name, snippet, seg_ids: update_slot_sync( + user_id, + stage, + slot_name, + snippet, + seg_ids, + db, + memoir_batch=True, ), ) - chapters_to_enqueue: Set[str] = set() - affected_chapter_ids: Set[str] = set() - for ( - chapter_category, - category_segments, - ) in prepared.category_to_segments.items(): - lock_handle = _acquire_chapter_lock( - user_id, chapter_category, ttl_seconds=_chapter_lock_ttl() + skip_ids = prepared.segment_skip_story_ids + missing_cat = [ + seg.id + for seg in segments + if not prepared.segment_chapter_category.get(str(seg.id)) + ] + if missing_cat: + logger.error( + "event=memoir_phase1_missing_category abort segment_ids={}", + missing_cat, + ) + raise RuntimeError( + f"memoir_phase1_missing_category: {len(missing_cat)} segments" ) - if lock_handle is None: - logger.warning( - "章节锁竞争: category={}, 延迟重试", - chapter_category, - ) - raise self.retry(countdown=10) - try: - batch_ids = {str(s.id) for s in category_segments} - skip_ids = prepared.segment_skip_story_ids - in_skip = batch_ids & skip_ids - if in_skip: - logger.info( - "event=memoir_skip_story_signal chapter_category={} " - "segment_ids_in_skip_set={}", - chapter_category, - sorted(in_skip), - ) - if batch_ids and batch_ids <= skip_ids: - logger.info( - "event=story_pipeline_skipped reason=no_substantive_after_none " - "chapter_category={} segment_ids={}", - chapter_category, - sorted(batch_ids), - ) - continue - - chapter, needs_cover, disp = run_story_pipeline_for_category_batch( - db, - user_id=user_id, - chapter_category=chapter_category, - category_segments=category_segments, - state=prepared.state, - user_profile=user_profile, - user_birth_year=user_birth_year, - llm=llm, - background_voice=background_voice, - occupation=user_occupation, - ) - story_dispatch_ids |= disp - db.flush() - db.refresh(chapter) - affected_chapter_ids.add(chapter.id) - - needs_cover_enqueue = ( - image_settings.enabled and chapter_needs_cover_enqueue(chapter) - ) - - stmt_book = ( - select(Book) - .where(Book.user_id == user_id) - .order_by(Book.updated_at.desc()) - ) - result_book = db.execute(stmt_book) - book = result_book.scalar_one_or_none() - if not book: - book = Book( - id=str(uuid.uuid4()), - user_id=user_id, - title="我的回忆录", - total_pages=0, - total_words=0, - cover_image_url=None, - ) - db.add(book) - book.has_update = True - book.last_update_chapter_id = chapter.id - - if chapter and needs_cover_enqueue: - chapters_to_enqueue.add(chapter.id) - finally: - _release_chapter_lock(lock_handle) - - # 标记段落为已处理 for seg in segments: - seg.processed = True + cat = prepared.segment_chapter_category[str(seg.id)] + seg.topic_category = cat + is_skip = str(seg.id) in skip_ids + seg.skip_narrative = is_skip + seg.narrated = False + if is_skip: + seg.processed = True + + db.flush() + + categories_for_phase2: Set[str] = set() + phase2_immediate: list[str] = [] + phase2_timeout: list[str] = [] + for chapter_category, cat_segments in prepared.category_to_segments.items(): + batch_non_skip = [ + s + for s in cat_segments + if str(s.id) not in prepared.segment_skip_story_ids + ] + if not batch_non_skip: + continue + max_chars = max( + len((s.user_input_text or "").strip()) for s in batch_non_skip + ) + categories_for_phase2.add(chapter_category) + if _should_trigger_phase2(db, user_id, chapter_category, max_chars): + phase2_immediate.append(chapter_category) + else: + phase2_timeout.append(chapter_category) db.commit() - from app.features.story.post_commit import enqueue_story_post_commit_effects - - pc = enqueue_story_post_commit_effects( - user_id=user_id, - story_ids=set(story_dispatch_ids), - chapter_ids=affected_chapter_ids, - trigger_source="pipeline", - need_compaction=True, - compaction_extra={ - "pipeline_run_id": str(task_id), - "story_dispatch_ids": sorted(story_dispatch_ids), - "chapters_to_enqueue": sorted(chapters_to_enqueue), - }, - ) - logger.info( - "event=story_post_commit user_id={} trigger=pipeline " - "enqueued_story_image_count={} enqueued_chapter_recompose_count={} " - "compaction_scheduled={} errors={}", - user_id, - pc.enqueued_story_image_count, - pc.enqueued_chapter_recompose_count, - pc.compaction_scheduled, - pc.errors, - ) - - from app.tasks.chapter_cover_enqueue import ( - try_enqueue_generate_chapter_cover, - ) - - for chapter_id in sorted(chapters_to_enqueue): - if try_enqueue_generate_chapter_cover(chapter_id, source="pipeline"): - logger.info(f"派发章节封面任务: chapter={chapter_id}") + for cc in phase2_immediate: + _dispatch_phase2_immediate(user_id, cc) + for cc in phase2_timeout: + _schedule_phase2_timeout(user_id, cc) categories_processed = sorted(prepared.category_to_segments.keys()) - logger.info( - "回忆录处理完成: user_id={} task_id={} segment_count={} " - "categories_processed={}", - user_id, - task_id, - len(segments), - categories_processed, - ) - - # 更新任务状态为成功 _update_task_status_sync( user_id, task_id, @@ -492,25 +661,35 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): { "processed": len(segments), "categories_processed": categories_processed, + "phase2_watch_categories": sorted(categories_for_phase2), }, ) - + logger.info( + "event=memoir_phase1_done user_id={} task_id={} segment_count={} " + "categories={}", + user_id, + task_id, + len(segments), + categories_processed, + ) return { "status": "success", "processed": len(segments), "categories_processed": categories_processed, } + except Retry: + raise except Exception as e: - logger.error(f"回忆录处理失败: {e}") - - # 更新任务状态为失败 + logger.error("event=memoir_phase1_failed user_id={} exc={}", user_id, e) _update_task_status_sync(user_id, task_id, "failure", {"error": str(e)}) - - # 重试 raise self.retry(exc=e) from e +# 兼容旧 Celery/文档入口名 +process_memoir_segments = process_memoir_phase1 + + @shared_task(bind=True, max_retries=3, default_retry_delay=30) def generate_chapter_content(self, user_id: str, stage: str, new_content: str): """ @@ -521,6 +700,7 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str): stage: 阶段 new_content: 新内容 """ + stage = normalize_chapter_category(stage, fallback="summary") logger.info(f"生成章节内容: user_id={user_id}, stage={stage}") try: @@ -547,7 +727,7 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str): self.id = str(uuid.uuid4()) self.user_input_text = text - state = _get_or_create_state_sync(user_id, db) + state = get_or_create_state_sync(user_id, db) chapter, _, dispatch_ids = run_story_pipeline_for_category_batch( db, user_id=user_id, @@ -560,14 +740,24 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str): background_voice=background_voice, occupation=user_occupation, ) + db.flush() + if chapter is None: + logger.error( + "event=generate_chapter_content_no_chapter user_id={} stage={}", + user_id, + stage, + ) + db.rollback() + raise self.retry( + exc=RuntimeError("story_pipeline returned no chapter"), + countdown=30, + ) db.commit() db.refresh(chapter) from app.features.story.post_commit import enqueue_story_post_commit_effects - ch_ids: set[str] = set() - if chapter is not None: - ch_ids.add(str(chapter.id)) + ch_ids: set[str] = {str(chapter.id)} pc = enqueue_story_post_commit_effects( user_id=user_id, story_ids=set(dispatch_ids), @@ -599,6 +789,8 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str): try_enqueue_generate_chapter_cover(chapter.id, source="pipeline") return {"status": "success"} + except Retry: + raise except Exception as e: logger.error(f"章节生成失败: {e}") raise self.retry(exc=e) from e diff --git a/api/tests/test_background_runner.py b/api/tests/test_background_runner.py index 24a2868..4a0eede 100644 --- a/api/tests/test_background_runner.py +++ b/api/tests/test_background_runner.py @@ -73,13 +73,55 @@ async def test_flush_pending_submits_without_gate( first_queued_monotonic=0.0, ) - with patch.object(runner, "_submit_task", new=AsyncMock(side_effect=fake_submit)): + with ( + patch.object(runner, "_submit_task", new=AsyncMock(side_effect=fake_submit)), + patch.object( + runner, + "_flush_pending_phase2", + new=AsyncMock(return_value=None), + ), + ): await runner.flush_pending(uid) assert submitted == [("u1", ["s1", "s2"])] assert uid not in runner._batch +@pytest.mark.asyncio +async def test_flush_pending_merges_batch_and_extra_deduped( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(br.settings, "memoir_segment_batch_min_chars", 9999) + monkeypatch.setattr(br.settings, "memoir_segment_batch_max_wait_seconds", 9999.0) + + submitted: list[tuple[str, list[str]]] = [] + + async def fake_submit(uid: str, ids: list[str]) -> str: + submitted.append((uid, ids)) + return "tid" + + runner = br.BackgroundTaskRunner(debounce_seconds=30) + uid = "u1" + runner._batch[uid] = br._MemoirBatchState( + segment_ids=["s1", "s2"], + total_text_chars=3, + first_queued_monotonic=0.0, + ) + + with ( + patch.object(runner, "_submit_task", new=AsyncMock(side_effect=fake_submit)), + patch.object( + runner, + "_flush_pending_phase2", + new=AsyncMock(return_value=None), + ), + ): + await runner.flush_pending(uid, extra_segment_ids=["s2", "s3", "s1"]) + + assert submitted == [("u1", ["s1", "s2", "s3"])] + assert uid not in runner._batch + + @pytest.mark.asyncio async def test_queue_message_min_chars_zero_submits_after_debounce( monkeypatch: pytest.MonkeyPatch, diff --git a/api/tests/test_memoir_skip_story.py b/api/tests/test_memoir_skip_story.py index 6ce7e72..17174ff 100644 --- a/api/tests/test_memoir_skip_story.py +++ b/api/tests/test_memoir_skip_story.py @@ -9,12 +9,15 @@ from app.agents.memoir.classification_agent import ( ) from app.agents.memoir.extraction_agent import ExtractionResult from app.agents.memoir.orchestrator import MemoirOrchestrator -from app.agents.state_schema import MemoirStateSchema +from app.agents.stage_constants import CHAT_STAGES +from app.agents.state_schema import DEFAULT_STAGE_ORDER, MemoirStateSchema def _empty_state() -> MemoirStateSchema: + """与生产默认一致的五阶段 stage_order(计划 §5-C 全量阶段管道覆盖)。""" + assert list(CHAT_STAGES) == DEFAULT_STAGE_ORDER return MemoirStateSchema( - stage_order=["childhood"], + stage_order=list(CHAT_STAGES), current_stage="childhood", covered_stages=[], slots={}, diff --git a/api/tests/test_memoir_two_phase.py b/api/tests/test_memoir_two_phase.py new file mode 100644 index 0000000..702c1fe --- /dev/null +++ b/api/tests/test_memoir_two_phase.py @@ -0,0 +1,85 @@ +"""回忆录两阶段管线:Phase2 触发条件与 orchestrator 字段。""" + +from unittest.mock import MagicMock + +import pytest + +from app.agents.memoir.orchestrator import MemoirOrchestrator +from app.agents.memoir.extraction_agent import ExtractionResult +from app.agents.memoir.classification_agent import ChapterClassifyResult +from app.agents.state_schema import MemoirStateSchema + +from app.tasks.memoir_tasks import _should_trigger_phase2 + + +def test_segment_chapter_category_populated() -> None: + orch = MemoirOrchestrator() + orch.extraction_agent.extract = MagicMock( + return_value=ExtractionResult( + detected_stage="childhood", slots={"toy": "布娃娃"} + ) + ) + orch.classification_agent.classify = MagicMock( + return_value=ChapterClassifyResult(category="childhood", llm_said_none=False) + ) + st = MemoirStateSchema( + stage_order=["childhood"], + current_stage="childhood", + covered_stages=[], + slots={}, + ) + + def get_state() -> MemoirStateSchema: + return st + + def update_slot( + stage: str, slot_name: str, snippet: str, seg_ids: list[str] + ) -> MemoirStateSchema: + return st + + class _Seg: + def __init__(self, sid: str, text: str) -> None: + self.id = sid + self.user_input_text = text + + s1 = _Seg("a1", "小时候喜欢玩布娃娃") + p = orch.prepare_batches( + segments=[s1], + llm=MagicMock(), + get_or_create_state=get_state, + update_slot=update_slot, + ) + assert p.segment_chapter_category["a1"] == "childhood" + + +@pytest.mark.parametrize( + "count,total_chars,current_chars,expect", + [ + (1, 10, 60, True), # immediate via current segment chars + (3, 5, 5, True), # batch min segments + (2, 100, 5, True), # batch min total chars + (2, 50, 5, False), # below both accum thresholds + ], +) +def test_should_trigger_phase2_matrix( + monkeypatch: pytest.MonkeyPatch, + count: int, + total_chars: int, + current_chars: int, + expect: bool, +) -> None: + monkeypatch.setattr( + "app.tasks.memoir_tasks.settings.memoir_narrative_immediate_char_threshold", + 50, + ) + monkeypatch.setattr( + "app.tasks.memoir_tasks.settings.memoir_narrative_batch_min_segments", + 3, + ) + monkeypatch.setattr( + "app.tasks.memoir_tasks.settings.memoir_narrative_batch_min_chars", + 80, + ) + db = MagicMock() + db.execute.return_value.one.return_value = (count, total_chars) + assert _should_trigger_phase2(db, "user-1", "childhood", current_chars) == expect diff --git a/api/tests/test_stage_validation.py b/api/tests/test_stage_validation.py new file mode 100644 index 0000000..c66be73 --- /dev/null +++ b/api/tests/test_stage_validation.py @@ -0,0 +1,83 @@ +"""阶段 / 章节归一化纯函数(VALID + normalize + chat_bucket)。""" + +import json +from unittest.mock import MagicMock + +import pytest + +from app.agents.memoir.extraction_agent import ExtractionAgent +from app.agents.stage_constants import ( + chat_bucket, + normalize_chapter_category, + normalize_chat_stage, +) + + +def test_normalize_chat_stage_valid_and_alias() -> None: + assert normalize_chat_stage("career", "childhood") == "career" + assert normalize_chat_stage("BELIEF", "childhood") == "belief" + assert normalize_chat_stage("beliefs", "childhood") == "belief" + + +def test_normalize_chat_stage_chapter_to_chat_bucket() -> None: + assert normalize_chat_stage("career_early", "childhood") == "career" + + +def test_normalize_chat_stage_invalid_fallback() -> None: + assert normalize_chat_stage("not_a_stage", "education") == "education" + assert normalize_chat_stage(None, "family") == "family" + assert normalize_chat_stage("", "family") == "family" + + +def test_normalize_chapter_category() -> None: + assert normalize_chapter_category("childhood", "summary") == "childhood" + assert normalize_chapter_category("bogus", "childhood") == "childhood" + assert normalize_chapter_category(None, "invalid_fallback") == "summary" + + +def test_chat_bucket() -> None: + assert chat_bucket("career_early") == "career" + assert chat_bucket("belief") == "belief" + assert chat_bucket("beliefs") == "belief" + + +def test_extraction_agent_normalizes_detected_stage( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = ExtractionAgent() + llm = MagicMock() + monkeypatch.setattr( + "app.agents.memoir.extraction_agent.invoke_json_object", + lambda *_a, **_k: json.dumps( + {"detected_stage": "career_early", "slots": {}}, + ensure_ascii=False, + ), + ) + r = agent.extract( + user_message="hello", + current_stage="childhood", + stage_slots={}, + llm=llm, + ) + assert r.detected_stage == "career" + + +def test_extraction_agent_invalid_detected_falls_back( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = ExtractionAgent() + llm = MagicMock() + monkeypatch.setattr( + "app.agents.memoir.extraction_agent.invoke_json_object", + lambda *_a, **_k: json.dumps( + {"detected_stage": "llm_hallucination", "slots": {}}, + ensure_ascii=False, + ), + ) + r = agent.extract( + user_message="hello", + current_stage="education", + stage_slots={}, + llm=llm, + ) + assert r.detected_stage == "education" diff --git a/api/tests/test_state_service_batch_stage_policy.py b/api/tests/test_state_service_batch_stage_policy.py new file mode 100644 index 0000000..704c8a8 --- /dev/null +++ b/api/tests/test_state_service_batch_stage_policy.py @@ -0,0 +1,193 @@ +"""Memoir batch 路径下 current_stage 真值表(memoir_extraction_updates_current_stage)。""" + +# 与 alembic/env.py 一致:注册全部 ORM,避免 User relationship 解析失败 +from app.features.asset import models as _asset_models # noqa: F401 +from app.features.auth import models as _auth_models # noqa: F401 +from app.features.conversation import models as _conv_models # noqa: F401 +from app.features.memoir import models as _memoir_models # noqa: F401 +from app.features.memory import models as _memory_models # noqa: F401 +from app.features.payment import models as _payment_models # noqa: F401 +from app.features.story import models as _story_models # noqa: F401 +from app.features.user import models as _user_models # noqa: F401 + +import uuid +from types import SimpleNamespace + +import pytest +from sqlalchemy import create_engine, select +from sqlalchemy.orm import sessionmaker + +from app.agents.state_schema import default_state +from app.core.config import settings +from app.core.db import Base +from app.features.memoir.models import MemoirState as MemoirStateModel +from app.features.memoir.state_service import ( + _apply_current_stage_policy, + update_slot_sync, +) +from app.features.user.models import User + + +@pytest.fixture +def sqlite_session_factory(): + engine = create_engine("sqlite:///:memory:", future=True) + Base.metadata.create_all( + engine, + tables=[ + User.__table__, + MemoirStateModel.__table__, + ], + ) + yield sessionmaker(bind=engine, expire_on_commit=False, future=True) + engine.dispose() + + +def _add_user_and_state( + db, + *, + user_id: str, + current_stage: str, +) -> None: + db.add( + User( + id=user_id, + phone=f"p-{user_id}", + password_hash="x", + nickname="t", + ) + ) + default = default_state() + db.add( + MemoirStateModel( + id=str(uuid.uuid4()), + user_id=user_id, + stage_order=default.stage_order, + current_stage=current_stage, + covered_stages=default.covered_stages, + slots={ + k: {sk: sv.model_dump() for sk, sv in v.items()} + for k, v in default.slots.items() + }, + ) + ) + db.commit() + + +def test_apply_current_stage_policy_live_path_always_writes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", False) + state = SimpleNamespace(current_stage="childhood") + _apply_current_stage_policy(state, "career", memoir_batch=False) + assert state.current_stage == "career" + + +def test_apply_current_stage_policy_batch_flag_off_short_circuit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", False) + state = SimpleNamespace(current_stage="childhood") + _apply_current_stage_policy(state, "career", memoir_batch=True) + assert state.current_stage == "childhood" + + +def test_apply_current_stage_policy_batch_same_bucket_updates( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + state = SimpleNamespace(current_stage="career") + _apply_current_stage_policy(state, "career", memoir_batch=True) + assert state.current_stage == "career" + + +def test_apply_current_stage_policy_batch_same_bucket_repairs_chapter_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + state = SimpleNamespace(current_stage="career_early") + _apply_current_stage_policy(state, "career", memoir_batch=True) + assert state.current_stage == "career" + + +def test_apply_current_stage_policy_batch_cross_bucket_blocked( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + state = SimpleNamespace(current_stage="childhood") + _apply_current_stage_policy(state, "career", memoir_batch=True) + assert state.current_stage == "childhood" + + +def test_update_slot_sync_batch_respects_flag_false( + sqlite_session_factory, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", False) + uid = "u-batch-off" + db = sqlite_session_factory() + _add_user_and_state(db, user_id=uid, current_stage="childhood") + + update_slot_sync( + uid, + "career", + "job", + "snippet", + ["s1"], + db, + memoir_batch=True, + ) + st = db.execute( + select(MemoirStateModel).where(MemoirStateModel.user_id == uid) + ).scalar_one() + assert st.current_stage == "childhood" + assert "career" in (st.slots or {}) + + +def test_update_slot_sync_batch_flag_true_same_bucket_updates_row( + sqlite_session_factory, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + uid = "u-batch-on" + db = sqlite_session_factory() + _add_user_and_state(db, user_id=uid, current_stage="career") + + update_slot_sync( + uid, + "career_achievement", + "peak", + "won prize", + ["s2"], + db, + memoir_batch=True, + ) + st = db.execute( + select(MemoirStateModel).where(MemoirStateModel.user_id == uid) + ).scalar_one() + assert st.current_stage == "career" + assert st.slots.get("career", {}).get("peak") is not None + + +def test_update_slot_sync_batch_flag_true_cross_bucket_unchanged( + sqlite_session_factory, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + uid = "u-cross" + db = sqlite_session_factory() + _add_user_and_state(db, user_id=uid, current_stage="childhood") + + update_slot_sync( + uid, + "career", + "job", + "actor", + ["s3"], + db, + memoir_batch=True, + ) + st = db.execute( + select(MemoirStateModel).where(MemoirStateModel.user_id == uid) + ).scalar_one() + assert st.current_stage == "childhood" + assert st.slots.get("career", {}).get("job", {}).get("snippet") == "actor"