feat: 生成回忆录agent结构封装

This commit is contained in:
yangshilin
2026-03-19 10:38:11 +08:00
parent b16bb2b96c
commit 4a1d6f0dcc
10 changed files with 881 additions and 227 deletions

View File

@@ -0,0 +1,25 @@
"""回忆录模块MemoryAgent、BackgroundTaskRunner、MemoirOrchestrator、各 Specialist Agent"""
from app.agents.memoir.memory_agent import MemoryAgent
from app.agents.memoir.processor import (
BackgroundTaskRunner,
ContentAnalyzer,
MemoirGenerator,
)
from app.agents.memoir.orchestrator import MemoirOrchestrator
from app.agents.memoir.extraction_agent import ExtractionAgent, ExtractionResult
from app.agents.memoir.classification_agent import ClassificationAgent
from app.agents.memoir.narrative_agent import NarrativeAgent
from app.agents.memoir.placeholder_agent import inject_placeholders
__all__ = [
"MemoryAgent",
"BackgroundTaskRunner",
"ContentAnalyzer",
"MemoirGenerator",
"MemoirOrchestrator",
"ExtractionAgent",
"ExtractionResult",
"ClassificationAgent",
"NarrativeAgent",
"inject_placeholders",
]

View File

@@ -0,0 +1,77 @@
"""
ClassificationAgent将内容分类到 8 个章节类别,或判定无价值返回 None。
对应现有逻辑_classify_chapter_category
"""
from __future__ import annotations
from typing import Any, Optional
from app.core.logging import get_logger
from app.agents.prompts.memory_prompts import (
CHAPTER_CATEGORIES,
get_chapter_classification_prompt,
)
logger = get_logger(__name__)
# 5-stage 关键词(用于 LLM 失败时的兜底)
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
# 5-stage → 默认 8-category 映射LLM 分类失败时的兜底)
_STAGE_TO_DEFAULT_CATEGORY = {
"childhood": "childhood",
"education": "education",
"career": "career_early",
"family": "family",
"belief": "beliefs",
}
def _detect_stage(text: str, fallback_stage: str) -> str:
"""根据关键词检测消息所属的 5-stage 阶段"""
message = (text or "").lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
class ClassificationAgent:
"""将内容分类到 8 个章节类别之一,或判定无价值返回 None"""
def classify(
self,
text: str,
fallback_stage: str,
llm: Any,
) -> Optional[str]:
"""
分类到 8 个章节类别之一。
若 LLM 判定内容无实质回忆录价值,返回 None。
llm 需支持 .invoke(prompt) 同步调用。
"""
if llm:
try:
prompt = get_chapter_classification_prompt(text)
response = llm.invoke(prompt)
category = (response.content or "").strip().lower()
if category == "none":
logger.info("LLM 判定内容无回忆录价值,跳过: %s...", (text or "")[:80])
return None
if category in CHAPTER_CATEGORIES:
return category
except Exception as e:
logger.warning("ClassificationAgent LLM 章节分类失败: %s", e)
stage = _detect_stage(text, fallback_stage)
return _STAGE_TO_DEFAULT_CATEGORY.get(
stage,
_STAGE_TO_DEFAULT_CATEGORY.get(fallback_stage, "childhood"),
)

View File

