feat: 添加Redis支持和Celery任务处理
- 新增Redis服务模块用于会话状态存储和缓存 - 集成Celery用于后台任务处理 - 更新Docker Compose配置以支持开发环境 - 优化API以支持异步调用和Redis会话存储 - 更新文档以反映新的开发环境配置和使用方法
This commit is contained in:
@@ -1,37 +1,54 @@
|
||||
"""
|
||||
对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应
|
||||
支持异步调用和 Redis 会话存储
|
||||
"""
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from services.llm_service import llm_service
|
||||
from services.redis_service import redis_service
|
||||
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt
|
||||
from .state_schema import MemoirStateSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationAgent:
|
||||
"""对话 Agent"""
|
||||
"""对话 Agent(支持异步和 Redis 存储)"""
|
||||
|
||||
def __init__(self):
|
||||
# 使用 LLM 服务获取 LLM 实例
|
||||
self.llm = llm_service.get_llm()
|
||||
|
||||
# 对话记忆
|
||||
self.memories: dict[str, ConversationBufferMemory] = {}
|
||||
|
||||
def _get_memory(self, conversation_id: str) -> ConversationBufferMemory:
|
||||
"""获取或创建对话记忆"""
|
||||
if conversation_id not in self.memories:
|
||||
self.memories[conversation_id] = ConversationBufferMemory(
|
||||
return_messages=True,
|
||||
memory_key="history"
|
||||
)
|
||||
return self.memories[conversation_id]
|
||||
async def _get_history_messages(self, conversation_id: str) -> List[Any]:
|
||||
"""从 Redis 获取对话历史并转换为 LangChain 消息格式"""
|
||||
history = await redis_service.get_conversation_history(conversation_id)
|
||||
messages = []
|
||||
for msg in history:
|
||||
if msg["role"] == "human":
|
||||
messages.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "ai":
|
||||
messages.append(AIMessage(content=msg["content"]))
|
||||
return messages
|
||||
|
||||
def generate_response(
|
||||
async def _save_message(self, conversation_id: str, role: str, content: str):
|
||||
"""保存消息到 Redis"""
|
||||
await redis_service.add_message(conversation_id, role, content)
|
||||
|
||||
def _format_history_string(self, messages: List[Any]) -> str:
|
||||
"""将消息列表格式化为字符串(用于 prompt)"""
|
||||
history_parts = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
history_parts.append(f"Human: {msg.content}")
|
||||
elif isinstance(msg, AIMessage):
|
||||
history_parts.append(f"Assistant: {msg.content}")
|
||||
return "\n\n".join(history_parts)
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
@@ -39,7 +56,7 @@ class ConversationAgent:
|
||||
covered_topics: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""
|
||||
生成 Agent 回应
|
||||
异步生成 Agent 回应
|
||||
|
||||
Args:
|
||||
conversation_id: 对话 ID
|
||||
@@ -60,38 +77,39 @@ class ConversationAgent:
|
||||
if not self.llm:
|
||||
return "抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
|
||||
|
||||
# 获取系统提示词
|
||||
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
|
||||
|
||||
# 获取对话记忆
|
||||
memory = self._get_memory(conversation_id)
|
||||
|
||||
# 创建对话链
|
||||
prompt_template = PromptTemplate(
|
||||
input_variables=["history", "input"],
|
||||
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
|
||||
)
|
||||
|
||||
chain = ConversationChain(
|
||||
llm=self.llm,
|
||||
prompt=prompt_template,
|
||||
memory=memory,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# 生成回应
|
||||
response = chain.predict(input=user_message)
|
||||
|
||||
return response
|
||||
try:
|
||||
# 获取系统提示词
|
||||
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
|
||||
|
||||
# 从 Redis 获取对话历史
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
# 构建完整 prompt
|
||||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
|
||||
# 异步调用 LLM
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 保存对话到 Redis
|
||||
await self._save_message(conversation_id, "human", user_message)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成回应失败: {e}")
|
||||
return f"抱歉,生成回应时出现错误: {str(e)}"
|
||||
|
||||
def generate_response_with_state(
|
||||
async def generate_response_with_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema
|
||||
) -> List[str]:
|
||||
"""
|
||||
基于共享状态生成引导式回复
|
||||
基于共享状态异步生成引导式回复
|
||||
|
||||
Args:
|
||||
conversation_id: 对话 ID
|
||||
@@ -104,39 +122,44 @@ class ConversationAgent:
|
||||
if not self.llm:
|
||||
return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
|
||||
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
filled_slots = {
|
||||
key: value.snippet
|
||||
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
try:
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
filled_slots = {
|
||||
key: value.snippet
|
||||
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots=empty_slots,
|
||||
filled_slots=filled_slots,
|
||||
user_message=user_message,
|
||||
)
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots=empty_slots,
|
||||
filled_slots=filled_slots,
|
||||
user_message=user_message,
|
||||
)
|
||||
|
||||
memory = self._get_memory(conversation_id)
|
||||
prompt_template = PromptTemplate(
|
||||
input_variables=["history", "input"],
|
||||
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
|
||||
)
|
||||
|
||||
chain = ConversationChain(
|
||||
llm=self.llm,
|
||||
prompt=prompt_template,
|
||||
memory=memory,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
response = chain.predict(input=user_message)
|
||||
|
||||
# 支持多条消息,用 [SPLIT] 分隔
|
||||
messages = [msg.strip() for msg in response.split("[SPLIT]") if msg.strip()]
|
||||
# 最多返回 3 条
|
||||
return messages[:3] if messages else [response]
|
||||
# 从 Redis 获取对话历史
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
# 构建完整 prompt
|
||||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
|
||||
# 异步调用 LLM
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 保存对话到 Redis
|
||||
await self._save_message(conversation_id, "human", user_message)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
# 支持多条消息,用 [SPLIT] 分隔
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
# 最多返回 3 条
|
||||
return messages[:3] if messages else [response_text]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成回应失败: {e}")
|
||||
return [f"抱歉,生成回应时出现错误: {str(e)}"]
|
||||
|
||||
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
|
||||
"""
|
||||
@@ -168,8 +191,7 @@ class ConversationAgent:
|
||||
# 默认返回当前阶段或童年阶段
|
||||
return ConversationStage.CHILDHOOD
|
||||
|
||||
def clear_memory(self, conversation_id: str):
|
||||
"""清除对话记忆"""
|
||||
if conversation_id in self.memories:
|
||||
del self.memories[conversation_id]
|
||||
async def clear_memory(self, conversation_id: str):
|
||||
"""清除对话记忆(从 Redis)"""
|
||||
await redis_service.clear_conversation_history(conversation_id)
|
||||
|
||||
|
||||
@@ -6,34 +6,25 @@
|
||||
- 更新回忆录状态(slots)
|
||||
- 生成/更新章节内容
|
||||
- 创建创意章节标题
|
||||
|
||||
使用 Celery 进行后台任务处理,支持可靠的任务队列和重试机制
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from typing import Dict, List
|
||||
|
||||
from agents.state_schema import MemoirStateSchema
|
||||
from database.database import AsyncSessionLocal
|
||||
from database.models import Book, Chapter, Segment
|
||||
from services.llm_service import llm_service
|
||||
from services.memoir_state_service import (
|
||||
get_or_create_state,
|
||||
get_empty_slots,
|
||||
mark_stage_complete,
|
||||
switch_stage,
|
||||
update_slot,
|
||||
)
|
||||
from .prompts.memory_prompts import (
|
||||
get_creative_title_prompt,
|
||||
get_narrative_prompt,
|
||||
get_state_extraction_prompt,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STAGE_KEYWORDS = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
||||
@@ -53,7 +44,7 @@ class AnalysisResult:
|
||||
|
||||
|
||||
class ContentAnalyzer:
|
||||
"""对话内容分析"""
|
||||
"""对话内容分析(支持异步)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.llm = llm_service.get_llm()
|
||||
@@ -79,17 +70,15 @@ class ContentAnalyzer:
|
||||
is_new_chapter = False
|
||||
|
||||
if self.llm:
|
||||
prompt = get_state_extraction_prompt(
|
||||
user_message=user_message,
|
||||
current_stage=current_state.current_stage,
|
||||
stage_slots=current_state.slots.get(detected_stage, {}),
|
||||
)
|
||||
# 使用异步调用避免阻塞
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.llm.invoke(prompt)
|
||||
)
|
||||
content = response.content.strip()
|
||||
try:
|
||||
prompt = get_state_extraction_prompt(
|
||||
user_message=user_message,
|
||||
current_stage=current_state.current_stage,
|
||||
stage_slots=current_state.slots.get(detected_stage, {}),
|
||||
)
|
||||
# 使用异步调用
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content.strip()
|
||||
parsed = json.loads(content)
|
||||
detected_stage = parsed.get("detected_stage", detected_stage)
|
||||
extracted_slots = parsed.get("slots", {}) or {}
|
||||
@@ -97,6 +86,9 @@ class ContentAnalyzer:
|
||||
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
|
||||
except json.JSONDecodeError:
|
||||
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
|
||||
except Exception as e:
|
||||
logger.error(f"分析消息失败: {e}")
|
||||
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
|
||||
else:
|
||||
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
|
||||
|
||||
@@ -109,7 +101,7 @@ class ContentAnalyzer:
|
||||
|
||||
|
||||
class MemoirGenerator:
|
||||
"""回忆录生成与更新"""
|
||||
"""回忆录生成与更新(支持异步)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.llm = llm_service.get_llm()
|
||||
@@ -117,142 +109,87 @@ class MemoirGenerator:
|
||||
async def generate_chapter_title(self, stage: str, slots: Dict[str, str], emotion: str) -> str:
|
||||
if not self.llm:
|
||||
return f"{stage} 回忆"
|
||||
prompt = get_creative_title_prompt(stage=stage, emotion=emotion, slots=slots)
|
||||
# 使用异步调用避免阻塞
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.llm.invoke(prompt)
|
||||
)
|
||||
return response.content.strip().strip('"')
|
||||
try:
|
||||
prompt = get_creative_title_prompt(stage=stage, emotion=emotion, slots=slots)
|
||||
# 使用异步调用
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
return response.content.strip().strip('"')
|
||||
except Exception as e:
|
||||
logger.error(f"生成标题失败: {e}")
|
||||
return f"{stage} 回忆"
|
||||
|
||||
async def generate_narrative(self, stage: str, slots: Dict[str, str], new_content: str, existing_content: str) -> str:
|
||||
if not self.llm:
|
||||
if existing_content:
|
||||
return f"{existing_content}\n\n{new_content}"
|
||||
return new_content
|
||||
prompt = get_narrative_prompt(stage=stage, slots=slots, new_content=new_content, existing_content=existing_content)
|
||||
# 使用异步调用避免阻塞
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.llm.invoke(prompt)
|
||||
)
|
||||
return response.content.strip()
|
||||
try:
|
||||
prompt = get_narrative_prompt(stage=stage, slots=slots, new_content=new_content, existing_content=existing_content)
|
||||
# 使用异步调用
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
return response.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"生成叙事失败: {e}")
|
||||
if existing_content:
|
||||
return f"{existing_content}\n\n{new_content}"
|
||||
return new_content
|
||||
|
||||
|
||||
class BackgroundTaskRunner:
|
||||
"""后台任务调度(去抖)"""
|
||||
"""后台任务调度器(使用 Celery)"""
|
||||
|
||||
def __init__(self, debounce_seconds: int = 5) -> None:
|
||||
self.debounce_seconds = debounce_seconds
|
||||
self.pending_tasks: Dict[str, List[str]] = {}
|
||||
self._scheduled: Dict[str, asyncio.Task] = {}
|
||||
# 内存中的待处理任务(用于去抖)
|
||||
self._pending: Dict[str, List[str]] = {}
|
||||
self._timers: Dict[str, object] = {}
|
||||
self.analyzer = ContentAnalyzer()
|
||||
self.generator = MemoirGenerator()
|
||||
|
||||
async def queue_message(self, user_id: str, segment_id: str) -> None:
|
||||
self.pending_tasks.setdefault(user_id, []).append(segment_id)
|
||||
if user_id in self._scheduled:
|
||||
self._scheduled[user_id].cancel()
|
||||
self._scheduled[user_id] = asyncio.create_task(self._debounced_process(user_id))
|
||||
"""
|
||||
将消息加入处理队列
|
||||
|
||||
使用 Celery 延迟任务实现去抖效果
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 收集待处理的 segment_ids
|
||||
self._pending.setdefault(user_id, []).append(segment_id)
|
||||
|
||||
# 取消之前的定时器
|
||||
if user_id in self._timers:
|
||||
self._timers[user_id].cancel()
|
||||
|
||||
# 创建新的定时器
|
||||
async def delayed_submit():
|
||||
try:
|
||||
await asyncio.sleep(self.debounce_seconds)
|
||||
segment_ids = self._pending.pop(user_id, [])
|
||||
if segment_ids:
|
||||
# 提交到 Celery
|
||||
from tasks.memoir_tasks import process_memoir_segments
|
||||
process_memoir_segments.delay(user_id, segment_ids)
|
||||
logger.info(f"已提交 Celery 任务: user_id={user_id}, segments={len(segment_ids)}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"提交 Celery 任务失败: {e}")
|
||||
|
||||
self._timers[user_id] = asyncio.create_task(delayed_submit())
|
||||
|
||||
async def _debounced_process(self, user_id: str) -> None:
|
||||
try:
|
||||
await asyncio.sleep(self.debounce_seconds)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
async with AsyncSessionLocal() as db:
|
||||
await self.process_pending(user_id, db)
|
||||
|
||||
async def process_pending(self, user_id: str, db) -> None:
|
||||
segment_ids = self.pending_tasks.pop(user_id, [])
|
||||
if not segment_ids:
|
||||
return
|
||||
|
||||
stmt = select(Segment).where(Segment.id.in_(segment_ids))
|
||||
result = await db.execute(stmt)
|
||||
segments = result.scalars().all()
|
||||
if not segments:
|
||||
return
|
||||
|
||||
state = await get_or_create_state(user_id, db)
|
||||
stage_to_segments: Dict[str, List[Segment]] = {}
|
||||
for segment in segments:
|
||||
analysis = await self.analyzer.analyze_message(segment.transcript_text, state)
|
||||
detected_stage = analysis.detected_stage
|
||||
if detected_stage != state.current_stage:
|
||||
state = await switch_stage(user_id, detected_stage, db)
|
||||
|
||||
for slot_name, snippet in analysis.extracted_slots.items():
|
||||
state = await update_slot(
|
||||
user_id=user_id,
|
||||
stage=detected_stage,
|
||||
slot_name=slot_name,
|
||||
snippet=snippet,
|
||||
segment_ids=[segment.id],
|
||||
db=db,
|
||||
)
|
||||
|
||||
stage_to_segments.setdefault(detected_stage, []).append(segment)
|
||||
|
||||
for stage, stage_segments in stage_to_segments.items():
|
||||
segment_texts = [seg.transcript_text for seg in stage_segments]
|
||||
combined_text = "\n\n".join(segment_texts)
|
||||
source_ids = [seg.id for seg in stage_segments]
|
||||
|
||||
stmt_chapter = select(Chapter).where(
|
||||
Chapter.user_id == user_id,
|
||||
Chapter.category == stage,
|
||||
)
|
||||
result_chapter = await db.execute(stmt_chapter)
|
||||
chapter = result_chapter.scalar_one_or_none()
|
||||
slot_snippets = {
|
||||
key: value.snippet for key, value in (state.slots.get(stage, {}) or {}).items() if value.snippet
|
||||
}
|
||||
title = chapter.title if chapter else await self.generator.generate_chapter_title(stage, slot_snippets, "neutral")
|
||||
existing_content = chapter.content if chapter else ""
|
||||
narrative = await self.generator.generate_narrative(stage, slot_snippets, combined_text, existing_content)
|
||||
|
||||
if chapter:
|
||||
chapter.content = narrative
|
||||
chapter.title = title
|
||||
chapter.is_new = True
|
||||
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
|
||||
else:
|
||||
chapter = Chapter(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
content=narrative,
|
||||
order_index=999,
|
||||
status="completed",
|
||||
category=stage,
|
||||
images=[],
|
||||
is_new=True,
|
||||
source_segments=source_ids,
|
||||
)
|
||||
db.add(chapter)
|
||||
|
||||
await db.flush()
|
||||
|
||||
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
|
||||
result_book = await db.execute(stmt_book)
|
||||
book = result_book.scalar_one_or_none()
|
||||
if not book:
|
||||
book = Book(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title="我的回忆录",
|
||||
total_pages=0,
|
||||
total_words=0,
|
||||
cover_image_url=None,
|
||||
)
|
||||
db.add(book)
|
||||
book.has_update = True
|
||||
book.last_update_chapter_id = chapter.id
|
||||
|
||||
empty_slots = await get_empty_slots(user_id, db)
|
||||
if not empty_slots:
|
||||
await mark_stage_complete(user_id, state.current_stage, db)
|
||||
|
||||
for seg in segments:
|
||||
seg.processed = True
|
||||
|
||||
await db.commit()
|
||||
async def flush_pending(self, user_id: str) -> None:
|
||||
"""
|
||||
立即提交用户的待处理任务(用于对话结束时)
|
||||
"""
|
||||
# 取消定时器
|
||||
if user_id in self._timers:
|
||||
self._timers[user_id].cancel()
|
||||
del self._timers[user_id]
|
||||
|
||||
# 提交待处理任务
|
||||
segment_ids = self._pending.pop(user_id, [])
|
||||
if segment_ids:
|
||||
from tasks.memoir_tasks import process_memoir_segments
|
||||
process_memoir_segments.delay(user_id, segment_ids)
|
||||
logger.info(f"立即提交 Celery 任务: user_id={user_id}, segments={len(segment_ids)}")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
回忆录整理 Agent:基于传记结构,将口语改写为书面语,归类到章节
|
||||
支持异步调用
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from services.llm_service import llm_service
|
||||
@@ -14,17 +16,19 @@ from .prompts import (
|
||||
CHAPTER_ORDER
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryAgent:
|
||||
"""回忆录整理 Agent"""
|
||||
"""回忆录整理 Agent(支持异步)"""
|
||||
|
||||
def __init__(self):
|
||||
# 使用 LLM 服务获取 LLM 实例
|
||||
self.llm = llm_service.get_llm()
|
||||
|
||||
def classify_chapter(self, segments_text: str) -> str:
|
||||
async def classify_chapter(self, segments_text: str) -> str:
|
||||
"""
|
||||
分类章节
|
||||
异步分类章节
|
||||
|
||||
Args:
|
||||
segments_text: 对话段落文本
|
||||
@@ -36,28 +40,34 @@ class MemoryAgent:
|
||||
# 如果没有配置 LLM,返回默认类别
|
||||
return "childhood"
|
||||
|
||||
prompt = get_chapter_classification_prompt(segments_text)
|
||||
|
||||
response = self.llm.invoke(prompt)
|
||||
|
||||
# 提取类别
|
||||
category = response.content.strip().lower()
|
||||
|
||||
# 验证类别是否有效
|
||||
if category in CHAPTER_CATEGORIES:
|
||||
return category
|
||||
try:
|
||||
prompt = get_chapter_classification_prompt(segments_text)
|
||||
|
||||
# 异步调用 LLM
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
|
||||
# 提取类别
|
||||
content = response.content if hasattr(response, 'content') else str(response)
|
||||
category = content.strip().lower()
|
||||
|
||||
# 验证类别是否有效
|
||||
if category in CHAPTER_CATEGORIES:
|
||||
return category
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分类章节失败: {e}")
|
||||
|
||||
# 默认返回 childhood
|
||||
return "childhood"
|
||||
|
||||
def rewrite_to_literary(
|
||||
async def rewrite_to_literary(
|
||||
self,
|
||||
segments_text: str,
|
||||
chapter_category: str,
|
||||
existing_content: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
将口语改写为书面语
|
||||
异步将口语改写为书面语
|
||||
|
||||
Args:
|
||||
segments_text: 对话段落文本
|
||||
@@ -76,14 +86,16 @@ class MemoryAgent:
|
||||
"image_suggestions": []
|
||||
}
|
||||
|
||||
prompt = get_text_rewrite_prompt(segments_text, chapter_category, existing_content or "")
|
||||
|
||||
response = self.llm.invoke(prompt)
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
# 提取 JSON 部分
|
||||
content = response.content.strip()
|
||||
prompt = get_text_rewrite_prompt(segments_text, chapter_category, existing_content or "")
|
||||
|
||||
# 异步调用 LLM
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
|
||||
# 尝试解析 JSON
|
||||
content = response.content if hasattr(response, 'content') else str(response)
|
||||
content = content.strip()
|
||||
|
||||
# 移除可能的 markdown 代码块标记
|
||||
if content.startswith("```json"):
|
||||
content = content[7:]
|
||||
@@ -95,22 +107,31 @@ class MemoryAgent:
|
||||
|
||||
result = json.loads(content)
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,返回基本结构
|
||||
return {
|
||||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||||
"content": response.content,
|
||||
"content": response.content if hasattr(response, 'content') else str(response),
|
||||
"summary": "",
|
||||
"image_suggestions": []
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"改写文本失败: {e}")
|
||||
return {
|
||||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||||
"content": segments_text,
|
||||
"summary": "",
|
||||
"image_suggestions": []
|
||||
}
|
||||
|
||||
def process_segments(
|
||||
async def process_segments(
|
||||
self,
|
||||
segments: List[Dict],
|
||||
existing_chapters: Optional[Dict[str, Dict]] = None
|
||||
) -> Dict[str, Dict]:
|
||||
"""
|
||||
处理对话段落,生成或更新章节
|
||||
异步处理对话段落,生成或更新章节
|
||||
|
||||
Args:
|
||||
segments: 对话段落列表,每个包含 transcript_text
|
||||
@@ -130,8 +151,8 @@ class MemoryAgent:
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# 分类
|
||||
category = self.classify_chapter(text)
|
||||
# 异步分类
|
||||
category = await self.classify_chapter(text)
|
||||
|
||||
if category not in segments_by_category:
|
||||
segments_by_category[category] = []
|
||||
@@ -145,8 +166,8 @@ class MemoryAgent:
|
||||
combined_text = "\n\n".join(texts)
|
||||
existing_content = existing_chapters.get(category, {}).get("content", "")
|
||||
|
||||
# 改写为书面语
|
||||
result = self.rewrite_to_literary(combined_text, category, existing_content)
|
||||
# 异步改写为书面语
|
||||
result = await self.rewrite_to_literary(combined_text, category, existing_content)
|
||||
|
||||
# 更新章节
|
||||
updated_chapters[category] = {
|
||||
|
||||
Reference in New Issue
Block a user