feat(memoir): 回忆录分段两阶段管线(Phase1 分类 / Phase2 叙事)与配置、测试
This commit is contained in:
@@ -14,8 +14,11 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.agents.memoir.prompts import get_chapter_classification_json_prompt
|
||||
from app.agents.stage_constants import CHAPTER_CATEGORIES
|
||||
from app.agents.stage_constants import STAGE_TO_DEFAULT_CATEGORY
|
||||
from app.agents.stage_constants import (
|
||||
CHAPTER_CATEGORIES,
|
||||
STAGE_KEYWORD_WEIGHTS,
|
||||
STAGE_TO_DEFAULT_CATEGORY,
|
||||
)
|
||||
from app.core.json_utils import extract_json_payload
|
||||
from app.core.langchain_llm import invoke_json_object
|
||||
from app.core.logging import get_logger
|
||||
@@ -40,21 +43,12 @@ _SHORT_HUKOU_STYLE = re.compile(
|
||||
re.UNICODE,
|
||||
)
|
||||
|
||||
# 5-stage 关键词(用于 LLM 失败时的兜底);注意勿含易与「仅年份句」共现的泛词,以免误推类别
|
||||
STAGE_KEYWORDS = {
|
||||
"childhood": ["童年", "小时候", "家乡", "小镇"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
||||
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
|
||||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
|
||||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
|
||||
}
|
||||
|
||||
|
||||
def _detect_stage(text: str, fallback_stage: str) -> str:
|
||||
"""根据关键词检测消息所属的 5-stage 阶段"""
|
||||
"""根据关键词检测消息所属的 5-stage 阶段(与 stage_constants.STAGE_KEYWORD_WEIGHTS 同源;匹配方式为子串,非加权)。"""
|
||||
message = (text or "").lower()
|
||||
for stage, keywords in STAGE_KEYWORDS.items():
|
||||
if any(word in message for word in keywords):
|
||||
for stage, pairs in STAGE_KEYWORD_WEIGHTS.items():
|
||||
if any(word in message for word, _w in pairs):
|
||||
return stage
|
||||
return fallback_stage
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.agents.memoir.prompts import get_state_extraction_prompt
|
||||
from app.agents.stage_constants import normalize_chat_stage
|
||||
from app.core.langchain_llm import invoke_json_object
|
||||
from app.core.logging import get_logger
|
||||
from app.core.json_utils import extract_json_payload
|
||||
@@ -63,7 +64,11 @@ class ExtractionAgent:
|
||||
agent="ExtractionAgent.extract",
|
||||
)
|
||||
parsed = json.loads(extract_json_payload(raw))
|
||||
detected_stage = parsed.get("detected_stage", detected_stage)
|
||||
raw_detected = parsed.get("detected_stage", detected_stage)
|
||||
detected_stage = normalize_chat_stage(
|
||||
str(raw_detected) if raw_detected is not None else None,
|
||||
fallback=current_stage,
|
||||
)
|
||||
raw_slots = parsed.get("slots", {}) or {}
|
||||
extracted_slots = {
|
||||
k: v if isinstance(v, str) else str(v) for k, v in raw_slots.items()
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.agents.stage_constants import CHAPTER_CATEGORIES
|
||||
from app.agents.memoir.prompts import (
|
||||
get_creative_title_json_prompt,
|
||||
get_narrative_json_prompt,
|
||||
@@ -34,7 +35,7 @@ class NarrativeAgent:
|
||||
) -> str:
|
||||
"""生成创意标题。若无 LLM 则返回默认标题"""
|
||||
if not llm:
|
||||
return f"{stage} 回忆"
|
||||
return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆"
|
||||
try:
|
||||
prompt = get_creative_title_json_prompt(
|
||||
stage=stage,
|
||||
@@ -53,10 +54,10 @@ class NarrativeAgent:
|
||||
title = (data.get("title") or "").strip() if isinstance(data, dict) else ""
|
||||
if title:
|
||||
return title.strip('"')
|
||||
return f"{stage} 回忆"
|
||||
return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆"
|
||||
except Exception as e:
|
||||
logger.warning("NarrativeAgent 生成标题失败: {}", e)
|
||||
return f"{stage} 回忆"
|
||||
return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆"
|
||||
|
||||
def generate_narrative(
|
||||
self,
|
||||
|
||||
@@ -33,6 +33,8 @@ class PreparedMemoirBatches:
|
||||
category_to_segments: Dict[str, List[Segment]]
|
||||
#: segment id 在「LLM 判 none 且 extraction slots 为空」时加入;batch 级短路见 memoir_tasks
|
||||
segment_skip_story_ids: Set[str]
|
||||
#: 每个 segment → Phase 1 分类 chapter_category(持久化到 Segment.topic_category)
|
||||
segment_chapter_category: Dict[str, str]
|
||||
|
||||
|
||||
class MemoirOrchestrator:
|
||||
@@ -64,6 +66,7 @@ class MemoirOrchestrator:
|
||||
state = get_or_create_state()
|
||||
category_to_segments: Dict[str, List[Segment]] = {}
|
||||
segment_skip_story_ids: Set[str] = set()
|
||||
segment_chapter_category: Dict[str, str] = {}
|
||||
classify_extract_llm = llm_fast if llm_fast is not None else llm
|
||||
|
||||
for segment in segments:
|
||||
@@ -103,6 +106,7 @@ class MemoirOrchestrator:
|
||||
chapter_category = classify_result.category
|
||||
if (not result.slots) and classify_result.llm_said_none:
|
||||
segment_skip_story_ids.add(str(segment.id))
|
||||
segment_chapter_category[str(segment.id)] = chapter_category
|
||||
|
||||
if agent_summary_enabled():
|
||||
logger.info(
|
||||
@@ -126,6 +130,7 @@ class MemoirOrchestrator:
|
||||
state=state,
|
||||
category_to_segments=category_to_segments,
|
||||
segment_skip_story_ids=segment_skip_story_ids,
|
||||
segment_chapter_category=segment_chapter_category,
|
||||
)
|
||||
|
||||
def run(
|
||||
|
||||
Reference in New Issue
Block a user