187 lines
6.1 KiB
Python
187 lines
6.1 KiB
Python
"""
|
||
回忆录整理 Agent:基于传记结构,将口语改写为书面语,归类到章节
|
||
支持异步调用
|
||
"""
|
||
import json
|
||
import logging
|
||
from typing import List, Dict, Optional
|
||
|
||
from services.llm_service import llm_service
|
||
|
||
from .prompts import (
|
||
get_memory_prompt,
|
||
get_chapter_classification_prompt,
|
||
get_text_rewrite_prompt,
|
||
inject_image_placeholder_template,
|
||
CHAPTER_CATEGORIES,
|
||
STAGE_TO_ORDER,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class MemoryAgent:
|
||
"""回忆录整理 Agent(支持异步)"""
|
||
|
||
def __init__(self):
|
||
# 使用 LLM 服务获取 LLM 实例
|
||
self.llm = llm_service.get_llm()
|
||
|
||
async def classify_chapter(self, segments_text: str) -> str:
|
||
"""
|
||
异步分类章节
|
||
|
||
Args:
|
||
segments_text: 对话段落文本
|
||
|
||
Returns:
|
||
章节类别(如:childhood)
|
||
"""
|
||
if not self.llm:
|
||
# 如果没有配置 LLM,返回默认类别
|
||
return "childhood"
|
||
|
||
try:
|
||
prompt = get_chapter_classification_prompt(segments_text)
|
||
|
||
# 异步调用 LLM
|
||
response = await self.llm.ainvoke(prompt)
|
||
|
||
# 提取类别
|
||
content = response.content if hasattr(response, 'content') else str(response)
|
||
category = content.strip().lower()
|
||
|
||
# 验证类别是否有效
|
||
if category in CHAPTER_CATEGORIES:
|
||
return category
|
||
|
||
except Exception as e:
|
||
logger.error(f"分类章节失败: {e}")
|
||
|
||
# 默认返回 childhood
|
||
return "childhood"
|
||
|
||
async def rewrite_to_literary(
|
||
self,
|
||
segments_text: str,
|
||
chapter_category: str,
|
||
existing_content: Optional[str] = None
|
||
) -> Dict:
|
||
"""
|
||
异步将口语改写为书面语
|
||
|
||
Args:
|
||
segments_text: 对话段落文本
|
||
chapter_category: 章节类别
|
||
existing_content: 已有章节内容(可选)
|
||
|
||
Returns:
|
||
包含 title, content, summary, image_suggestions 的字典
|
||
"""
|
||
if not self.llm:
|
||
# 如果没有配置 LLM,返回基本结构
|
||
return {
|
||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||
"content": segments_text,
|
||
"summary": "",
|
||
"image_suggestions": []
|
||
}
|
||
|
||
try:
|
||
prompt = get_text_rewrite_prompt(segments_text, chapter_category, existing_content or "")
|
||
|
||
# 异步调用 LLM
|
||
response = await self.llm.ainvoke(prompt)
|
||
|
||
# 尝试解析 JSON
|
||
content = response.content if hasattr(response, 'content') else str(response)
|
||
content = content.strip()
|
||
|
||
# 移除可能的 markdown 代码块标记
|
||
if content.startswith("```json"):
|
||
content = content[7:]
|
||
if content.startswith("```"):
|
||
content = content[3:]
|
||
if content.endswith("```"):
|
||
content = content[:-3]
|
||
content = content.strip()
|
||
|
||
result = json.loads(content)
|
||
result["content"] = inject_image_placeholder_template(result.get("content") or "")
|
||
return result
|
||
|
||
except json.JSONDecodeError:
|
||
# 如果解析失败,返回基本结构
|
||
raw = response.content if hasattr(response, 'content') else str(response)
|
||
return {
|
||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||
"content": inject_image_placeholder_template(raw),
|
||
"summary": "",
|
||
"image_suggestions": []
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"改写文本失败: {e}")
|
||
return {
|
||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||
"content": segments_text,
|
||
"summary": "",
|
||
"image_suggestions": []
|
||
}
|
||
|
||
async def process_segments(
|
||
self,
|
||
segments: List[Dict],
|
||
existing_chapters: Optional[Dict[str, Dict]] = None
|
||
) -> Dict[str, Dict]:
|
||
"""
|
||
异步处理对话段落,生成或更新章节
|
||
|
||
Args:
|
||
segments: 对话段落列表,每个包含 transcript_text
|
||
existing_chapters: 已有章节字典,key 为 category
|
||
|
||
Returns:
|
||
更新后的章节字典
|
||
"""
|
||
if existing_chapters is None:
|
||
existing_chapters = {}
|
||
|
||
# 按章节分类组织段落
|
||
segments_by_category: Dict[str, List[str]] = {}
|
||
|
||
for segment in segments:
|
||
text = segment.get("transcript_text", "")
|
||
if not text:
|
||
continue
|
||
|
||
# 异步分类
|
||
category = await self.classify_chapter(text)
|
||
|
||
if category not in segments_by_category:
|
||
segments_by_category[category] = []
|
||
|
||
segments_by_category[category].append(text)
|
||
|
||
# 为每个类别生成或更新章节
|
||
updated_chapters = existing_chapters.copy()
|
||
|
||
for category, texts in segments_by_category.items():
|
||
combined_text = "\n\n".join(texts)
|
||
existing_content = existing_chapters.get(category, {}).get("content", "")
|
||
|
||
# 异步改写为书面语
|
||
result = await self.rewrite_to_literary(combined_text, category, existing_content)
|
||
|
||
# 更新章节
|
||
updated_chapters[category] = {
|
||
"title": result.get("title", CHAPTER_CATEGORIES.get(category, "章节")),
|
||
"content": result.get("content", ""),
|
||
"summary": result.get("summary", ""),
|
||
"image_suggestions": result.get("image_suggestions", []),
|
||
"category": category,
|
||
"order_index": STAGE_TO_ORDER.get(category, 999)
|
||
}
|
||
|
||
return updated_chapters
|
||
|