diff --git a/api/README.md b/api/README.md index a5a2474..114e0fe 100644 --- a/api/README.md +++ b/api/README.md @@ -10,7 +10,7 @@ Life Echo API 是一个智能对话系统,通过 WebSocket 实时连接,使 - **会话真源**:`conversation_messages`(DB)+ Redis 缓存;**实时编排入口**:`ChatOrchestrator`。 - **图像管线**:正文主图 `generate_story_image`;章节封面 `try_enqueue_generate_chapter_cover` → `generate_chapter_cover`。 -- **回忆录批次**:`MemoirOrchestrator.prepare_batches` 显式分桶后,`process_memoir_segments` 按类别加锁并调用 `run_story_pipeline_for_category_batch`(含 `StoryRouteAgent.plan_batch` 多 unit 写入)。 +- **回忆录批次**:`MemoirOrchestrator.prepare_batches` 显式分桶后,`process_memoir_phase1` 派发 Phase 2 按类别调用 `run_story_pipeline_for_category_batch`(含 `StoryRouteAgent.plan_batch` 多 unit 写入)。 ### LLM 与记忆(约定文档) diff --git a/api/alembic/versions/0017_segment_narrative_defer.py b/api/alembic/versions/0017_segment_narrative_defer.py new file mode 100644 index 0000000..cb8d47f --- /dev/null +++ b/api/alembic/versions/0017_segment_narrative_defer.py @@ -0,0 +1,75 @@ +"""segments:Phase2 低置信路由延迟元数据 + +Revision ID: 0017_segment_narrative_defer +Revises: 0016_memory_pipeline_status +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +revision: str = "0017_segment_narrative_defer" +down_revision: Union[str, None] = "0016_memory_pipeline_status" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _column_names(table_name: str) -> set[str]: + bind = op.get_bind() + inspector = sa.inspect(bind) + return {column["name"] for column in inspector.get_columns(table_name)} + + +def upgrade() -> None: + columns = _column_names("segments") + if "narrative_deferred_until" not in columns: + op.add_column( + "segments", + sa.Column( + "narrative_deferred_until", + sa.DateTime(timezone=True), + nullable=True, + ), + ) + if "narrative_defer_count" not in columns: + op.add_column( + "segments", + sa.Column( + "narrative_defer_count", + sa.Integer(), + nullable=False, + server_default=sa.text("0"), + ), + ) + if "narrative_defer_reason" not in columns: + op.add_column( + "segments", + sa.Column( + "narrative_defer_reason", + sa.String(), + nullable=True, + ), + ) + if "narrative_last_attempt_at" not in columns: + op.add_column( + "segments", + sa.Column( + "narrative_last_attempt_at", + sa.DateTime(timezone=True), + nullable=True, + ), + ) + + +def downgrade() -> None: + columns = _column_names("segments") + for column in ( + "narrative_last_attempt_at", + "narrative_defer_reason", + "narrative_defer_count", + "narrative_deferred_until", + ): + if column in columns: + op.drop_column("segments", column) diff --git a/api/app/agents/chat/profile_agent.py b/api/app/agents/chat/profile_agent.py index 86b4328..a5f738a 100644 --- a/api/app/agents/chat/profile_agent.py +++ b/api/app/agents/chat/profile_agent.py @@ -22,7 +22,6 @@ from app.agents.chat.reply_limits import ( from app.agents.chat.schemas import ProfileExtractionOutput from app.core.agent_logging import agent_span, log_agent_payload, log_agent_summary from app.core.config import settings -from app.core.dependencies import get_llm_provider from app.core.llm_call import allm_json_call from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger @@ -31,11 +30,53 @@ from app.ports.llm import LLMProvider logger = get_logger(__name__) -def _get_langchain_llm(): - try: - return LlmGateway().langchain_llm_for(LlmUseCase("chat.profile")) - except Exception: - return None +class _ProviderBackedProfileGateway: + def __init__(self, provider: LLMProvider) -> None: + self._provider = provider + + async def chat_text( + self, + messages: list[dict], + *, + use_case: LlmUseCase | None = None, + temperature: float | None = None, + model: str | None = None, + max_tokens: int | None = None, + ) -> str: + resolved_temperature = temperature + if resolved_temperature is None: + resolved_temperature = ( + use_case.temperature + if use_case and use_case.temperature is not None + else 0.7 + ) + return await self._provider.complete( + messages, + temperature=resolved_temperature, + model=model if model is not None else (use_case.model if use_case else None), + max_tokens=( + max_tokens + if max_tokens is not None + else (use_case.max_tokens if use_case else None) + ), + ) + + async def json_object( + self, + prompt: str, + schema: type[ProfileExtractionOutput], + *, + use_case: LlmUseCase, + fallback_factory: Any = None, + ) -> ProfileExtractionOutput: + return await allm_json_call( + getattr(self._provider, "langchain_llm", None), + prompt, + schema, + max_tokens=use_case.max_tokens or 1024, + agent=use_case.name, + fallback_factory=fallback_factory, + ) def _langchain_messages_to_port(messages: List[Any]) -> list[dict]: @@ -66,14 +107,17 @@ def _message_contents_char_count(messages: List[Any]) -> int: class ProfileAgent: """用户资料收集 Specialist Agent""" - def __init__(self, llm_provider: LLMProvider | None = None): - self._llm_provider = llm_provider - self.llm = _get_langchain_llm() - - def _provider(self) -> LLMProvider: - if self._llm_provider is not None: - return self._llm_provider - return get_llm_provider() + def __init__( + self, + llm_provider: LLMProvider | None = None, + llm_gateway: Any | None = None, + ) -> None: + if llm_gateway is not None: + self._llm_gateway = llm_gateway + elif llm_provider is not None: + self._llm_gateway = _ProviderBackedProfileGateway(llm_provider) + else: + self._llm_gateway = LlmGateway() async def _invoke_chat( self, @@ -88,8 +132,9 @@ class ProfileAgent: with agent_span( logger, f"{agent_name}.llm", conversation_id=conversation_id or "" ): - response_text = await self._provider().complete( + response_text = await self._llm_gateway.chat_text( port_messages, + use_case=LlmUseCase("chat.profile", max_tokens=max_tokens), max_tokens=max_tokens, ) logger.info( @@ -130,7 +175,7 @@ class ProfileAgent: conversation_id: Optional[str] = None, ) -> Dict[str, Any]: """从用户消息中提取资料字段,不持久化""" - if not self.llm or not missing_fields: + if not missing_fields: return {} recent_dialogue = "" if conversation_id: @@ -151,12 +196,13 @@ class ProfileAgent: prompt = get_profile_extraction_prompt( user_message, missing_fields, recent_dialogue=recent_dialogue or None ) - parsed = await allm_json_call( - self.llm, + parsed = await self._llm_gateway.json_object( prompt, ProfileExtractionOutput, - max_tokens=settings.chat_profile_extract_max_tokens, - agent="ProfileAgent.extract_profile_from_message", + use_case=LlmUseCase( + "ProfileAgent.extract_profile_from_message", + max_tokens=settings.chat_profile_extract_max_tokens, + ), fallback_factory=lambda: ProfileExtractionOutput(), ) result = {} @@ -197,8 +243,6 @@ class ProfileAgent: interview_stage_hint: str = "", ) -> List[str]: """生成资料追问回复,不持久化(由 Orchestrator 负责)""" - if not self.llm: - return ["谢谢!还能告诉我更多吗?"] try: prompt = get_profile_followup_prompt( missing_fields, @@ -260,8 +304,6 @@ class ProfileAgent: nickname: str = "", ) -> List[str]: """生成资料收集开场白,不持久化(由 Orchestrator 负责)""" - if not self.llm: - return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"] try: prompt = get_profile_greeting_prompt(missing_fields, nickname) hw = await get_history_with_window( diff --git a/api/app/agents/image_prompt/orchestrator.py b/api/app/agents/image_prompt/orchestrator.py index 89ecb57..332517c 100644 --- a/api/app/agents/image_prompt/orchestrator.py +++ b/api/app/agents/image_prompt/orchestrator.py @@ -9,8 +9,12 @@ from __future__ import annotations from typing import Any, Optional from app.agents.image_prompt.prompt_agent import PromptGenerationAgent +from app.core.config import settings +from app.core.logging import get_logger from app.features.memoir.memoir_images.settings import MemoirImageSettings +logger = get_logger(__name__) + class ImagePromptOrchestrator: """ @@ -76,5 +80,15 @@ def get_image_prompt_orchestrator() -> ImagePromptOrchestrator: """Celery / 后台任务入口:统一装配 LLM 与 MemoirImageSettings。""" from app.core.llm_gateway import LlmGateway, LlmUseCase - llm = LlmGateway().langchain_llm_for(LlmUseCase("image_prompt")) - return ImagePromptOrchestrator(llm=llm, settings=MemoirImageSettings.from_env()) + image_settings = MemoirImageSettings.from_env() + try: + llm = LlmGateway().langchain_llm_for(LlmUseCase("image_prompt")) + except Exception as e: + if settings.image_prompt_fallback_disabled: + raise + logger.warning( + "ImagePromptOrchestrator LLM 初始化失败,使用确定性 fallback: {}", + e, + ) + llm = None + return ImagePromptOrchestrator(llm=llm, settings=image_settings) diff --git a/api/app/agents/memoir/batch_phase1_prep.py b/api/app/agents/memoir/batch_phase1_prep.py index f76ac87..829ceca 100644 --- a/api/app/agents/memoir/batch_phase1_prep.py +++ b/api/app/agents/memoir/batch_phase1_prep.py @@ -10,7 +10,6 @@ from typing import Any, Callable, Dict, List from app.agents.memoir.prompts import get_batch_memoir_phase1_prep_prompt from app.agents.memoir.schemas import BatchPhase1LLMOutput -from app.agents.stage_constants import STAGE_SLOT_KEYS from app.agents.state_schema import MemoirStateSchema from app.core.config import settings from app.core.llm_call import LLMCallError, llm_json_call @@ -19,11 +18,6 @@ from app.features.conversation.models import Segment logger = get_logger(__name__) -STAGE_ALLOWED_SLOTS: Dict[str, frozenset[str]] = { - k: frozenset(v) for k, v in STAGE_SLOT_KEYS.items() -} - - def _slots_snapshot(state: MemoirStateSchema) -> dict: snap: dict = {} for stage, buckets in (state.slots or {}).items(): diff --git a/api/app/agents/memoir/orchestrator.py b/api/app/agents/memoir/orchestrator.py index 2ecbe20..56d7091 100644 --- a/api/app/agents/memoir/orchestrator.py +++ b/api/app/agents/memoir/orchestrator.py @@ -8,12 +8,9 @@ from __future__ import annotations import time from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set -from app.agents.memoir.batch_phase1_prep import ( - STAGE_ALLOWED_SLOTS, - run_batch_phase1_prep_chunked, -) +from app.agents.memoir.batch_phase1_prep import run_batch_phase1_prep_chunked from app.agents.memoir.classification_agent import ( ClassificationAgent, _looks_like_fragment_only, @@ -22,7 +19,11 @@ from app.agents.memoir.classification_agent import ( _detect_stage as detect_stage_from_keywords, ) from app.agents.memoir.extraction_agent import ExtractionAgent, ExtractionResult -from app.agents.stage_constants import normalize_chapter_category, normalize_chat_stage +from app.agents.stage_constants import ( + filter_stage_slots, + normalize_chapter_category, + normalize_chat_stage, +) from app.agents.state_schema import MemoirStateSchema from app.core.agent_logging import agent_span, agent_summary_enabled, log_agent_detail from app.core.config import settings @@ -92,7 +93,7 @@ class MemoirOrchestrator: ) if use_batch: try: - result = self._prepare_batches_via_batch_llm( + prepared_batch = self._prepare_batches_via_batch_llm( segments=segments, state=state, classify_extract_llm=classify_extract_llm, @@ -104,7 +105,7 @@ class MemoirOrchestrator: "msg=Phase1 批处理 LLM 路径已使用", len(segments), ) - return result + return prepared_batch except Exception as e: logger.warning( "event=phase1_batch_path_fallback segment_count={} exc={} " @@ -132,8 +133,12 @@ class MemoirOrchestrator: stage_slots=stage_slots_raw, llm=classify_extract_llm, ) - detected_stage = result.detected_stage - for slot_name, snippet in result.slots.items(): + fb = state.current_stage or "childhood" + detected_stage = normalize_chat_stage(result.detected_stage, fb) + result_slots = filter_stage_slots(detected_stage, result.slots, fb) + if not result_slots: + detected_stage = normalize_chat_stage(fb, fb) + for slot_name, snippet in result_slots.items(): state = update_slot(detected_stage, slot_name, snippet, [segment.id]) with agent_span( @@ -148,7 +153,7 @@ class MemoirOrchestrator: segment_id=segment.id, ) chapter_category = classify_result.category - if (not result.slots) and classify_result.llm_said_none: + if (not result_slots) and classify_result.llm_said_none: segment_skip_story_ids.add(str(segment.id)) segment_chapter_category[str(segment.id)] = chapter_category @@ -166,7 +171,7 @@ class MemoirOrchestrator: logger, "MemoirOrchestrator.segment_done segment_id={} slots={}", segment.id, - list((result.slots or {}).keys()), + list(result_slots.keys()), ) category_to_segments.setdefault(chapter_category, []).append(segment) @@ -211,8 +216,7 @@ class MemoirOrchestrator: else: detected_stage = normalize_chat_stage(row.detected_stage, fb) - allowed = STAGE_ALLOWED_SLOTS.get(detected_stage, frozenset()) - result_slots = {k: v for k, v in result_slots.items() if k in allowed} + result_slots = filter_stage_slots(detected_stage, result_slots, fb) if not result_slots: detected_stage = normalize_chat_stage(fb, fb) @@ -269,72 +273,3 @@ class MemoirOrchestrator: segment_skip_story_ids=segment_skip_story_ids, segment_chapter_category=segment_chapter_category, ) - - def run( - self, - *, - segments: List[Segment], - llm: Any, - user_profile: str = "", - user_birth_year: Any = None, - get_or_create_state: Callable[[], MemoirStateSchema], - update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema], - acquire_lock: Callable[[str], bool], - release_lock: Callable[[str], None], - process_category: Callable[ - [ - str, - List[Segment], - MemoirStateSchema, - str, - Any, - Any, - ], - Tuple[Any, bool], - ], - raise_retry: Callable[[], None], - llm_fast: Any | None = None, - ) -> Tuple[Set[str], int]: - """ - 执行回忆录流水线。 - process_category(category, segments, state, user_profile, user_birth_year, llm) - 返回 (chapter, has_images_to_generate)。 - 返回 (chapters_to_enqueue, processed_count)。 - raise_retry 用于锁竞争时抛出 Celery retry。 - """ - prepared = self.prepare_batches( - segments=segments, - llm=llm, - llm_fast=llm_fast, - get_or_create_state=get_or_create_state, - update_slot=update_slot, - on_phase1_chunk=None, - ) - state = prepared.state - chapters_to_enqueue: Set[str] = set() - category_to_segments = prepared.category_to_segments - - # 按 category 调用 process_category:叙事生成、持久化、封面入队标记 - for chapter_category, category_segments in category_to_segments.items(): - if not acquire_lock(chapter_category): - logger.warning( - "章节锁竞争: category={}, 延迟重试", - chapter_category, - ) - raise_retry() - - try: - chapter, has_images = process_category( - chapter_category, - category_segments, - state, - user_profile, - user_birth_year, - llm, - ) - if chapter and has_images: - chapters_to_enqueue.add(chapter.id) - finally: - release_lock(chapter_category) - - return chapters_to_enqueue, len(segments) diff --git a/api/app/agents/memoir/story_route_agent.py b/api/app/agents/memoir/story_route_agent.py index b18dab1..6e1a00d 100644 --- a/api/app/agents/memoir/story_route_agent.py +++ b/api/app/agents/memoir/story_route_agent.py @@ -31,6 +31,9 @@ PLAN_BATCH_MAX_SEGMENTS = 48 # 童年 / 求学 / 家庭:模型与后处理均倾向「少拆分、优先续写」 APPEND_FIRST_CHAPTER_CATEGORIES = frozenset({"childhood", "education", "family"}) +# These route outcomes are conservative fail-safes, not semantic append matches. +FALLBACK_NEW_STORY_REASONS = frozenset({"no_llm", "parse_error", "invalid_target"}) + def default_append_target_story_id( candidate_stories: list[Story], @@ -220,13 +223,6 @@ class StoryRouteAgent: story_meta: dict[str, dict[str, int]] | None = None, ) -> StoryRouteDecision: if not llm: - fb = default_append_target_story_id(candidate_stories, story_meta, settings) - if fb and fb in valid_story_ids: - return StoryRouteDecision( - decision="append_story", - target_story_id=fb, - reason="no_llm_default_append", - ) return StoryRouteDecision( decision="new_story", new_story_title=None, @@ -241,13 +237,6 @@ class StoryRouteAgent: ) def _decide_fallback() -> StoryRouteDecision: - fb = default_append_target_story_id(candidate_stories, story_meta, settings) - if fb and fb in valid_story_ids: - return StoryRouteDecision( - decision="append_story", - target_story_id=fb, - reason="parse_error_default_append", - ) return StoryRouteDecision( decision="new_story", new_story_title=None, @@ -266,22 +255,8 @@ class StoryRouteAgent: if decision.decision == "append_story": tid = decision.target_story_id if not tid or tid not in valid_story_ids: - fb = default_append_target_story_id( - candidate_stories, story_meta, settings - ) - if fb and fb in valid_story_ids: - logger.info( - "StoryRoute append 无效 target_story_id={},回退默认 append {}", - tid, - fb, - ) - return StoryRouteDecision( - decision="append_story", - target_story_id=fb, - reason="invalid_target_default_append", - ) logger.warning( - "StoryRoute append 无效 target_story_id={},且无可用默认目标,回退 new_story", + "StoryRoute append 无效 target_story_id={},回退 new_story", tid, ) return StoryRouteDecision( diff --git a/api/app/agents/stage_constants.py b/api/app/agents/stage_constants.py index 0ff3001..4281831 100644 --- a/api/app/agents/stage_constants.py +++ b/api/app/agents/stage_constants.py @@ -68,6 +68,35 @@ STAGE_SLOT_KEYS: dict[str, tuple[str, ...]] = { "belief": ("value", "regret", "pride", "lesson"), } +STAGE_ALLOWED_SLOTS: dict[str, frozenset[str]] = { + k: frozenset(v) for k, v in STAGE_SLOT_KEYS.items() +} + + +def allowed_slot_names_for_stage( + stage: str | None, + fallback: str = "childhood", +) -> frozenset[str]: + stage_norm = normalize_chat_stage(stage, fallback=fallback) + return STAGE_ALLOWED_SLOTS.get(stage_norm, frozenset()) + + +def is_valid_stage_slot( + stage: str | None, + slot_name: str, + fallback: str = "childhood", +) -> bool: + return slot_name in allowed_slot_names_for_stage(stage, fallback=fallback) + + +def filter_stage_slots( + stage: str | None, + slots: dict[str, str], + fallback: str = "childhood", +) -> dict[str, str]: + allowed = allowed_slot_names_for_stage(stage, fallback=fallback) + return {k: v for k, v in (slots or {}).items() if k in allowed} + # 人生阶段 / 章节类目的年龄参照(仅用于 prompt 时间提示;非业务校验) STAGE_ERA_HINTS: dict[str, tuple[int, int]] = { "childhood": (0, 12), diff --git a/api/app/agents/style_profiles.py b/api/app/agents/style_profiles.py index 018a511..e15cb89 100644 --- a/api/app/agents/style_profiles.py +++ b/api/app/agents/style_profiles.py @@ -22,7 +22,6 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import List, Tuple - # ============================================================================= # 共享:Memoir 评测维度单一事实源 # ============================================================================= diff --git a/api/app/core/config.py b/api/app/core/config.py index 393dda3..a167651 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -349,6 +349,13 @@ class Settings(BaseSettings): memoir_recompose_retry_on_lock_contention: bool = True # Phase2 立即派发使用固定 task_id,减少同类目重复入队(超时任务仍用独立 id) memoir_phase2_singleflight_immediate: bool = True + # True:Phase2 路由低置信(no_llm/parse_error/invalid_target)时不写 Story, + # 把 segment 标记为 narrative_deferred_until 之后再重试。 + memoir_route_defer_enabled: bool = True + # 低置信延迟时长(秒):到期前不消费这些 segment,避免后台空转 + memoir_route_defer_seconds: float = Field(default=120.0, ge=1.0, le=3600.0) + # 同一类目最多自动延迟次数;达到上限后 segment 仅靠新素材到达激活,不再自动重试 + memoir_route_defer_max_attempts: int = Field(default=3, ge=1, le=20) # True:Phase2 首稿后异步运行质量增强(fidelity recheck、标题润色、LLM 归一) memoir_quality_pass_enabled: bool = True memoir_quality_pass_delay_seconds: int = Field(default=5, ge=0, le=300) diff --git a/api/app/features/conversation/models.py b/api/app/features/conversation/models.py index 36f3bfe..56b1306 100644 --- a/api/app/features/conversation/models.py +++ b/api/app/features/conversation/models.py @@ -58,6 +58,13 @@ class Segment(Base): narrated = Column(Boolean, default=False, server_default="false") # Phase 1 判定无需进故事管线(无 slots 且 LLM 判 none) skip_narrative = Column(Boolean, default=False, server_default="false") + # Phase 2 路由低置信延迟:到期前不消费;新同类目素材到达可清空。 + narrative_deferred_until = Column(DateTime(timezone=True), nullable=True) + narrative_defer_count = Column( + Integer, nullable=False, default=0, server_default="0" + ) + narrative_defer_reason = Column(String, nullable=True) + narrative_last_attempt_at = Column(DateTime(timezone=True), nullable=True) agent_response = Column(Text, nullable=True) tts_audio_urls = Column(JSON, nullable=True) # 用户轮次 durable message id(与 lineage_json 同步;便于查询) diff --git a/api/app/features/memoir/state_service.py b/api/app/features/memoir/state_service.py index b32d6f2..8246d9e 100644 --- a/api/app/features/memoir/state_service.py +++ b/api/app/features/memoir/state_service.py @@ -11,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.agents.stage_constants import ( + allowed_slot_names_for_stage, chat_bucket, normalize_chat_stage, ) @@ -136,6 +137,8 @@ async def update_slot( fallback=current_from_db, log_context={"user_id": user_id}, ) + if slot_name not in allowed_slot_names_for_stage(stage_norm, current_from_db): + return coerce_memoir_state(state) slots = _slots_snapshot_for_merge( state.slots if isinstance(state.slots, dict) else None @@ -292,6 +295,8 @@ def update_slot_sync( fallback=current_from_db, log_context={"user_id": user_id}, ) + if slot_name not in allowed_slot_names_for_stage(stage_norm, current_from_db): + return coerce_memoir_state(state) slots = _slots_snapshot_for_merge( state.slots if isinstance(state.slots, dict) else None diff --git a/api/app/features/memoir/story_pipeline_sync.py b/api/app/features/memoir/story_pipeline_sync.py index d3338cd..93221c4 100644 --- a/api/app/features/memoir/story_pipeline_sync.py +++ b/api/app/features/memoir/story_pipeline_sync.py @@ -11,6 +11,7 @@ import re import time import uuid from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field from typing import Any from sqlalchemy import func, select @@ -20,6 +21,7 @@ from app.agents.memoir.narrative_agent import NarrativeAgent from app.agents.memoir.prompts import format_narrative_user_content from app.agents.memoir.story_route_agent import ( APPEND_FIRST_CHAPTER_CATEGORIES, + FALLBACK_NEW_STORY_REASONS, PLAN_BATCH_MAX_SEGMENTS, StoryBatchPlan, StoryRouteAgent, @@ -70,6 +72,23 @@ from app.features.story.sync_write import ( logger = get_logger(__name__) +@dataclass +class StoryPipelineResult: + """Phase2 故事管线结果。 + + - 正常写入:``deferred=False``,``chapter`` 非空。 + - 低置信延迟:``deferred=True``,``chapter`` 为 None;调用方应把 ``defer_segment_ids`` + 标记为延迟态,不要置 ``narrated/processed``,也不要触发后置任务。 + """ + + chapter: Chapter | None + needs_cover: bool + dispatch_ids: set[str] + deferred: bool = False + defer_reason: str | None = None + defer_segment_ids: list[str] = field(default_factory=list) + + def _dialogue_lineage_dict_for_segment_ids( category_segments: list, segment_ids: list[str], @@ -662,6 +681,7 @@ def _resolve_append_target( route_decision == "new_story" and chapter_category in APPEND_FIRST_CHAPTER_CATEGORIES and candidate_stories + and decision_source not in FALLBACK_NEW_STORY_REASONS and len(oral_norm) <= int(settings.memoir_story_route_append_guardrail_oral_chars) ): @@ -952,9 +972,10 @@ def run_story_pipeline_for_category_batch( memoir_correlation_id: str | None = None, llm_fast: Any | None = None, memory_evidence: dict | None = None, -) -> tuple[Chapter | None, bool, set[str]]: - """ - 返回 (chapter, needs_cover_enqueue, story_ids_to_dispatch_after_commit)。 +) -> StoryPipelineResult: + """运行某 chapter_category 的 Phase2 写入管线。 + + 返回 :class:`StoryPipelineResult`。低置信路由会被延迟而不创建 Story/Chapter。 """ pipeline_phase_timings: dict[str, float] = {} narrative_agent = NarrativeAgent() @@ -1074,8 +1095,46 @@ def run_story_pipeline_for_category_batch( valid_story_ids=valid_ids, story_meta=story_meta, ) + + single_route: Any = None + if plan is None: + single_route = route_agent.decide( + chapter_category=chapter_category, + chapter_title=title, + batch_transcript=route_transcript, + candidate_stories=candidates, + llm=llm_route, + valid_story_ids=valid_ids, + story_meta=story_meta, + ) pipeline_phase_timings["route"] = time.perf_counter() - _t0 + if ( + plan is None + and single_route is not None + and single_route.reason in FALLBACK_NEW_STORY_REASONS + and bool(settings.memoir_route_defer_enabled) + ): + defer_ids = [str(s.id) for s in category_segments] + logger.info( + "event=memoir_pipeline_route_deferred memoir_correlation_id={} user_id={} " + "chapter_category={} segment_count={} reason={} " + "msg=Phase2 路由低置信,本批 segment 进入延迟池", + memoir_correlation_id or "", + user_id, + chapter_category, + len(defer_ids), + single_route.reason, + ) + return StoryPipelineResult( + chapter=None, + needs_cover=False, + dispatch_ids=set(), + deferred=True, + defer_reason=str(single_route.reason), + defer_segment_ids=defer_ids, + ) + chapter = _ensure_chapter_record( session, user_id=user_id, @@ -1110,17 +1169,12 @@ def run_story_pipeline_for_category_batch( fidelity_llm=llm_fidelity, ) else: - route = route_agent.decide( - chapter_category=chapter_category, - chapter_title=title, - batch_transcript=route_transcript, - candidate_stories=candidates, - llm=llm_route, - valid_story_ids=valid_ids, - story_meta=story_meta, + route = single_route + decision_source = ( + route.reason + if route.reason in FALLBACK_NEW_STORY_REASONS + else ("fallback_no_llm" if not llm_route else "single_decide") ) - - decision_source = "fallback_no_llm" if not llm else "single_decide" target_story_id, existing_for_narrative, decision_source = ( _resolve_append_target( session, @@ -1191,4 +1245,8 @@ def run_story_pipeline_for_category_batch( timing_parts, ) - return chapter, needs_cover, dispatch_ids + return StoryPipelineResult( + chapter=chapter, + needs_cover=needs_cover, + dispatch_ids=dispatch_ids, + ) diff --git a/api/app/tasks/__init__.py b/api/app/tasks/__init__.py index efeae56..324da23 100644 --- a/api/app/tasks/__init__.py +++ b/api/app/tasks/__init__.py @@ -7,7 +7,6 @@ from .chapter_cover_tasks import generate_chapter_cover from .memoir_tasks import ( process_memoir_phase1, process_memoir_phase2, - process_memoir_segments, ) from .memory_compaction_tasks import memory_compaction_run from .story_image_tasks import generate_story_image @@ -16,7 +15,6 @@ __all__ = [ "celery_app", "process_memoir_phase1", "process_memoir_phase2", - "process_memoir_segments", "generate_chapter_cover", "generate_story_image", "memory_compaction_run", diff --git a/api/app/tasks/memoir_tasks.py b/api/app/tasks/memoir_tasks.py index 908b97d..440d08c 100644 --- a/api/app/tasks/memoir_tasks.py +++ b/api/app/tasks/memoir_tasks.py @@ -6,7 +6,7 @@ import asyncio import json import time import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Dict, List, Set import redis @@ -336,6 +336,133 @@ def _phase2_immediate_task_id(user_id: str, chapter_category: str) -> str: return f"phase2-immediate-{user_id}-{chapter_category}" +def _wake_deferred_segments_for_category( + db: Session, + user_id: str, + chapter_category: str, +) -> int: + """清空该用户某 chapter_category 下旧的 defer 元数据,让其与新素材一起重判。 + + 返回被唤醒的 segment 数量,仅用于日志。 + """ + user_convs = select(Conversation.id).where( + Conversation.user_id == user_id, + Conversation.deleted_at.is_(None), + ) + stmt = select(Segment).where( + Segment.conversation_id.in_(user_convs), + Segment.topic_category == chapter_category, + Segment.narrated.is_(False), + Segment.skip_narrative.is_(False), + Segment.narrative_deferred_until.isnot(None), + ) + rows = list(db.execute(stmt).scalars().all()) + if not rows: + return 0 + for seg in rows: + seg.narrative_deferred_until = None + seg.narrative_defer_count = 0 + seg.narrative_defer_reason = None + return len(rows) + + +def _persist_phase2_route_defer( + db: Session, + *, + user_id: str, + chapter_category: str, + task_id: str, + memoir_correlation_id: str | None, + defer_segment_ids: list[str], + defer_reason: str, + phase2_started: float, + pipeline_elapsed: float, + lock_elapsed: float, +) -> dict: + """把本批 segment 标记为延迟态,并按需再排一次 Phase2 timeout。 + + 返回 Celery 任务的 result dict(``status=deferred``)。 + """ + now_ts = datetime.now(timezone.utc) + max_attempts = int(settings.memoir_route_defer_max_attempts) + defer_seconds = float(settings.memoir_route_defer_seconds) + deferred_until_ts = now_ts + timedelta(seconds=max(defer_seconds, 1.0)) + + rows: list[Segment] = [] + if defer_segment_ids: + stmt = select(Segment).where(Segment.id.in_(list(defer_segment_ids))) + rows = list(db.execute(stmt).scalars().all()) + + saturated_segments = 0 + new_max_attempts_reached = False + for seg in rows: + prev_count = int(seg.narrative_defer_count or 0) + seg.narrative_defer_count = prev_count + 1 + seg.narrative_defer_reason = defer_reason + seg.narrative_last_attempt_at = now_ts + if seg.narrative_defer_count >= max_attempts: + seg.narrative_deferred_until = None + saturated_segments += 1 + new_max_attempts_reached = True + else: + seg.narrative_deferred_until = deferred_until_ts + + db.commit() + + next_task_id: str | None = None + if rows and not new_max_attempts_reached: + next_task_id = _schedule_phase2_timeout( + user_id, chapter_category, memoir_correlation_id + ) + + phase2_elapsed = time.perf_counter() - phase2_started + duration_ms = phase2_elapsed * 1000 + logger.info( + "event=memoir_phase2_route_deferred user_id={} task_id={} chapter_category={} " + "segment_count={} saturated_count={} reason={} memoir_correlation_id={} " + "lock_seconds={:.3f} pipeline_seconds={:.3f} " + "phase2_total_seconds={:.3f} duration_ms={:.1f} next_task_id={} " + "msg=Phase2 路由低置信,本批 segment 延迟", + user_id, + task_id, + chapter_category, + len(rows), + saturated_segments, + defer_reason, + memoir_correlation_id or "", + lock_elapsed, + pipeline_elapsed, + phase2_elapsed, + duration_ms, + next_task_id or "", + ) + merge_pipeline_run( + memoir_correlation_id, + { + "phase2": [ + { + "chapter_category": chapter_category, + "task_id": str(task_id), + "status": "deferred", + "detail": { + "segments": len(rows), + "reason": defer_reason, + "saturated_count": saturated_segments, + "next_task_id": next_task_id, + }, + } + ], + }, + ) + return { + "status": "deferred", + "chapter_category": chapter_category, + "segments": len(rows), + "reason": defer_reason, + "saturated_count": saturated_segments, + } + + def _schedule_phase2_timeout( user_id: str, chapter_category: str, memoir_correlation_id: str | None = None ) -> str | None: @@ -492,6 +619,7 @@ def process_memoir_phase2( Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) + now_utc = datetime.now(timezone.utc) stmt = ( select(Segment) .where( @@ -499,6 +627,10 @@ def process_memoir_phase2( Segment.topic_category == chapter_category, Segment.narrated.is_(False), Segment.skip_narrative.is_(False), + ( + Segment.narrative_deferred_until.is_(None) + | (Segment.narrative_deferred_until <= now_utc) + ), ) .order_by(Segment.created_at) ) @@ -606,7 +738,7 @@ def process_memoir_phase2( "relevant_stories": [], } pipeline_t0 = time.perf_counter() - chapter, needs_cover, disp = run_story_pipeline_for_category_batch( + pipeline_result = run_story_pipeline_for_category_batch( db, user_id=user_id, chapter_category=chapter_category, @@ -622,7 +754,24 @@ def process_memoir_phase2( memory_evidence=memory_evidence, ) pipeline_elapsed = time.perf_counter() - pipeline_t0 - story_dispatch_ids |= disp + + if pipeline_result.deferred: + deferred_response = _persist_phase2_route_defer( + db, + user_id=user_id, + chapter_category=chapter_category, + task_id=str(task_id), + memoir_correlation_id=cid, + defer_segment_ids=pipeline_result.defer_segment_ids, + defer_reason=pipeline_result.defer_reason or "unknown", + phase2_started=phase2_t0, + pipeline_elapsed=pipeline_elapsed, + lock_elapsed=lock_elapsed, + ) + return deferred_response + + chapter = pipeline_result.chapter + story_dispatch_ids |= pipeline_result.dispatch_ids db.flush() if chapter is None: logger.error( @@ -948,6 +1097,7 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): categories_for_phase2: Set[str] = set() phase2_immediate: list[str] = [] phase2_timeout: list[str] = [] + woke_up_by_category: dict[str, int] = {} for chapter_category, cat_segments in prepared.category_to_segments.items(): batch_non_skip = [ s @@ -956,6 +1106,11 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): ] if not batch_non_skip: continue + woke = _wake_deferred_segments_for_category( + db, user_id, chapter_category + ) + if woke: + woke_up_by_category[chapter_category] = woke max_chars = max( len((s.user_input_text or "").strip()) for s in batch_non_skip ) @@ -965,6 +1120,14 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): else: phase2_timeout.append(chapter_category) + if woke_up_by_category: + logger.info( + "event=memoir_phase1_wake_deferred user_id={} categories={} " + "msg=Phase1 新素材唤醒同类目延迟 segment", + user_id, + woke_up_by_category, + ) + db.commit() merge_pipeline_run( @@ -1080,11 +1243,6 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): _update_task_status_sync(user_id, task_id, "failure", {"error": str(e)}) raise self.retry(exc=e) from e - -# 兼容旧 Celery/文档入口名 -process_memoir_segments = process_memoir_phase1 - - @shared_task(bind=True, max_retries=3, default_retry_delay=30) def generate_chapter_content(self, user_id: str, stage: str, new_content: str): """ diff --git a/api/docs/本地开发环境配置.md b/api/docs/本地开发环境配置.md index 4655865..f70d054 100644 --- a/api/docs/本地开发环境配置.md +++ b/api/docs/本地开发环境配置.md @@ -251,9 +251,9 @@ asyncio.run(test()) ### 手动触发 Celery 任务 ```python -from app.tasks.memoir_tasks import process_memoir_segments +from app.tasks.memoir_tasks import process_memoir_phase1 # 同步调用(测试) -result = process_memoir_segments.delay("user_id", ["segment_id_1", "segment_id_2"]) +result = process_memoir_phase1.delay("user_id", ["segment_id_1", "segment_id_2"]) print(result.get(timeout=60)) ``` diff --git a/api/tests/test_image_prompt_policy.py b/api/tests/test_image_prompt_policy.py index e337c11..005e103 100644 --- a/api/tests/test_image_prompt_policy.py +++ b/api/tests/test_image_prompt_policy.py @@ -121,3 +121,44 @@ def test_cover_fallback_disabled_requires_excerpt(monkeypatch): chapter_category="family", context_excerpt="", ) + + +def test_image_prompt_orchestrator_provider_failure_uses_fallback(monkeypatch): + from app.agents.image_prompt.orchestrator import get_image_prompt_orchestrator + + class BoomGateway: + def langchain_llm_for(self, *_a, **_kw): # noqa: ANN001 + raise RuntimeError("provider missing") + + monkeypatch.setattr( + "app.agents.image_prompt.orchestrator.settings.image_prompt_fallback_disabled", + False, + ) + monkeypatch.setattr("app.core.llm_gateway.LlmGateway", lambda: BoomGateway()) + + orch = get_image_prompt_orchestrator() + out = orch.build_cover_prompt( + chapter_title="T", + chapter_category="family", + context_excerpt="mountain lake", + ) + assert "mountain lake" in out["prompt"].lower() + + +def test_image_prompt_orchestrator_provider_failure_raises_when_disabled( + monkeypatch, +): + from app.agents.image_prompt.orchestrator import get_image_prompt_orchestrator + + class BoomGateway: + def langchain_llm_for(self, *_a, **_kw): # noqa: ANN001 + raise RuntimeError("provider missing") + + monkeypatch.setattr( + "app.agents.image_prompt.orchestrator.settings.image_prompt_fallback_disabled", + True, + ) + monkeypatch.setattr("app.core.llm_gateway.LlmGateway", lambda: BoomGateway()) + + with pytest.raises(RuntimeError, match="provider missing"): + get_image_prompt_orchestrator() diff --git a/api/tests/test_memoir_pipeline_optimization.py b/api/tests/test_memoir_pipeline_optimization.py index 67c1c80..fa1f0dd 100644 --- a/api/tests/test_memoir_pipeline_optimization.py +++ b/api/tests/test_memoir_pipeline_optimization.py @@ -106,7 +106,10 @@ def test_orchestrator_fallback_to_sequential(monkeypatch: pytest.MonkeyPatch) -> orch._prepare_batches_via_batch_llm = fail_batch orch.extraction_agent.extract = MagicMock( - return_value=ExtractionResult(detected_stage="childhood", slots={"toy": "ball"}) + return_value=ExtractionResult( + detected_stage="childhood", + slots={"place": "潍坊"}, + ) ) orch.classification_agent.classify = MagicMock( return_value=ChapterClassifyResult(category="childhood", llm_said_none=False) @@ -134,6 +137,52 @@ def test_orchestrator_fallback_to_sequential(monkeypatch: pytest.MonkeyPatch) -> assert "s1" in result.segment_chapter_category +def test_orchestrator_sequential_filters_invalid_slots( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Sequential fallback should match batch path slot validation.""" + monkeypatch.setattr( + "app.agents.memoir.orchestrator.settings.memoir_phase1_batch_llm_enabled", + False, + ) + + orch = MemoirOrchestrator() + orch.extraction_agent.extract = MagicMock( + return_value=ExtractionResult( + detected_stage="childhood", + slots={"place": "潍坊", "hallucinated": "bad"}, + ) + ) + orch.classification_agent.classify = MagicMock( + return_value=ChapterClassifyResult(category="childhood", llm_said_none=False) + ) + + st = MemoirStateSchema( + stage_order=["childhood"], + current_stage="childhood", + covered_stages=[], + slots={}, + ) + calls: list[tuple] = [] + + class _Seg: + id = "s1" + user_input_text = "我小时候在潍坊。" + + def update_slot(*args): + calls.append(args) + return st + + orch.prepare_batches( + segments=[_Seg()], + llm=MagicMock(), + get_or_create_state=lambda: st, + update_slot=update_slot, + ) + + assert calls == [("childhood", "place", "潍坊", ["s1"])] + + # --------------------------------------------------------------------------- # Memory enrichment decoupled from ingest # --------------------------------------------------------------------------- @@ -216,6 +265,33 @@ def test_resolve_append_target_forced_new_on_overflow() -> None: assert dsrc == "forced_new_due_to_append_limit" +def test_resolve_append_target_does_not_guardrail_route_fallback() -> None: + """No-LLM / parse fallback new_story decisions must not append by recency.""" + from app.features.memoir.story_pipeline_sync import _resolve_append_target + + session = MagicMock() + candidate = MagicMock() + candidate.id = "story-1" + + tid, existing, dsrc = _resolve_append_target( + session, + route_decision="new_story", + route_target_story_id=None, + user_id="u1", + chapter_category="childhood", + oral_norm="short text", + candidate_stories=[candidate], + story_meta={"story-1": {"char_count": 10, "version_count": 1}}, + decision_source="no_llm", + memoir_correlation_id=None, + ) + + assert tid is None + assert existing == "" + assert dsrc == "no_llm" + session.get.assert_not_called() + + # --------------------------------------------------------------------------- # _run_post_pipeline_commit helper # --------------------------------------------------------------------------- diff --git a/api/tests/test_memoir_route_defer.py b/api/tests/test_memoir_route_defer.py new file mode 100644 index 0000000..6fa51d4 --- /dev/null +++ b/api/tests/test_memoir_route_defer.py @@ -0,0 +1,437 @@ +"""Phase2 路由低置信延迟管线:deferred 池 / 唤醒 / 重试上限。""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace +from unittest.mock import DEFAULT, MagicMock, patch + +import pytest +from sqlalchemy import create_engine, select +from sqlalchemy.orm import sessionmaker + +# 与 alembic/env.py 一致:注册全部 ORM,避免 relationship 解析失败 +from app.agents.memoir.story_route_agent import StoryRouteDecision +from app.agents.state_schema import MemoirStateSchema +from app.core.config import settings +from app.core.db import Base +from app.features.asset import models as _asset_models # noqa: F401 +from app.features.auth import models as _auth_models # noqa: F401 +from app.features.conversation import models as _conv_models # noqa: F401 +from app.features.conversation.models import Conversation, Segment +from app.features.memoir import models as _memoir_models # noqa: F401 +from app.features.memoir.story_pipeline_sync import ( + StoryPipelineResult, + run_story_pipeline_for_category_batch, +) +from app.features.memory import models as _memory_models # noqa: F401 +from app.features.payment import models as _payment_models # noqa: F401 +from app.features.story import models as _story_models # noqa: F401 +from app.features.user import models as _user_models # noqa: F401 +from app.features.user.models import User +from app.tasks.memoir_tasks import ( + _persist_phase2_route_defer, + _wake_deferred_segments_for_category, +) + + +@pytest.fixture +def sqlite_session_factory(): + engine = create_engine("sqlite:///:memory:", future=True) + Base.metadata.create_all( + engine, + tables=[ + User.__table__, + Conversation.__table__, + Segment.__table__, + ], + ) + yield sessionmaker(bind=engine, expire_on_commit=False, future=True) + engine.dispose() + + +def _seed_user_segment( + db, + *, + user_id: str, + conversation_id: str, + segment_id: str, + text: str = "我童年的事情很短暂", + topic_category: str = "childhood", +) -> Segment: + if not db.get(User, user_id): + db.add( + User( + id=user_id, + phone=f"p-{user_id[:8]}", + password_hash="x", + nickname="t", + ) + ) + if not db.get(Conversation, conversation_id): + db.add(Conversation(id=conversation_id, user_id=user_id)) + seg = Segment( + id=segment_id, + conversation_id=conversation_id, + user_input_text=text, + topic_category=topic_category, + narrated=False, + skip_narrative=False, + narrative_defer_count=0, + ) + db.add(seg) + db.commit() + return seg + + +def _patch_pipeline(plan_return, decide_return): + """统一 mock pipeline 内的 IO 与 LLM 依赖,便于聚焦路由分支。 + + 返回 ``(context_manager, route_agent_mock)``;进入 context 后由 ``patch.multiple`` + 生成的 mock dict 作为 ``mocks`` 提供给测试用例配置返回值与断言。 + """ + route_agent_mock = MagicMock() + route_agent_mock.plan_batch.return_value = plan_return + route_agent_mock.decide.return_value = decide_return + + return ( + patch.multiple( + "app.features.memoir.story_pipeline_sync", + list_active_stories_for_user_sync=DEFAULT, + StoryRouteAgent=DEFAULT, + NarrativeAgent=DEFAULT, + normalize_oral_for_memoir=DEFAULT, + ensure_chapter_story_link_sync=DEFAULT, + reorder_chapter_story_links_by_life_order_sync=DEFAULT, + mark_chapter_dirty_sync=DEFAULT, + chapter_needs_cover_enqueue=DEFAULT, + MemoirImageSettings=DEFAULT, + refresh_chapter_evidence_snapshot_with_retry_sync=DEFAULT, + create_story_with_version_sync=DEFAULT, + _ensure_chapter_record=DEFAULT, + ), + route_agent_mock, + ) + + +def _configure_pipeline_mocks(mocks: dict, route_agent_mock: MagicMock) -> None: + mocks["list_active_stories_for_user_sync"].return_value = [] + mocks["StoryRouteAgent"].return_value = route_agent_mock + mocks["normalize_oral_for_memoir"].side_effect = lambda text, **_: text + mocks["chapter_needs_cover_enqueue"].return_value = False + mocks["MemoirImageSettings"].from_env.return_value = MagicMock(enabled=False) + + +def _empty_state() -> MemoirStateSchema: + return MemoirStateSchema( + stage_order=["childhood"], + current_stage="childhood", + covered_stages=[], + slots={}, + ) + + +@pytest.mark.parametrize("reason", ["no_llm", "parse_error", "invalid_target"]) +def test_pipeline_defers_on_fallback_route_reason(reason: str) -> None: + """单段路由 fallback 时不写 chapter/story,返回 deferred 结果。""" + seg = SimpleNamespace(id="seg-defer-1", user_input_text="一句简短的口述") + decide_return = StoryRouteDecision( + decision="new_story", + new_story_title=None, + reason=reason, + ) + cm, route_agent_mock = _patch_pipeline( + plan_return=None, + decide_return=decide_return, + ) + with cm as mocks: + _configure_pipeline_mocks(mocks, route_agent_mock) + session = MagicMock() + exec_result = MagicMock() + exec_result.unique.return_value.scalar_one_or_none.return_value = None + session.execute.return_value = exec_result + + result = run_story_pipeline_for_category_batch( + session, + user_id="user-defer", + chapter_category="childhood", + category_segments=[seg], + state=_empty_state(), + user_profile="", + user_birth_year=None, + llm=object(), + memory_evidence={ + "relevant_chunks": [], + "relevant_summaries": [], + "relevant_facts": [], + "relevant_stories": [], + }, + ) + + assert isinstance(result, StoryPipelineResult) + assert result.deferred is True + assert result.chapter is None + assert result.dispatch_ids == set() + assert result.defer_reason == reason + assert result.defer_segment_ids == ["seg-defer-1"] + mocks["_ensure_chapter_record"].assert_not_called() + mocks["create_story_with_version_sync"].assert_not_called() + mocks["mark_chapter_dirty_sync"].assert_not_called() + route_agent_mock.decide.assert_called_once() + + +def test_pipeline_does_not_defer_when_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """关闭开关后,旧行为:直接写 new_story(不再延迟)。""" + monkeypatch.setattr(settings, "memoir_route_defer_enabled", False) + + seg = SimpleNamespace(id="seg-no-defer", user_input_text="一句简短的口述") + decide_return = StoryRouteDecision( + decision="new_story", + new_story_title=None, + reason="no_llm", + ) + cm, route_agent_mock = _patch_pipeline( + plan_return=None, + decide_return=decide_return, + ) + with cm as mocks: + _configure_pipeline_mocks(mocks, route_agent_mock) + chapter_stub = SimpleNamespace(id="chapter-1") + mocks["_ensure_chapter_record"].return_value = chapter_stub + story_stub = MagicMock() + story_stub.id = "story-x" + story_stub.current_version_id = None + mocks["create_story_with_version_sync"].return_value = story_stub + + # NarrativeAgent.generate_narrative 必须返回有效 JSON + nac_instance = mocks["NarrativeAgent"].return_value + nac_instance.generate_narrative.return_value = ( + '{"paragraphs": [{"content": "叙事正文段落足够长用于测试合并逻辑避免触发过短回退"}]}' + ) + + session = MagicMock() + exec_result = MagicMock() + exec_result.unique.return_value.scalar_one_or_none.return_value = None + session.execute.return_value = exec_result + + result = run_story_pipeline_for_category_batch( + session, + user_id="user-no-defer", + chapter_category="childhood", + category_segments=[seg], + state=_empty_state(), + user_profile="", + user_birth_year=None, + llm=object(), + memory_evidence={ + "relevant_chunks": [], + "relevant_summaries": [], + "relevant_facts": [], + "relevant_stories": [], + }, + ) + + assert isinstance(result, StoryPipelineResult) + assert result.deferred is False + assert result.chapter is chapter_stub + mocks["_ensure_chapter_record"].assert_called_once() + + +def test_pipeline_returns_result_object_for_normal_path() -> None: + """决策非 fallback 时,pipeline 仍按原路径执行并返回 StoryPipelineResult。""" + seg = SimpleNamespace(id="seg-ok", user_input_text="一段足够长的童年口述用于测试正常写入路径") + decide_return = StoryRouteDecision( + decision="new_story", + new_story_title="一个童年故事的新标题", + reason="ok", + ) + cm, route_agent_mock = _patch_pipeline( + plan_return=None, + decide_return=decide_return, + ) + with cm as mocks: + _configure_pipeline_mocks(mocks, route_agent_mock) + chapter_stub = SimpleNamespace(id="chapter-ok") + mocks["_ensure_chapter_record"].return_value = chapter_stub + story_stub = MagicMock() + story_stub.id = "story-ok" + story_stub.current_version_id = None + mocks["create_story_with_version_sync"].return_value = story_stub + + nac_instance = mocks["NarrativeAgent"].return_value + nac_instance.generate_narrative.return_value = ( + '{"paragraphs": [{"content": "叙事正文段落足够长用于测试合并逻辑避免触发过短回退"}]}' + ) + + session = MagicMock() + exec_result = MagicMock() + exec_result.unique.return_value.scalar_one_or_none.return_value = None + session.execute.return_value = exec_result + + result = run_story_pipeline_for_category_batch( + session, + user_id="user-ok", + chapter_category="childhood", + category_segments=[seg], + state=_empty_state(), + user_profile="", + user_birth_year=None, + llm=object(), + memory_evidence={ + "relevant_chunks": [], + "relevant_summaries": [], + "relevant_facts": [], + "relevant_stories": [], + }, + ) + + assert isinstance(result, StoryPipelineResult) + assert result.deferred is False + assert result.chapter is chapter_stub + + +def test_persist_phase2_route_defer_marks_segment_and_schedules_next( + sqlite_session_factory, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """首次延迟:写入 defer 元数据并安排下一次 timeout(未达上限)。""" + monkeypatch.setattr(settings, "memoir_route_defer_seconds", 30.0) + monkeypatch.setattr(settings, "memoir_route_defer_max_attempts", 3) + + db = sqlite_session_factory() + seg = _seed_user_segment( + db, + user_id="u-defer-1", + conversation_id=str(uuid.uuid4()), + segment_id="seg-defer-x1", + ) + + with patch( + "app.tasks.memoir_tasks._schedule_phase2_timeout", + return_value="task-id-next", + ) as schedule_mock: + out = _persist_phase2_route_defer( + db, + user_id="u-defer-1", + chapter_category="childhood", + task_id="task-id-current", + memoir_correlation_id="cid-1", + defer_segment_ids=[seg.id], + defer_reason="no_llm", + phase2_started=0.0, + pipeline_elapsed=0.0, + lock_elapsed=0.0, + ) + + assert out["status"] == "deferred" + assert out["segments"] == 1 + assert out["saturated_count"] == 0 + schedule_mock.assert_called_once_with("u-defer-1", "childhood", "cid-1") + + refreshed = db.execute(select(Segment).where(Segment.id == seg.id)).scalar_one() + assert refreshed.narrative_defer_count == 1 + assert refreshed.narrative_defer_reason == "no_llm" + assert refreshed.narrative_deferred_until is not None + assert refreshed.narrative_last_attempt_at is not None + assert refreshed.narrated is False + assert refreshed.processed is False + + +def test_persist_phase2_route_defer_stops_scheduling_at_max_attempts( + sqlite_session_factory, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """达到 max_attempts 后不再继续派发 timeout,segment 仍保留 defer 元数据。""" + monkeypatch.setattr(settings, "memoir_route_defer_seconds", 30.0) + monkeypatch.setattr(settings, "memoir_route_defer_max_attempts", 2) + + db = sqlite_session_factory() + seg = _seed_user_segment( + db, + user_id="u-defer-max", + conversation_id=str(uuid.uuid4()), + segment_id="seg-defer-max-1", + ) + seg.narrative_defer_count = 1 + db.commit() + + with patch( + "app.tasks.memoir_tasks._schedule_phase2_timeout", + return_value="should-not-be-called", + ) as schedule_mock: + out = _persist_phase2_route_defer( + db, + user_id="u-defer-max", + chapter_category="childhood", + task_id="task-id-current", + memoir_correlation_id="cid-2", + defer_segment_ids=[seg.id], + defer_reason="parse_error", + phase2_started=0.0, + pipeline_elapsed=0.0, + lock_elapsed=0.0, + ) + + assert out["status"] == "deferred" + assert out["saturated_count"] == 1 + schedule_mock.assert_not_called() + + refreshed = db.execute(select(Segment).where(Segment.id == seg.id)).scalar_one() + assert refreshed.narrative_defer_count == 2 + # 达上限后不设 deferred_until,需要等待新素材唤醒;此时 segment 仍可被下次 Phase2 消费 + assert refreshed.narrative_deferred_until is None + assert refreshed.narrative_defer_reason == "parse_error" + + +def test_wake_deferred_segments_clears_defer_metadata( + sqlite_session_factory, +) -> None: + """新素材到达时清空同类目下既有 defer 元数据,并保留另一类目不变。""" + db = sqlite_session_factory() + user_id = "u-wake" + conv_id = str(uuid.uuid4()) + seg_a = _seed_user_segment( + db, + user_id=user_id, + conversation_id=conv_id, + segment_id="seg-wake-1", + topic_category="childhood", + ) + seg_other = _seed_user_segment( + db, + user_id=user_id, + conversation_id=conv_id, + segment_id="seg-other", + topic_category="education", + ) + seg_a.narrative_defer_count = 2 + seg_a.narrative_defer_reason = "parse_error" + seg_a.narrative_deferred_until = datetime.now(timezone.utc) + timedelta(minutes=5) + seg_other.narrative_defer_count = 1 + seg_other.narrative_defer_reason = "no_llm" + seg_other.narrative_deferred_until = datetime.now(timezone.utc) + timedelta( + minutes=5 + ) + db.commit() + + woke = _wake_deferred_segments_for_category(db, user_id, "childhood") + db.commit() + + refreshed_a = db.execute( + select(Segment).where(Segment.id == seg_a.id) + ).scalar_one() + refreshed_other = db.execute( + select(Segment).where(Segment.id == seg_other.id) + ).scalar_one() + + assert woke == 1 + assert refreshed_a.narrative_deferred_until is None + assert refreshed_a.narrative_defer_count == 0 + assert refreshed_a.narrative_defer_reason is None + # 其它类目不应被波及 + assert refreshed_other.narrative_deferred_until is not None + assert refreshed_other.narrative_defer_count == 1 + assert refreshed_other.narrative_defer_reason == "no_llm" diff --git a/api/tests/test_memoir_two_phase.py b/api/tests/test_memoir_two_phase.py index 702c1fe..e1a6ef5 100644 --- a/api/tests/test_memoir_two_phase.py +++ b/api/tests/test_memoir_two_phase.py @@ -16,7 +16,7 @@ def test_segment_chapter_category_populated() -> None: orch = MemoirOrchestrator() orch.extraction_agent.extract = MagicMock( return_value=ExtractionResult( - detected_stage="childhood", slots={"toy": "布娃娃"} + detected_stage="childhood", slots={"daily_life": "玩布娃娃"} ) ) orch.classification_agent.classify = MagicMock( diff --git a/api/tests/test_profile_agent_gateway.py b/api/tests/test_profile_agent_gateway.py new file mode 100644 index 0000000..5e7fd71 --- /dev/null +++ b/api/tests/test_profile_agent_gateway.py @@ -0,0 +1,85 @@ +"""ProfileAgent LLM gateway injection regression tests.""" + +from __future__ import annotations + +import json +from types import SimpleNamespace + +import pytest + +from app.agents.chat.profile_agent import ProfileAgent + + +class _Response: + def __init__(self, content: str) -> None: + self.content = content + + +class _BoundJsonLlm: + async def ainvoke(self, _prompt: str) -> _Response: + return _Response( + json.dumps( + { + "birth_year": 1988, + "birth_place": "杭州", + "grew_up_place": "杭州", + "occupation": "工程师", + } + ) + ) + + +class _JsonLlm: + def bind(self, **_kwargs) -> _BoundJsonLlm: # noqa: ANN003 + return _BoundJsonLlm() + + +class _Provider: + langchain_llm = _JsonLlm() + + def __init__(self) -> None: + self.messages: list[dict] = [] + + async def complete(self, messages: list[dict], **_kwargs) -> str: # noqa: ANN003 + self.messages = messages + return "谢谢分享!还能再说说吗?" + + async def stream(self, *_args, **_kwargs): # noqa: ANN003 + if False: + yield "" + + +@pytest.mark.asyncio +async def test_profile_agent_llm_provider_injection_covers_chat_and_json( + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def fake_history(*_args, **_kwargs): + return SimpleNamespace(window=[], turn_total=0) + + monkeypatch.setattr( + "app.agents.chat.profile_agent.get_history_with_window", + fake_history, + ) + provider = _Provider() + agent = ProfileAgent(llm_provider=provider) + + extracted = await agent.extract_profile_from_message( + "我是一名工程师,1988 年出生在杭州。", + ["birth_year", "birth_place", "occupation"], + ) + followup = await agent.generate_profile_followup( + conversation_id="c1", + user_message="我在杭州长大。", + missing_fields=["grew_up_place"], + filled_fields={"birth_year": "1988"}, + ) + + assert extracted == { + "birth_year": 1988, + "birth_place": "杭州", + "grew_up_place": "杭州", + "occupation": "工程师", + } + assert followup + assert provider.messages + assert provider.messages[0]["role"] == "system" diff --git a/api/tests/test_stage_validation.py b/api/tests/test_stage_validation.py index 39c471a..11bd76c 100644 --- a/api/tests/test_stage_validation.py +++ b/api/tests/test_stage_validation.py @@ -8,6 +8,7 @@ from app.agents.memoir.extraction_agent import ExtractionAgent from app.agents.memoir.schemas import StateExtractionOutput from app.agents.stage_constants import ( chat_bucket, + filter_stage_slots, normalize_chapter_category, normalize_chat_stage, ) @@ -41,6 +42,13 @@ def test_chat_bucket() -> None: assert chat_bucket("beliefs") == "belief" +def test_filter_stage_slots_uses_canonical_keys() -> None: + assert filter_stage_slots( + "childhood", + {"place": "潍坊", "toy": "ball"}, + ) == {"place": "潍坊"} + + def test_extraction_agent_normalizes_detected_stage( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/api/tests/test_state_service_batch_stage_policy.py b/api/tests/test_state_service_batch_stage_policy.py index 704c8a8..eb6f4ef 100644 --- a/api/tests/test_state_service_batch_stage_policy.py +++ b/api/tests/test_state_service_batch_stage_policy.py @@ -155,7 +155,7 @@ def test_update_slot_sync_batch_flag_true_same_bucket_updates_row( update_slot_sync( uid, "career_achievement", - "peak", + "growth", "won prize", ["s2"], db, @@ -165,7 +165,32 @@ def test_update_slot_sync_batch_flag_true_same_bucket_updates_row( select(MemoirStateModel).where(MemoirStateModel.user_id == uid) ).scalar_one() assert st.current_stage == "career" - assert st.slots.get("career", {}).get("peak") is not None + assert st.slots.get("career", {}).get("growth") is not None + + +def test_update_slot_sync_ignores_invalid_slot_name( + sqlite_session_factory, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + uid = "u-invalid-slot" + db = sqlite_session_factory() + _add_user_and_state(db, user_id=uid, current_stage="childhood") + + update_slot_sync( + uid, + "childhood", + "made_up_key", + "bad", + ["s-bad"], + db, + memoir_batch=True, + ) + st = db.execute( + select(MemoirStateModel).where(MemoirStateModel.user_id == uid) + ).scalar_one() + assert "made_up_key" not in (st.slots.get("childhood") or {}) + assert st.current_stage == "childhood" def test_update_slot_sync_batch_flag_true_cross_bucket_unchanged( diff --git a/api/tests/test_story_route_prompts_and_behavior.py b/api/tests/test_story_route_prompts_and_behavior.py index 8b7616e..e5944d1 100644 --- a/api/tests/test_story_route_prompts_and_behavior.py +++ b/api/tests/test_story_route_prompts_and_behavior.py @@ -134,7 +134,7 @@ def test_decide_career_mock_llm_new_story_and_prompt_episodic(): assert "经历叙事" in captured["prompt"] -def test_decide_invalid_target_falls_back_to_default_append(): +def test_decide_invalid_target_falls_back_to_new_story(): def fake_llm_json(_llm, _prompt: str, _schema: object, **_kwargs): return StoryRouteDecision( decision="append_story", @@ -163,12 +163,12 @@ def test_decide_invalid_target_falls_back_to_default_append(): valid_story_ids={"good"}, story_meta={"good": {"char_count": 2, "version_count": 1}}, ) - assert d.decision == "append_story" - assert d.target_story_id == "good" - assert d.reason == "invalid_target_default_append" + assert d.decision == "new_story" + assert d.target_story_id is None + assert d.reason == "invalid_target" -def test_decide_no_llm_defaults_append_when_candidates_exist(): +def test_decide_no_llm_defaults_new_story_when_candidates_exist(): cand = SimpleNamespace( id="s-default", title="求学", @@ -186,8 +186,39 @@ def test_decide_no_llm_defaults_append_when_candidates_exist(): valid_story_ids={"s-default"}, story_meta={"s-default": {"char_count": 4, "version_count": 1}}, ) - assert d.decision == "append_story" - assert d.target_story_id == "s-default" + assert d.decision == "new_story" + assert d.target_story_id is None + assert d.reason == "no_llm" + + +def test_decide_parse_error_fallback_defaults_new_story(): + def fake_llm_json(_llm, _prompt: str, _schema: object, **kwargs): + return kwargs["fallback_factory"]() + + cand = SimpleNamespace( + id="s-default", + title="求学", + summary="y" * 40, + canonical_markdown="本科经历", + updated_at=datetime(2025, 2, 1, tzinfo=timezone.utc), + chapter_links=[], + ) + with patch( + "app.agents.memoir.story_route_agent.llm_json_call", + side_effect=fake_llm_json, + ): + d = StoryRouteAgent().decide( + chapter_category="education", + chapter_title="教育", + batch_transcript="后来又考研。", + candidate_stories=[cand], + llm=MagicMock(), + valid_story_ids={"s-default"}, + story_meta={"s-default": {"char_count": 4, "version_count": 1}}, + ) + assert d.decision == "new_story" + assert d.target_story_id is None + assert d.reason == "parse_error" def test_plan_batch_merges_consecutive_new_story_units():