@@ -0,0 +1,66 @@
"""
ExtractionAgent从用户消息中提取 5-stage 状态与 slots。
对应现有逻辑get_state_extraction_prompt + JSON 解析
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict
from app.core.logging import get_logger
from app.features.memoir.memoir_images.json_payload import extract_json_payload
from app.agents.prompts.memory_prompts import get_state_extraction_prompt
logger = get_logger(__name__)
@dataclass
class ExtractionResult:
"""状态提取结果"""
detected_stage: str
slots: Dict[str, str]
class ExtractionAgent:
"""从用户消息中提取 detected_stage 和 slots"""
def extract(
self,
user_message: str,
current_stage: str,
stage_slots: Dict[str, Any],
llm: Any,
) -> ExtractionResult:
"""
提取结构化信息并判断阶段。
llm 需支持 .invoke(prompt) 同步调用Celery 任务内使用)。
"""
detected_stage = current_stage
extracted_slots: Dict[str, str] = {}
if not llm:
return ExtractionResult(detected_stage=detected_stage, slots=extracted_slots)
try:
prompt = get_state_extraction_prompt(
user_message=user_message,
current_stage=current_stage,
stage_slots={
k: v.model_dump() if hasattr(v, "model_dump") else v
for k, v in (stage_slots or {}).items()
},
)
response = llm.invoke(prompt)
parsed = json.loads(extract_json_payload(response.content))
detected_stage = parsed.get("detected_stage", detected_stage)
raw_slots = parsed.get("slots", {}) or {}
extracted_slots = {
k: v if isinstance(v, str) else str(v)
for k, v in raw_slots.items()
}
except (json.JSONDecodeError, Exception) as e:
logger.warning("ExtractionAgent LLM 解析失败: %s", e)
return ExtractionResult(detected_stage=detected_stage, slots=extracted_slots)

View File

@@ -0,0 +1,130 @@
"""
回忆录整理 Agent基于传记结构将口语改写为书面语归类到章节
支持异步调用
"""
import json
from typing import Dict, List, Optional
from app.core.dependencies import get_llm_provider
from app.core.logging import get_logger
from app.agents.prompts import (
get_chapter_classification_prompt,
get_text_rewrite_prompt,
inject_image_placeholder_template,
CHAPTER_CATEGORIES,
STAGE_TO_ORDER,
)
logger = get_logger(__name__)
def _get_langchain_llm():
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
class MemoryAgent:
"""回忆录整理 Agent支持异步"""
def __init__(self):
self.llm = _get_langchain_llm()
async def classify_chapter(self, segments_text: str) -> str:
if not self.llm:
return "childhood"
try:
prompt = get_chapter_classification_prompt(segments_text)
response = await self.llm.ainvoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
category = content.strip().lower()
if category in CHAPTER_CATEGORIES:
return category
except Exception as e:
logger.error("分类章节失败: %s", e)
return "childhood"
async def rewrite_to_literary(
self,
segments_text: str,
chapter_category: str,
existing_content: Optional[str] = None,
) -> Dict:
if not self.llm:
return {
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
"content": segments_text,
"summary": "",
"image_suggestions": [],
}
try:
prompt = get_text_rewrite_prompt(
segments_text, chapter_category, existing_content or ""
)
response = await self.llm.ainvoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
content = content.strip()
if content.startswith("```json"):
content = content[7:]
if content.startswith("```"):
content = content[3:]
if content.endswith("```"):
content = content[:-3]
content = content.strip()
result = json.loads(content)
result["content"] = inject_image_placeholder_template(
result.get("content") or ""
)
return result
except json.JSONDecodeError:
raw = response.content if hasattr(response, "content") else str(response)
return {
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
"content": inject_image_placeholder_template(raw),
"summary": "",
"image_suggestions": [],
}
except Exception as e:
logger.error("改写文本失败: %s", e)
return {
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
"content": segments_text,
"summary": "",
"image_suggestions": [],
}
async def process_segments(
self,
segments: List[Dict],
existing_chapters: Optional[Dict[str, Dict]] = None,
) -> Dict[str, Dict]:
if existing_chapters is None:
existing_chapters = {}
segments_by_category: Dict[str, List[str]] = {}
for segment in segments:
text = segment.get("transcript_text", "")
if not text:
continue
category = await self.classify_chapter(text)
if category not in segments_by_category:
segments_by_category[category] = []
segments_by_category[category].append(text)
updated_chapters = existing_chapters.copy()
for category, texts in segments_by_category.items():
combined_text = "\n\n".join(texts)
existing_content = existing_chapters.get(category, {}).get("content", "")
result = await self.rewrite_to_literary(
combined_text, category, existing_content
)
updated_chapters[category] = {
"title": result.get("title", CHAPTER_CATEGORIES.get(category, "章节")),
"content": result.get("content", ""),
"summary": result.get("summary", ""),
"image_suggestions": result.get("image_suggestions", []),
"category": category,
"order_index": STAGE_TO_ORDER.get(category, 999),
}
return updated_chapters

View File

@@ -0,0 +1,78 @@
"""
NarrativeAgent生成创意标题和叙事改写。
对应现有逻辑get_creative_title_prompt、get_narrative_prompt
"""
from __future__ import annotations
from typing import Any, Dict, Optional
from app.core.logging import get_logger
from app.agents.prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
)
logger = get_logger(__name__)
class NarrativeAgent:
"""生成章节标题和叙事正文"""
def generate_title(
self,
stage: str,
emotion: str,
slots: Dict[str, str],
user_profile: str = "",
birth_year: Optional[int] = None,
llm: Any = None,
) -> str:
"""生成创意标题。若无 LLM 则返回默认标题"""
if not llm:
return f"{stage} 回忆"
try:
prompt = get_creative_title_prompt(
stage=stage,
emotion=emotion,
slots=slots,
user_profile=user_profile,
birth_year=birth_year,
)
response = llm.invoke(prompt)
return (response.content or "").strip().strip('"')
except Exception as e:
logger.warning("NarrativeAgent 生成标题失败: %s", e)
return f"{stage} 回忆"
def generate_narrative(
self,
stage: str,
slots: Dict[str, str],
new_content: str,
existing_content: str = "",
user_profile: str = "",
birth_year: Optional[int] = None,
llm: Any = None,
) -> str:
"""将新对话改写为叙述。若无 LLM 则直接拼接"""
if not llm:
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
try:
prompt = get_narrative_prompt(
stage=stage,
slots=slots,
new_content=new_content,
existing_content=existing_content,
user_profile=user_profile,
birth_year=birth_year,
)
response = llm.invoke(prompt)
return (response.content or "").strip()
except Exception as e:
logger.warning("NarrativeAgent 生成叙事失败: %s", e)
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content

View File

@@ -0,0 +1,124 @@
"""
MemoirOrchestrator按 segment 编排流水线,调用各 Specialist Agent。
负责:遍历 segments、按 category 聚合、调用 Specialist、更新 state
持久化与章节生成由 process_category 回调完成。
"""
from __future__ import annotations
from typing import Any, Callable, Dict, List, Set, Tuple
from app.core.logging import get_logger
from app.features.conversation.models import Segment
from app.agents.state_schema import MemoirStateSchema
from app.agents.memoir.extraction_agent import ExtractionAgent, ExtractionResult
from app.agents.memoir.classification_agent import (
ClassificationAgent,
_detect_stage as detect_stage_from_keywords,
)
logger = get_logger(__name__)
class MemoirOrchestrator:
"""
回忆录生成编排器。
遍历 segments → ExtractionAgent → ClassificationAgent → 按 category 聚合 →
调用 process_category 生成叙事并持久化。
"""
def __init__(self) -> None:
self.extraction_agent = ExtractionAgent()
self.classification_agent = ClassificationAgent()
def run(
self,
*,
segments: List[Segment],
llm: Any,
user_profile: str = "",
user_birth_year: Any = None,
get_or_create_state: Callable[[], MemoirStateSchema],
update_slot: Callable[
[str, str, str, List[str]], MemoirStateSchema
],
acquire_lock: Callable[[str], bool],
release_lock: Callable[[str], None],
process_category: Callable[
[
str,
List[Segment],
MemoirStateSchema,
str,
Any,
Any,
],
Tuple[Any, bool],
],
raise_retry: Callable[[], None],
) -> Tuple[Set[str], int]:
"""
执行回忆录流水线。
process_category(category, segments, state, user_profile, user_birth_year, llm)
返回 (chapter, has_images_to_generate)。
返回 (chapters_to_enqueue, processed_count)。
raise_retry 用于锁竞争时抛出 Celery retry。
"""
state = get_or_create_state()
chapters_to_enqueue: Set[str] = set()
category_to_segments: Dict[str, List[Segment]] = {}
# 1) 遍历 segmentsExtractionAgent → 更新 slotsClassificationAgent → 聚合
for segment in segments:
text = segment.transcript_text or ""
# 关键词预检测阶段,用于 slot 查找(与原有逻辑一致)
initial_stage = detect_stage_from_keywords(
text, state.current_stage or "childhood"
)
stage_slots_raw = state.slots.get(initial_stage, {}) or {}
result: ExtractionResult = self.extraction_agent.extract(
user_message=text,
current_stage=state.current_stage or "childhood",
stage_slots=stage_slots_raw,
llm=llm,
)
detected_stage = result.detected_stage
for slot_name, snippet in result.slots.items():
state = update_slot(detected_stage, slot_name, snippet, [segment.id])
# ClassificationAgent
chapter_category = self.classification_agent.classify(
text=text,
fallback_stage=detected_stage,
llm=llm,
)
if chapter_category is None:
logger.info("段落无回忆录价值,跳过: segment_id=%s", segment.id)
continue
category_to_segments.setdefault(chapter_category, []).append(segment)
# 2) 按 category 调用 process_category内含 NarrativeAgent、PlaceholderInject、持久化
for chapter_category, category_segments in category_to_segments.items():
if not acquire_lock(chapter_category):
logger.warning(
"章节锁竞争: category=%s, 延迟重试",
chapter_category,
)
raise_retry()
try:
chapter, has_images = process_category(
chapter_category,
category_segments,
state,
user_profile,
user_birth_year,
llm,
)
if chapter and has_images:
chapters_to_enqueue.add(chapter.id)
finally:
release_lock(chapter_category)
return chapters_to_enqueue, len(segments)

View File

@@ -0,0 +1,14 @@
"""
PlaceholderInjectAgent对 narrative 做占位符模板注入。
对应现有逻辑inject_image_placeholder_template
纯函数式,无 LLM 调用。
"""
from app.agents.prompts.memory_prompts import inject_image_placeholder_template
def inject_placeholders(content: str) -> str:
"""
对章节正文做占位符处理:匹配所有图片占位符,拼上固定模板。
与 inject_image_placeholder_template 行为一致。
"""
return inject_image_placeholder_template(content)

View File

@@ -0,0 +1,212 @@
"""
回忆录后台处理器:分析对话、更新状态、生成章节、创意标题
使用 Celery 进行后台任务处理
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Dict, List
from app.core.dependencies import get_llm_provider
from app.core.logging import get_logger
from app.core.task_tracker import task_tracker
from app.agents.state_schema import MemoirStateSchema
from app.agents.prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
)
logger = get_logger(__name__)
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
def _get_langchain_llm():
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
@dataclass
class AnalysisResult:
detected_stage: str
extracted_slots: Dict[str, str]
emotion: str
is_new_chapter: bool
class ContentAnalyzer:
def __init__(self) -> None:
self.llm = _get_langchain_llm()
def _detect_stage(self, user_message: str, fallback_stage: str) -> str:
message = user_message.lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
def _fallback_slots(
self, state: MemoirStateSchema, stage: str, user_message: str
) -> Dict[str, str]:
stage_slots = state.slots.get(stage, {})
for key, value in stage_slots.items():
if not value.snippet:
return {key: user_message.strip()[:200]}
return {}
async def analyze_message(
self, user_message: str, current_state: MemoirStateSchema
) -> AnalysisResult:
detected_stage = self._detect_stage(
user_message, current_state.current_stage
)
extracted_slots: Dict[str, str] = {}
emotion = "neutral"
is_new_chapter = False
if self.llm:
try:
prompt = get_state_extraction_prompt(
user_message=user_message,
current_stage=current_state.current_stage,
stage_slots=current_state.slots.get(detected_stage, {}),
)
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
detected_stage = parsed.get("detected_stage", detected_stage)
extracted_slots = parsed.get("slots", {}) or {}
emotion = parsed.get("emotion", emotion)
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
except json.JSONDecodeError:
extracted_slots = self._fallback_slots(
current_state, detected_stage, user_message
)
except Exception as e:
logger.error("分析消息失败: %s", e)
extracted_slots = self._fallback_slots(
current_state, detected_stage, user_message
)
else:
extracted_slots = self._fallback_slots(
current_state, detected_stage, user_message
)
return AnalysisResult(
detected_stage=detected_stage,
extracted_slots=extracted_slots,
emotion=emotion,
is_new_chapter=is_new_chapter,
)
class MemoirGenerator:
def __init__(self) -> None:
self.llm = _get_langchain_llm()
async def generate_chapter_title(
self, stage: str, slots: Dict[str, str], emotion: str
) -> str:
if not self.llm:
return f"{stage} 回忆"
try:
prompt = get_creative_title_prompt(
stage=stage, emotion=emotion, slots=slots
)
response = await self.llm.ainvoke(prompt)
return response.content.strip().strip('"')
except Exception as e:
logger.error("生成标题失败: %s", e)
return f"{stage} 回忆"
async def generate_narrative(
self,
stage: str,
slots: Dict[str, str],
new_content: str,
existing_content: str,
) -> str:
if not self.llm:
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
try:
prompt = get_narrative_prompt(
stage=stage,
slots=slots,
new_content=new_content,
existing_content=existing_content,
)
response = await self.llm.ainvoke(prompt)
return response.content.strip()
except Exception as e:
logger.error("生成叙事失败: %s", e)
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
class BackgroundTaskRunner:
def __init__(self, debounce_seconds: int = 5) -> None:
self.debounce_seconds = debounce_seconds
self._pending: Dict[str, List[str]] = {}
self._timers: Dict[str, object] = {}
self.analyzer = ContentAnalyzer()
self.generator = MemoirGenerator()
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
try:
from app.tasks.memoir_tasks import process_memoir_segments
result = process_memoir_segments.delay(user_id, segment_ids)
task_id = result.id
await task_tracker.add_task(user_id, task_id, "memoir")
logger.info(
"已提交 Celery 任务: user_id=%s, task_id=%s, segments=%s",
user_id,
task_id,
len(segment_ids),
)
return task_id
except Exception as e:
logger.error("提交 Celery 任务失败: %s", e)
return None
async def queue_message(self, user_id: str, segment_id: str) -> None:
import asyncio
self._pending.setdefault(user_id, []).append(segment_id)
if user_id in self._timers:
self._timers[user_id].cancel()
async def delayed_submit():
try:
await asyncio.sleep(self.debounce_seconds)
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
await self._submit_task(user_id, segment_ids)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error("延迟提交任务失败: %s", e)
self._timers[user_id] = asyncio.create_task(delayed_submit())
async def flush_pending(self, user_id: str) -> str | None:
if user_id in self._timers:
self._timers[user_id].cancel()
del self._timers[user_id]
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
return await self._submit_task(user_id, segment_ids)
return None