feat: 生成回忆录agent结构封装
This commit is contained in:
25
api/app/agents/memoir/__init__.py
Normal file
25
api/app/agents/memoir/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""回忆录模块:MemoryAgent、BackgroundTaskRunner、MemoirOrchestrator、各 Specialist Agent"""
|
||||
from app.agents.memoir.memory_agent import MemoryAgent
|
||||
from app.agents.memoir.processor import (
|
||||
BackgroundTaskRunner,
|
||||
ContentAnalyzer,
|
||||
MemoirGenerator,
|
||||
)
|
||||
from app.agents.memoir.orchestrator import MemoirOrchestrator
|
||||
from app.agents.memoir.extraction_agent import ExtractionAgent, ExtractionResult
|
||||
from app.agents.memoir.classification_agent import ClassificationAgent
|
||||
from app.agents.memoir.narrative_agent import NarrativeAgent
|
||||
from app.agents.memoir.placeholder_agent import inject_placeholders
|
||||
|
||||
__all__ = [
|
||||
"MemoryAgent",
|
||||
"BackgroundTaskRunner",
|
||||
"ContentAnalyzer",
|
||||
"MemoirGenerator",
|
||||
"MemoirOrchestrator",
|
||||
"ExtractionAgent",
|
||||
"ExtractionResult",
|
||||
"ClassificationAgent",
|
||||
"NarrativeAgent",
|
||||
"inject_placeholders",
|
||||
]
|
||||
77
api/app/agents/memoir/classification_agent.py
Normal file
77
api/app/agents/memoir/classification_agent.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
ClassificationAgent:将内容分类到 8 个章节类别,或判定无价值返回 None。
|
||||
对应现有逻辑:_classify_chapter_category
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from app.agents.prompts.memory_prompts import (
|
||||
CHAPTER_CATEGORIES,
|
||||
get_chapter_classification_prompt,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 5-stage 关键词(用于 LLM 失败时的兜底)
|
||||
STAGE_KEYWORDS = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
||||
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
|
||||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
|
||||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
|
||||
}
|
||||
|
||||
# 5-stage → 默认 8-category 映射(LLM 分类失败时的兜底)
|
||||
_STAGE_TO_DEFAULT_CATEGORY = {
|
||||
"childhood": "childhood",
|
||||
"education": "education",
|
||||
"career": "career_early",
|
||||
"family": "family",
|
||||
"belief": "beliefs",
|
||||
}
|
||||
|
||||
|
||||
def _detect_stage(text: str, fallback_stage: str) -> str:
|
||||
"""根据关键词检测消息所属的 5-stage 阶段"""
|
||||
message = (text or "").lower()
|
||||
for stage, keywords in STAGE_KEYWORDS.items():
|
||||
if any(word in message for word in keywords):
|
||||
return stage
|
||||
return fallback_stage
|
||||
|
||||
|
||||
class ClassificationAgent:
|
||||
"""将内容分类到 8 个章节类别之一,或判定无价值返回 None"""
|
||||
|
||||
def classify(
|
||||
self,
|
||||
text: str,
|
||||
fallback_stage: str,
|
||||
llm: Any,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
分类到 8 个章节类别之一。
|
||||
若 LLM 判定内容无实质回忆录价值,返回 None。
|
||||
llm 需支持 .invoke(prompt) 同步调用。
|
||||
"""
|
||||
if llm:
|
||||
try:
|
||||
prompt = get_chapter_classification_prompt(text)
|
||||
response = llm.invoke(prompt)
|
||||
category = (response.content or "").strip().lower()
|
||||
if category == "none":
|
||||
logger.info("LLM 判定内容无回忆录价值,跳过: %s...", (text or "")[:80])
|
||||
return None
|
||||
if category in CHAPTER_CATEGORIES:
|
||||
return category
|
||||
except Exception as e:
|
||||
logger.warning("ClassificationAgent LLM 章节分类失败: %s", e)
|
||||
|
||||
stage = _detect_stage(text, fallback_stage)
|
||||
return _STAGE_TO_DEFAULT_CATEGORY.get(
|
||||
stage,
|
||||
_STAGE_TO_DEFAULT_CATEGORY.get(fallback_stage, "childhood"),
|
||||
)
|
||||
66
api/app/agents/memoir/extraction_agent.py
Normal file
66
api/app/agents/memoir/extraction_agent.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
ExtractionAgent:从用户消息中提取 5-stage 状态与 slots。
|
||||
对应现有逻辑:get_state_extraction_prompt + JSON 解析
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.features.memoir.memoir_images.json_payload import extract_json_payload
|
||||
|
||||
from app.agents.prompts.memory_prompts import get_state_extraction_prompt
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""状态提取结果"""
|
||||
detected_stage: str
|
||||
slots: Dict[str, str]
|
||||
|
||||
|
||||
class ExtractionAgent:
|
||||
"""从用户消息中提取 detected_stage 和 slots"""
|
||||
|
||||
def extract(
|
||||
self,
|
||||
user_message: str,
|
||||
current_stage: str,
|
||||
stage_slots: Dict[str, Any],
|
||||
llm: Any,
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
提取结构化信息并判断阶段。
|
||||
llm 需支持 .invoke(prompt) 同步调用(Celery 任务内使用)。
|
||||
"""
|
||||
detected_stage = current_stage
|
||||
extracted_slots: Dict[str, str] = {}
|
||||
|
||||
if not llm:
|
||||
return ExtractionResult(detected_stage=detected_stage, slots=extracted_slots)
|
||||
|
||||
try:
|
||||
prompt = get_state_extraction_prompt(
|
||||
user_message=user_message,
|
||||
current_stage=current_stage,
|
||||
stage_slots={
|
||||
k: v.model_dump() if hasattr(v, "model_dump") else v
|
||||
for k, v in (stage_slots or {}).items()
|
||||
},
|
||||
)
|
||||
response = llm.invoke(prompt)
|
||||
parsed = json.loads(extract_json_payload(response.content))
|
||||
detected_stage = parsed.get("detected_stage", detected_stage)
|
||||
raw_slots = parsed.get("slots", {}) or {}
|
||||
extracted_slots = {
|
||||
k: v if isinstance(v, str) else str(v)
|
||||
for k, v in raw_slots.items()
|
||||
}
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.warning("ExtractionAgent LLM 解析失败: %s", e)
|
||||
|
||||
return ExtractionResult(detected_stage=detected_stage, slots=extracted_slots)
|
||||
130
api/app/agents/memoir/memory_agent.py
Normal file
130
api/app/agents/memoir/memory_agent.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
回忆录整理 Agent:基于传记结构,将口语改写为书面语,归类到章节
|
||||
支持异步调用
|
||||
"""
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.core.dependencies import get_llm_provider
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from app.agents.prompts import (
|
||||
get_chapter_classification_prompt,
|
||||
get_text_rewrite_prompt,
|
||||
inject_image_placeholder_template,
|
||||
CHAPTER_CATEGORIES,
|
||||
STAGE_TO_ORDER,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_langchain_llm():
|
||||
try:
|
||||
provider = get_llm_provider()
|
||||
return getattr(provider, "langchain_llm", None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class MemoryAgent:
|
||||
"""回忆录整理 Agent(支持异步)"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = _get_langchain_llm()
|
||||
|
||||
async def classify_chapter(self, segments_text: str) -> str:
|
||||
if not self.llm:
|
||||
return "childhood"
|
||||
try:
|
||||
prompt = get_chapter_classification_prompt(segments_text)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
category = content.strip().lower()
|
||||
if category in CHAPTER_CATEGORIES:
|
||||
return category
|
||||
except Exception as e:
|
||||
logger.error("分类章节失败: %s", e)
|
||||
return "childhood"
|
||||
|
||||
async def rewrite_to_literary(
|
||||
self,
|
||||
segments_text: str,
|
||||
chapter_category: str,
|
||||
existing_content: Optional[str] = None,
|
||||
) -> Dict:
|
||||
if not self.llm:
|
||||
return {
|
||||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||||
"content": segments_text,
|
||||
"summary": "",
|
||||
"image_suggestions": [],
|
||||
}
|
||||
try:
|
||||
prompt = get_text_rewrite_prompt(
|
||||
segments_text, chapter_category, existing_content or ""
|
||||
)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
content = content.strip()
|
||||
if content.startswith("```json"):
|
||||
content = content[7:]
|
||||
if content.startswith("```"):
|
||||
content = content[3:]
|
||||
if content.endswith("```"):
|
||||
content = content[:-3]
|
||||
content = content.strip()
|
||||
result = json.loads(content)
|
||||
result["content"] = inject_image_placeholder_template(
|
||||
result.get("content") or ""
|
||||
)
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
raw = response.content if hasattr(response, "content") else str(response)
|
||||
return {
|
||||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||||
"content": inject_image_placeholder_template(raw),
|
||||
"summary": "",
|
||||
"image_suggestions": [],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("改写文本失败: %s", e)
|
||||
return {
|
||||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||||
"content": segments_text,
|
||||
"summary": "",
|
||||
"image_suggestions": [],
|
||||
}
|
||||
|
||||
async def process_segments(
|
||||
self,
|
||||
segments: List[Dict],
|
||||
existing_chapters: Optional[Dict[str, Dict]] = None,
|
||||
) -> Dict[str, Dict]:
|
||||
if existing_chapters is None:
|
||||
existing_chapters = {}
|
||||
segments_by_category: Dict[str, List[str]] = {}
|
||||
for segment in segments:
|
||||
text = segment.get("transcript_text", "")
|
||||
if not text:
|
||||
continue
|
||||
category = await self.classify_chapter(text)
|
||||
if category not in segments_by_category:
|
||||
segments_by_category[category] = []
|
||||
segments_by_category[category].append(text)
|
||||
updated_chapters = existing_chapters.copy()
|
||||
for category, texts in segments_by_category.items():
|
||||
combined_text = "\n\n".join(texts)
|
||||
existing_content = existing_chapters.get(category, {}).get("content", "")
|
||||
result = await self.rewrite_to_literary(
|
||||
combined_text, category, existing_content
|
||||
)
|
||||
updated_chapters[category] = {
|
||||
"title": result.get("title", CHAPTER_CATEGORIES.get(category, "章节")),
|
||||
"content": result.get("content", ""),
|
||||
"summary": result.get("summary", ""),
|
||||
"image_suggestions": result.get("image_suggestions", []),
|
||||
"category": category,
|
||||
"order_index": STAGE_TO_ORDER.get(category, 999),
|
||||
}
|
||||
return updated_chapters
|
||||
78
api/app/agents/memoir/narrative_agent.py
Normal file
78
api/app/agents/memoir/narrative_agent.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
NarrativeAgent:生成创意标题和叙事改写。
|
||||
对应现有逻辑:get_creative_title_prompt、get_narrative_prompt
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from app.agents.prompts.memory_prompts import (
|
||||
get_creative_title_prompt,
|
||||
get_narrative_prompt,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class NarrativeAgent:
|
||||
"""生成章节标题和叙事正文"""
|
||||
|
||||
def generate_title(
|
||||
self,
|
||||
stage: str,
|
||||
emotion: str,
|
||||
slots: Dict[str, str],
|
||||
user_profile: str = "",
|
||||
birth_year: Optional[int] = None,
|
||||
llm: Any = None,
|
||||
) -> str:
|
||||
"""生成创意标题。若无 LLM 则返回默认标题"""
|
||||
if not llm:
|
||||
return f"{stage} 回忆"
|
||||
try:
|
||||
prompt = get_creative_title_prompt(
|
||||
stage=stage,
|
||||
emotion=emotion,
|
||||
slots=slots,
|
||||
user_profile=user_profile,
|
||||
birth_year=birth_year,
|
||||
)
|
||||
response = llm.invoke(prompt)
|
||||
return (response.content or "").strip().strip('"')
|
||||
except Exception as e:
|
||||
logger.warning("NarrativeAgent 生成标题失败: %s", e)
|
||||
return f"{stage} 回忆"
|
||||
|
||||
def generate_narrative(
|
||||
self,
|
||||
stage: str,
|
||||
slots: Dict[str, str],
|
||||
new_content: str,
|
||||
existing_content: str = "",
|
||||
user_profile: str = "",
|
||||
birth_year: Optional[int] = None,
|
||||
llm: Any = None,
|
||||
) -> str:
|
||||
"""将新对话改写为叙述。若无 LLM 则直接拼接"""
|
||||
if not llm:
|
||||
if existing_content:
|
||||
return f"{existing_content}\n\n{new_content}"
|
||||
return new_content
|
||||
try:
|
||||
prompt = get_narrative_prompt(
|
||||
stage=stage,
|
||||
slots=slots,
|
||||
new_content=new_content,
|
||||
existing_content=existing_content,
|
||||
user_profile=user_profile,
|
||||
birth_year=birth_year,
|
||||
)
|
||||
response = llm.invoke(prompt)
|
||||
return (response.content or "").strip()
|
||||
except Exception as e:
|
||||
logger.warning("NarrativeAgent 生成叙事失败: %s", e)
|
||||
if existing_content:
|
||||
return f"{existing_content}\n\n{new_content}"
|
||||
return new_content
|
||||
124
api/app/agents/memoir/orchestrator.py
Normal file
124
api/app/agents/memoir/orchestrator.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
MemoirOrchestrator:按 segment 编排流水线,调用各 Specialist Agent。
|
||||
负责:遍历 segments、按 category 聚合、调用 Specialist、更新 state;
|
||||
持久化与章节生成由 process_category 回调完成。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.features.conversation.models import Segment
|
||||
from app.agents.state_schema import MemoirStateSchema
|
||||
|
||||
from app.agents.memoir.extraction_agent import ExtractionAgent, ExtractionResult
|
||||
from app.agents.memoir.classification_agent import (
|
||||
ClassificationAgent,
|
||||
_detect_stage as detect_stage_from_keywords,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoirOrchestrator:
|
||||
"""
|
||||
回忆录生成编排器。
|
||||
遍历 segments → ExtractionAgent → ClassificationAgent → 按 category 聚合 →
|
||||
调用 process_category 生成叙事并持久化。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.extraction_agent = ExtractionAgent()
|
||||
self.classification_agent = ClassificationAgent()
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
segments: List[Segment],
|
||||
llm: Any,
|
||||
user_profile: str = "",
|
||||
user_birth_year: Any = None,
|
||||
get_or_create_state: Callable[[], MemoirStateSchema],
|
||||
update_slot: Callable[
|
||||
[str, str, str, List[str]], MemoirStateSchema
|
||||
],
|
||||
acquire_lock: Callable[[str], bool],
|
||||
release_lock: Callable[[str], None],
|
||||
process_category: Callable[
|
||||
[
|
||||
str,
|
||||
List[Segment],
|
||||
MemoirStateSchema,
|
||||
str,
|
||||
Any,
|
||||
Any,
|
||||
],
|
||||
Tuple[Any, bool],
|
||||
],
|
||||
raise_retry: Callable[[], None],
|
||||
) -> Tuple[Set[str], int]:
|
||||
"""
|
||||
执行回忆录流水线。
|
||||
process_category(category, segments, state, user_profile, user_birth_year, llm)
|
||||
返回 (chapter, has_images_to_generate)。
|
||||
返回 (chapters_to_enqueue, processed_count)。
|
||||
raise_retry 用于锁竞争时抛出 Celery retry。
|
||||
"""
|
||||
state = get_or_create_state()
|
||||
chapters_to_enqueue: Set[str] = set()
|
||||
category_to_segments: Dict[str, List[Segment]] = {}
|
||||
|
||||
# 1) 遍历 segments:ExtractionAgent → 更新 slots;ClassificationAgent → 聚合
|
||||
for segment in segments:
|
||||
text = segment.transcript_text or ""
|
||||
# 关键词预检测阶段,用于 slot 查找(与原有逻辑一致)
|
||||
initial_stage = detect_stage_from_keywords(
|
||||
text, state.current_stage or "childhood"
|
||||
)
|
||||
stage_slots_raw = state.slots.get(initial_stage, {}) or {}
|
||||
|
||||
result: ExtractionResult = self.extraction_agent.extract(
|
||||
user_message=text,
|
||||
current_stage=state.current_stage or "childhood",
|
||||
stage_slots=stage_slots_raw,
|
||||
llm=llm,
|
||||
)
|
||||
detected_stage = result.detected_stage
|
||||
for slot_name, snippet in result.slots.items():
|
||||
state = update_slot(detected_stage, slot_name, snippet, [segment.id])
|
||||
|
||||
# ClassificationAgent
|
||||
chapter_category = self.classification_agent.classify(
|
||||
text=text,
|
||||
fallback_stage=detected_stage,
|
||||
llm=llm,
|
||||
)
|
||||
if chapter_category is None:
|
||||
logger.info("段落无回忆录价值,跳过: segment_id=%s", segment.id)
|
||||
continue
|
||||
category_to_segments.setdefault(chapter_category, []).append(segment)
|
||||
|
||||
# 2) 按 category 调用 process_category:内含 NarrativeAgent、PlaceholderInject、持久化
|
||||
for chapter_category, category_segments in category_to_segments.items():
|
||||
if not acquire_lock(chapter_category):
|
||||
logger.warning(
|
||||
"章节锁竞争: category=%s, 延迟重试",
|
||||
chapter_category,
|
||||
)
|
||||
raise_retry()
|
||||
|
||||
try:
|
||||
chapter, has_images = process_category(
|
||||
chapter_category,
|
||||
category_segments,
|
||||
state,
|
||||
user_profile,
|
||||
user_birth_year,
|
||||
llm,
|
||||
)
|
||||
if chapter and has_images:
|
||||
chapters_to_enqueue.add(chapter.id)
|
||||
finally:
|
||||
release_lock(chapter_category)
|
||||
|
||||
return chapters_to_enqueue, len(segments)
|
||||
14
api/app/agents/memoir/placeholder_agent.py
Normal file
14
api/app/agents/memoir/placeholder_agent.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
PlaceholderInjectAgent:对 narrative 做占位符模板注入。
|
||||
对应现有逻辑:inject_image_placeholder_template
|
||||
纯函数式,无 LLM 调用。
|
||||
"""
|
||||
from app.agents.prompts.memory_prompts import inject_image_placeholder_template
|
||||
|
||||
|
||||
def inject_placeholders(content: str) -> str:
|
||||
"""
|
||||
对章节正文做占位符处理:匹配所有图片占位符,拼上固定模板。
|
||||
与 inject_image_placeholder_template 行为一致。
|
||||
"""
|
||||
return inject_image_placeholder_template(content)
|
||||
212
api/app/agents/memoir/processor.py
Normal file
212
api/app/agents/memoir/processor.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
回忆录后台处理器:分析对话、更新状态、生成章节、创意标题
|
||||
使用 Celery 进行后台任务处理
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from app.core.dependencies import get_llm_provider
|
||||
from app.core.logging import get_logger
|
||||
from app.core.task_tracker import task_tracker
|
||||
|
||||
from app.agents.state_schema import MemoirStateSchema
|
||||
from app.agents.prompts.memory_prompts import (
|
||||
get_creative_title_prompt,
|
||||
get_narrative_prompt,
|
||||
get_state_extraction_prompt,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
STAGE_KEYWORDS = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
||||
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
|
||||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
|
||||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
|
||||
}
|
||||
|
||||
|
||||
def _get_langchain_llm():
|
||||
try:
|
||||
provider = get_llm_provider()
|
||||
return getattr(provider, "langchain_llm", None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
detected_stage: str
|
||||
extracted_slots: Dict[str, str]
|
||||
emotion: str
|
||||
is_new_chapter: bool
|
||||
|
||||
|
||||
class ContentAnalyzer:
|
||||
def __init__(self) -> None:
|
||||
self.llm = _get_langchain_llm()
|
||||
|
||||
def _detect_stage(self, user_message: str, fallback_stage: str) -> str:
|
||||
message = user_message.lower()
|
||||
for stage, keywords in STAGE_KEYWORDS.items():
|
||||
if any(word in message for word in keywords):
|
||||
return stage
|
||||
return fallback_stage
|
||||
|
||||
def _fallback_slots(
|
||||
self, state: MemoirStateSchema, stage: str, user_message: str
|
||||
) -> Dict[str, str]:
|
||||
stage_slots = state.slots.get(stage, {})
|
||||
for key, value in stage_slots.items():
|
||||
if not value.snippet:
|
||||
return {key: user_message.strip()[:200]}
|
||||
return {}
|
||||
|
||||
async def analyze_message(
|
||||
self, user_message: str, current_state: MemoirStateSchema
|
||||
) -> AnalysisResult:
|
||||
detected_stage = self._detect_stage(
|
||||
user_message, current_state.current_stage
|
||||
)
|
||||
extracted_slots: Dict[str, str] = {}
|
||||
emotion = "neutral"
|
||||
is_new_chapter = False
|
||||
if self.llm:
|
||||
try:
|
||||
prompt = get_state_extraction_prompt(
|
||||
user_message=user_message,
|
||||
current_stage=current_state.current_stage,
|
||||
stage_slots=current_state.slots.get(detected_stage, {}),
|
||||
)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content.strip()
|
||||
parsed = json.loads(content)
|
||||
detected_stage = parsed.get("detected_stage", detected_stage)
|
||||
extracted_slots = parsed.get("slots", {}) or {}
|
||||
emotion = parsed.get("emotion", emotion)
|
||||
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
|
||||
except json.JSONDecodeError:
|
||||
extracted_slots = self._fallback_slots(
|
||||
current_state, detected_stage, user_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("分析消息失败: %s", e)
|
||||
extracted_slots = self._fallback_slots(
|
||||
current_state, detected_stage, user_message
|
||||
)
|
||||
else:
|
||||
extracted_slots = self._fallback_slots(
|
||||
current_state, detected_stage, user_message
|
||||
)
|
||||
return AnalysisResult(
|
||||
detected_stage=detected_stage,
|
||||
extracted_slots=extracted_slots,
|
||||
emotion=emotion,
|
||||
is_new_chapter=is_new_chapter,
|
||||
)
|
||||
|
||||
|
||||
class MemoirGenerator:
|
||||
def __init__(self) -> None:
|
||||
self.llm = _get_langchain_llm()
|
||||
|
||||
async def generate_chapter_title(
|
||||
self, stage: str, slots: Dict[str, str], emotion: str
|
||||
) -> str:
|
||||
if not self.llm:
|
||||
return f"{stage} 回忆"
|
||||
try:
|
||||
prompt = get_creative_title_prompt(
|
||||
stage=stage, emotion=emotion, slots=slots
|
||||
)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
return response.content.strip().strip('"')
|
||||
except Exception as e:
|
||||
logger.error("生成标题失败: %s", e)
|
||||
return f"{stage} 回忆"
|
||||
|
||||
async def generate_narrative(
|
||||
self,
|
||||
stage: str,
|
||||
slots: Dict[str, str],
|
||||
new_content: str,
|
||||
existing_content: str,
|
||||
) -> str:
|
||||
if not self.llm:
|
||||
if existing_content:
|
||||
return f"{existing_content}\n\n{new_content}"
|
||||
return new_content
|
||||
try:
|
||||
prompt = get_narrative_prompt(
|
||||
stage=stage,
|
||||
slots=slots,
|
||||
new_content=new_content,
|
||||
existing_content=existing_content,
|
||||
)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
return response.content.strip()
|
||||
except Exception as e:
|
||||
logger.error("生成叙事失败: %s", e)
|
||||
if existing_content:
|
||||
return f"{existing_content}\n\n{new_content}"
|
||||
return new_content
|
||||
|
||||
|
||||
class BackgroundTaskRunner:
|
||||
def __init__(self, debounce_seconds: int = 5) -> None:
|
||||
self.debounce_seconds = debounce_seconds
|
||||
self._pending: Dict[str, List[str]] = {}
|
||||
self._timers: Dict[str, object] = {}
|
||||
self.analyzer = ContentAnalyzer()
|
||||
self.generator = MemoirGenerator()
|
||||
|
||||
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
|
||||
try:
|
||||
from app.tasks.memoir_tasks import process_memoir_segments
|
||||
|
||||
result = process_memoir_segments.delay(user_id, segment_ids)
|
||||
task_id = result.id
|
||||
await task_tracker.add_task(user_id, task_id, "memoir")
|
||||
logger.info(
|
||||
"已提交 Celery 任务: user_id=%s, task_id=%s, segments=%s",
|
||||
user_id,
|
||||
task_id,
|
||||
len(segment_ids),
|
||||
)
|
||||
return task_id
|
||||
except Exception as e:
|
||||
logger.error("提交 Celery 任务失败: %s", e)
|
||||
return None
|
||||
|
||||
async def queue_message(self, user_id: str, segment_id: str) -> None:
|
||||
import asyncio
|
||||
|
||||
self._pending.setdefault(user_id, []).append(segment_id)
|
||||
if user_id in self._timers:
|
||||
self._timers[user_id].cancel()
|
||||
|
||||
async def delayed_submit():
|
||||
try:
|
||||
await asyncio.sleep(self.debounce_seconds)
|
||||
segment_ids = self._pending.pop(user_id, [])
|
||||
if segment_ids:
|
||||
await self._submit_task(user_id, segment_ids)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("延迟提交任务失败: %s", e)
|
||||
|
||||
self._timers[user_id] = asyncio.create_task(delayed_submit())
|
||||
|
||||
async def flush_pending(self, user_id: str) -> str | None:
|
||||
if user_id in self._timers:
|
||||
self._timers[user_id].cancel()
|
||||
del self._timers[user_id]
|
||||
segment_ids = self._pending.pop(user_id, [])
|
||||
if segment_ids:
|
||||
return await self._submit_task(user_id, segment_ids)
|
||||
return None
|
||||
@@ -27,23 +27,20 @@ from app.features.user.models import User
|
||||
from app.core.dependencies import get_llm_provider
|
||||
from app.agents.state_schema import MemoirStateSchema, SlotData, default_state
|
||||
from app.agents.prompts.memory_prompts import (
|
||||
get_creative_title_prompt,
|
||||
get_narrative_prompt,
|
||||
get_state_extraction_prompt,
|
||||
get_chapter_classification_prompt,
|
||||
inject_image_placeholder_template,
|
||||
STAGE_TO_ORDER,
|
||||
CHAPTER_CATEGORIES,
|
||||
get_narrative_prompt,
|
||||
inject_image_placeholder_template,
|
||||
)
|
||||
from app.agents.memoir import MemoirOrchestrator
|
||||
from app.agents.memoir.narrative_agent import NarrativeAgent
|
||||
from app.agents.memoir.placeholder_agent import inject_placeholders
|
||||
from app.agents.prompts.profile_prompts import format_user_profile_context
|
||||
import hashlib
|
||||
|
||||
from app.features.memoir.memoir_images.parser import (
|
||||
build_initial_image_assets,
|
||||
parse_image_placeholders,
|
||||
split_narrative_to_sections,
|
||||
)
|
||||
from app.features.memoir.memoir_images.json_payload import extract_json_payload
|
||||
import hashlib
|
||||
from app.core.dependencies import get_image_generator
|
||||
from app.features.memoir.memoir_images.prompting import MemoirImagePromptService
|
||||
from app.features.memoir.memoir_images.schema import (
|
||||
@@ -469,56 +466,6 @@ def _normalize_image_bytes_for_storage(image_bytes: bytes) -> bytes:
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
STAGE_KEYWORDS = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
||||
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
|
||||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
|
||||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
|
||||
}
|
||||
|
||||
# 5-stage → 默认 8-category 映射(LLM 分类失败时的兜底)
|
||||
_STAGE_TO_DEFAULT_CATEGORY = {
|
||||
"childhood": "childhood",
|
||||
"education": "education",
|
||||
"career": "career_early",
|
||||
"family": "family",
|
||||
"belief": "beliefs",
|
||||
}
|
||||
|
||||
|
||||
def _detect_stage(user_message: str, fallback_stage: str) -> str:
|
||||
"""检测消息所属的 5-stage 阶段(用于状态跟踪)"""
|
||||
message = user_message.lower()
|
||||
for stage, keywords in STAGE_KEYWORDS.items():
|
||||
if any(word in message for word in keywords):
|
||||
return stage
|
||||
return fallback_stage
|
||||
|
||||
|
||||
def _classify_chapter_category(text: str, fallback_stage: str, llm=None) -> str | None:
|
||||
"""
|
||||
将内容分类到 8 个章节类别之一。
|
||||
优先使用 LLM,失败则按 5-stage 关键词映射到默认类别。
|
||||
如果 LLM 判定内容无实质回忆录价值,返回 None。
|
||||
"""
|
||||
if llm:
|
||||
try:
|
||||
prompt = get_chapter_classification_prompt(text)
|
||||
response = llm.invoke(prompt)
|
||||
category = response.content.strip().lower()
|
||||
if category == "none":
|
||||
logger.info(f"LLM 判定内容无回忆录价值,跳过: {text[:80]}...")
|
||||
return None
|
||||
if category in CHAPTER_CATEGORIES:
|
||||
return category
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM 章节分类失败: {e}")
|
||||
|
||||
stage = _detect_stage(text, fallback_stage)
|
||||
return _STAGE_TO_DEFAULT_CATEGORY.get(stage, _STAGE_TO_DEFAULT_CATEGORY.get(fallback_stage, "childhood"))
|
||||
|
||||
|
||||
def _coerce_state(model: MemoirState) -> MemoirStateSchema:
|
||||
"""将数据库模型转换为 Schema"""
|
||||
return MemoirStateSchema.model_validate(
|
||||
@@ -628,174 +575,141 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
grew_up_place=user_obj.grew_up_place,
|
||||
occupation=user_obj.occupation,
|
||||
)
|
||||
|
||||
# 分两步处理:
|
||||
# 1) 5-stage 状态跟踪(slots)
|
||||
# 2) 8-category 章节分类(chapter creation)
|
||||
category_to_segments: Dict[str, List[Segment]] = {}
|
||||
|
||||
for segment in segments:
|
||||
text = segment.transcript_text
|
||||
detected_stage = _detect_stage(text, state.current_stage)
|
||||
narrative_agent = NarrativeAgent()
|
||||
|
||||
# 提取 slots(5-stage 状态跟踪)
|
||||
extracted_slots = {}
|
||||
if llm:
|
||||
try:
|
||||
prompt = get_state_extraction_prompt(
|
||||
user_message=text,
|
||||
current_stage=state.current_stage,
|
||||
stage_slots=state.slots.get(detected_stage, {}),
|
||||
)
|
||||
response = llm.invoke(prompt)
|
||||
parsed = json.loads(extract_json_payload(response.content))
|
||||
detected_stage = parsed.get("detected_stage", detected_stage)
|
||||
extracted_slots = parsed.get("slots", {}) or {}
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.warning(f"LLM 解析失败: {e}")
|
||||
def _process_category(
|
||||
chapter_category: str,
|
||||
category_segments: List,
|
||||
state: MemoirStateSchema,
|
||||
profile: str,
|
||||
birth_year,
|
||||
llm,
|
||||
):
|
||||
"""单章节处理:NarrativeAgent 生成标题+叙事,PlaceholderInjectAgent 注入,持久化"""
|
||||
segment_texts = [seg.transcript_text or "" for seg in category_segments]
|
||||
combined_text = "\n\n".join(segment_texts)
|
||||
source_ids = [seg.id for seg in category_segments]
|
||||
|
||||
for slot_name, snippet in extracted_slots.items():
|
||||
state = _update_slot_sync(
|
||||
stmt_chapter = (
|
||||
select(Chapter)
|
||||
.where(
|
||||
Chapter.user_id == user_id,
|
||||
Chapter.category == chapter_category,
|
||||
Chapter.is_active == True,
|
||||
)
|
||||
.options(
|
||||
joinedload(Chapter.sections).joinedload(ChapterSection.image_record),
|
||||
joinedload(Chapter.images),
|
||||
)
|
||||
)
|
||||
result_chapter = db.execute(stmt_chapter)
|
||||
chapter = result_chapter.unique().scalar_one_or_none()
|
||||
|
||||
slot_snippets = {}
|
||||
stage_slots = state.slots.get(chapter_category, {}) or {}
|
||||
for key, value in stage_slots.items():
|
||||
snip = getattr(value, "snippet", None) or (value.get("snippet") if isinstance(value, dict) else None)
|
||||
if snip:
|
||||
slot_snippets[key] = snip
|
||||
|
||||
title = chapter.title if chapter else f"{chapter_category} 回忆"
|
||||
existing_content = ""
|
||||
if chapter and getattr(chapter, "sections", None):
|
||||
existing_content = "\n\n".join(
|
||||
s.content for s in sorted(chapter.sections, key=lambda x: x.order_index) if (s.content or "").strip()
|
||||
)
|
||||
narrative = combined_text
|
||||
|
||||
if not chapter:
|
||||
title = narrative_agent.generate_title(
|
||||
stage=chapter_category,
|
||||
emotion="neutral",
|
||||
slots=slot_snippets,
|
||||
user_profile=profile,
|
||||
birth_year=birth_year,
|
||||
llm=llm,
|
||||
)
|
||||
new_narrative = narrative_agent.generate_narrative(
|
||||
stage=chapter_category,
|
||||
slots=slot_snippets,
|
||||
new_content=combined_text,
|
||||
existing_content=existing_content,
|
||||
user_profile=profile,
|
||||
birth_year=birth_year,
|
||||
llm=llm,
|
||||
)
|
||||
if existing_content:
|
||||
narrative = f"{existing_content}\n\n{new_narrative}"
|
||||
else:
|
||||
narrative = new_narrative
|
||||
|
||||
if existing_content and len(narrative) < len(existing_content) * 0.8:
|
||||
logger.warning(
|
||||
"内容长度异常: existing=%d, new=%d, category=%s. 回退为追加模式",
|
||||
len(existing_content),
|
||||
len(narrative),
|
||||
chapter_category,
|
||||
)
|
||||
narrative = f"{existing_content}\n\n{combined_text}"
|
||||
|
||||
narrative = inject_placeholders(narrative)
|
||||
calculated_order_index = STAGE_TO_ORDER.get(chapter_category, 999)
|
||||
|
||||
chapter = _save_narrative_to_sections(
|
||||
db,
|
||||
chapter,
|
||||
narrative,
|
||||
title=title,
|
||||
category=chapter_category,
|
||||
order_index=calculated_order_index,
|
||||
source_segments=source_ids,
|
||||
user_id=user_id,
|
||||
)
|
||||
db.flush()
|
||||
db.refresh(chapter)
|
||||
|
||||
has_images = image_settings.enabled and (
|
||||
_chapter_has_any_section_images_to_generate(chapter)
|
||||
or _chapter_has_cover_to_generate(chapter)
|
||||
)
|
||||
|
||||
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
|
||||
result_book = db.execute(stmt_book)
|
||||
book = result_book.scalar_one_or_none()
|
||||
if not book:
|
||||
book = Book(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
stage=detected_stage,
|
||||
slot_name=slot_name,
|
||||
snippet=snippet,
|
||||
segment_ids=[segment.id],
|
||||
db=db,
|
||||
title="我的回忆录",
|
||||
total_pages=0,
|
||||
total_words=0,
|
||||
cover_image_url=None,
|
||||
)
|
||||
db.add(book)
|
||||
book.has_update = True
|
||||
book.last_update_chapter_id = chapter.id
|
||||
|
||||
# 8-category 章节分类
|
||||
chapter_category = _classify_chapter_category(text, detected_stage, llm)
|
||||
if chapter_category is None:
|
||||
logger.info(f"段落无回忆录价值,跳过: segment_id={segment.id}")
|
||||
continue
|
||||
category_to_segments.setdefault(chapter_category, []).append(segment)
|
||||
return chapter, has_images
|
||||
|
||||
# 按 8 分类生成章节内容
|
||||
for chapter_category, category_segments in category_to_segments.items():
|
||||
if not _acquire_chapter_lock(user_id, chapter_category):
|
||||
logger.warning(f"章节锁竞争: user={user_id}, category={chapter_category}, 延迟重试")
|
||||
raise self.retry(countdown=10)
|
||||
try:
|
||||
segment_texts = [seg.transcript_text for seg in category_segments]
|
||||
combined_text = "\n\n".join(segment_texts)
|
||||
source_ids = [seg.id for seg in category_segments]
|
||||
def _raise_retry():
|
||||
raise self.retry(countdown=10)
|
||||
|
||||
# 查找 active 章节(被清除的章节不继续更新,而是创建新的),并预加载 sections、images
|
||||
stmt_chapter = (
|
||||
select(Chapter)
|
||||
.where(
|
||||
Chapter.user_id == user_id,
|
||||
Chapter.category == chapter_category,
|
||||
Chapter.is_active == True,
|
||||
)
|
||||
.options(
|
||||
joinedload(Chapter.sections).joinedload(ChapterSection.image_record),
|
||||
joinedload(Chapter.images),
|
||||
)
|
||||
)
|
||||
result_chapter = db.execute(stmt_chapter)
|
||||
chapter = result_chapter.unique().scalar_one_or_none()
|
||||
|
||||
# 获取 slot snippets
|
||||
slot_snippets = {
|
||||
key: value.snippet
|
||||
for key, value in (state.slots.get(chapter_category, {}) or {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
|
||||
# 生成标题和内容;已有章节的正文从 sections 拼接
|
||||
title = chapter.title if chapter else f"{chapter_category} 回忆"
|
||||
existing_content = ""
|
||||
if chapter and getattr(chapter, "sections", None):
|
||||
existing_content = "\n\n".join(
|
||||
s.content for s in sorted(chapter.sections, key=lambda x: x.order_index) if (s.content or "").strip()
|
||||
)
|
||||
narrative = combined_text
|
||||
|
||||
if llm:
|
||||
try:
|
||||
if not chapter:
|
||||
title_prompt = get_creative_title_prompt(
|
||||
stage=chapter_category,
|
||||
emotion="neutral",
|
||||
slots=slot_snippets,
|
||||
user_profile=user_profile,
|
||||
birth_year=user_birth_year,
|
||||
)
|
||||
title_response = llm.invoke(title_prompt)
|
||||
title = title_response.content.strip().strip('"')
|
||||
|
||||
narrative_prompt = get_narrative_prompt(
|
||||
stage=chapter_category,
|
||||
slots=slot_snippets,
|
||||
new_content=combined_text,
|
||||
existing_content=existing_content,
|
||||
user_profile=user_profile,
|
||||
birth_year=user_birth_year,
|
||||
)
|
||||
narrative_response = llm.invoke(narrative_prompt)
|
||||
new_narrative = narrative_response.content.strip()
|
||||
|
||||
# 追加而非替换
|
||||
if existing_content:
|
||||
narrative = f"{existing_content}\n\n{new_narrative}"
|
||||
else:
|
||||
narrative = new_narrative
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM 生成失败: {e}")
|
||||
if existing_content:
|
||||
narrative = f"{existing_content}\n\n{combined_text}"
|
||||
|
||||
# 安全检查:新内容不应比旧内容短
|
||||
if existing_content and len(narrative) < len(existing_content) * 0.8:
|
||||
logger.warning(
|
||||
f"内容长度异常: existing={len(existing_content)}, "
|
||||
f"new={len(narrative)}, category={chapter_category}. 回退为追加模式"
|
||||
)
|
||||
narrative = f"{existing_content}\n\n{combined_text}"
|
||||
|
||||
# 入库前:占位符位置用正则匹配后拼上固定模板
|
||||
narrative = inject_image_placeholder_template(narrative)
|
||||
calculated_order_index = STAGE_TO_ORDER.get(chapter_category, 999)
|
||||
|
||||
# 写入 sections(拆段 + 每段配图占位),新建或覆盖该章下所有 sections
|
||||
chapter = _save_narrative_to_sections(
|
||||
db,
|
||||
chapter,
|
||||
narrative,
|
||||
title=title,
|
||||
category=chapter_category,
|
||||
order_index=calculated_order_index,
|
||||
source_segments=source_ids,
|
||||
user_id=user_id,
|
||||
)
|
||||
db.flush()
|
||||
db.refresh(chapter)
|
||||
if image_settings.enabled and (
|
||||
_chapter_has_any_section_images_to_generate(chapter)
|
||||
or _chapter_has_cover_to_generate(chapter)
|
||||
):
|
||||
chapters_to_enqueue.add(chapter.id)
|
||||
|
||||
# 更新 Book
|
||||
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
|
||||
result_book = db.execute(stmt_book)
|
||||
book = result_book.scalar_one_or_none()
|
||||
if not book:
|
||||
book = Book(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title="我的回忆录",
|
||||
total_pages=0,
|
||||
total_words=0,
|
||||
cover_image_url=None,
|
||||
)
|
||||
db.add(book)
|
||||
book.has_update = True
|
||||
book.last_update_chapter_id = chapter.id
|
||||
finally:
|
||||
_release_chapter_lock(user_id, chapter_category)
|
||||
memoir_orchestrator = MemoirOrchestrator()
|
||||
chapters_to_enqueue, _ = memoir_orchestrator.run(
|
||||
segments=segments,
|
||||
llm=llm,
|
||||
user_profile=user_profile,
|
||||
user_birth_year=user_birth_year,
|
||||
get_or_create_state=lambda: _get_or_create_state_sync(user_id, db),
|
||||
update_slot=lambda stage, slot_name, snippet, seg_ids: _update_slot_sync(
|
||||
user_id, stage, slot_name, snippet, seg_ids, db
|
||||
),
|
||||
acquire_lock=lambda stage: _acquire_chapter_lock(user_id, stage),
|
||||
release_lock=lambda stage: _release_chapter_lock(user_id, stage),
|
||||
process_category=_process_category,
|
||||
raise_retry=_raise_retry,
|
||||
)
|
||||
|
||||
# 标记段落为已处理
|
||||
for seg in segments:
|
||||
|
||||
Reference in New Issue
Block a user