feat(memoir): 回忆录分段两阶段管线(Phase1 分类 / Phase2 叙事)与配置、测试

This commit is contained in:
Kevin
2026-04-02 16:37:14 +08:00
parent 3ae39838c0
commit 6b930808a3
27 changed files with 1550 additions and 430 deletions

View File

@@ -14,8 +14,11 @@ from dataclasses import dataclass
from typing import Any
from app.agents.memoir.prompts import get_chapter_classification_json_prompt
from app.agents.stage_constants import CHAPTER_CATEGORIES
from app.agents.stage_constants import STAGE_TO_DEFAULT_CATEGORY
from app.agents.stage_constants import (
CHAPTER_CATEGORIES,
STAGE_KEYWORD_WEIGHTS,
STAGE_TO_DEFAULT_CATEGORY,
)
from app.core.json_utils import extract_json_payload
from app.core.langchain_llm import invoke_json_object
from app.core.logging import get_logger
@@ -40,21 +43,12 @@ _SHORT_HUKOU_STYLE = re.compile(
re.UNICODE,
)
# 5-stage 关键词(用于 LLM 失败时的兜底);注意勿含易与「仅年份句」共现的泛词,以免误推类别
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
def _detect_stage(text: str, fallback_stage: str) -> str:
"""根据关键词检测消息所属的 5-stage 阶段"""
"""根据关键词检测消息所属的 5-stage 阶段(与 stage_constants.STAGE_KEYWORD_WEIGHTS 同源;匹配方式为子串,非加权)。"""
message = (text or "").lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
for stage, pairs in STAGE_KEYWORD_WEIGHTS.items():
if any(word in message for word, _w in pairs):
return stage
return fallback_stage

View File

@@ -10,6 +10,7 @@ from dataclasses import dataclass
from typing import Any, Dict
from app.agents.memoir.prompts import get_state_extraction_prompt
from app.agents.stage_constants import normalize_chat_stage
from app.core.langchain_llm import invoke_json_object
from app.core.logging import get_logger
from app.core.json_utils import extract_json_payload
@@ -63,7 +64,11 @@ class ExtractionAgent:
agent="ExtractionAgent.extract",
)
parsed = json.loads(extract_json_payload(raw))
detected_stage = parsed.get("detected_stage", detected_stage)
raw_detected = parsed.get("detected_stage", detected_stage)
detected_stage = normalize_chat_stage(
str(raw_detected) if raw_detected is not None else None,
fallback=current_stage,
)
raw_slots = parsed.get("slots", {}) or {}
extracted_slots = {
k: v if isinstance(v, str) else str(v) for k, v in raw_slots.items()

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import json
from typing import Any, Dict, Optional
from app.agents.stage_constants import CHAPTER_CATEGORIES
from app.agents.memoir.prompts import (
get_creative_title_json_prompt,
get_narrative_json_prompt,
@@ -34,7 +35,7 @@ class NarrativeAgent:
) -> str:
"""生成创意标题。若无 LLM 则返回默认标题"""
if not llm:
return f"{stage} 回忆"
return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆"
try:
prompt = get_creative_title_json_prompt(
stage=stage,
@@ -53,10 +54,10 @@ class NarrativeAgent:
title = (data.get("title") or "").strip() if isinstance(data, dict) else ""
if title:
return title.strip('"')
return f"{stage} 回忆"
return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆"
except Exception as e:
logger.warning("NarrativeAgent 生成标题失败: {}", e)
return f"{stage} 回忆"
return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆"
def generate_narrative(
self,

View File

@@ -33,6 +33,8 @@ class PreparedMemoirBatches:
category_to_segments: Dict[str, List[Segment]]
#: segment id 在「LLM 判 none 且 extraction slots 为空」时加入batch 级短路见 memoir_tasks
segment_skip_story_ids: Set[str]
#: 每个 segment → Phase 1 分类 chapter_category持久化到 Segment.topic_category
segment_chapter_category: Dict[str, str]
class MemoirOrchestrator:
@@ -64,6 +66,7 @@ class MemoirOrchestrator:
state = get_or_create_state()
category_to_segments: Dict[str, List[Segment]] = {}
segment_skip_story_ids: Set[str] = set()
segment_chapter_category: Dict[str, str] = {}
classify_extract_llm = llm_fast if llm_fast is not None else llm
for segment in segments:
@@ -103,6 +106,7 @@ class MemoirOrchestrator:
chapter_category = classify_result.category
if (not result.slots) and classify_result.llm_said_none:
segment_skip_story_ids.add(str(segment.id))
segment_chapter_category[str(segment.id)] = chapter_category
if agent_summary_enabled():
logger.info(
@@ -126,6 +130,7 @@ class MemoirOrchestrator:
state=state,
category_to_segments=category_to_segments,
segment_skip_story_ids=segment_skip_story_ids,
segment_chapter_category=segment_chapter_category,
)
def run(