feat: 添加Redis支持和Celery任务处理

- 新增Redis服务模块用于会话状态存储和缓存
- 集成Celery用于后台任务处理
- 更新Docker Compose配置以支持开发环境
- 优化API以支持异步调用和Redis会话存储
- 更新文档以反映新的开发环境配置和使用方法
This commit is contained in:
penghanyuan
2026-01-21 23:06:47 +01:00
parent 44bd478c1e
commit dbbb924625
16 changed files with 1339 additions and 309 deletions

View File

@@ -1,37 +1,54 @@
"""
对话 Agent基于访谈问题清单动态选择问题实时生成回应
支持异步调用和 Redis 会话存储
"""
from typing import List, Optional
import logging
from typing import List, Optional, Dict, Any
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from services.llm_service import llm_service
from services.redis_service import redis_service
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt
from .state_schema import MemoirStateSchema
logger = logging.getLogger(__name__)
class ConversationAgent:
"""对话 Agent"""
"""对话 Agent(支持异步和 Redis 存储)"""
def __init__(self):
# 使用 LLM 服务获取 LLM 实例
self.llm = llm_service.get_llm()
# 对话记忆
self.memories: dict[str, ConversationBufferMemory] = {}
def _get_memory(self, conversation_id: str) -> ConversationBufferMemory:
"""获取或创建对话记忆"""
if conversation_id not in self.memories:
self.memories[conversation_id] = ConversationBufferMemory(
return_messages=True,
memory_key="history"
)
return self.memories[conversation_id]
async def _get_history_messages(self, conversation_id: str) -> List[Any]:
"""从 Redis 获取对话历史并转换为 LangChain 消息格式"""
history = await redis_service.get_conversation_history(conversation_id)
messages = []
for msg in history:
if msg["role"] == "human":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "ai":
messages.append(AIMessage(content=msg["content"]))
return messages
def generate_response(
async def _save_message(self, conversation_id: str, role: str, content: str):
"""保存消息到 Redis"""
await redis_service.add_message(conversation_id, role, content)
def _format_history_string(self, messages: List[Any]) -> str:
"""将消息列表格式化为字符串(用于 prompt"""
history_parts = []
for msg in messages:
if isinstance(msg, HumanMessage):
history_parts.append(f"Human: {msg.content}")
elif isinstance(msg, AIMessage):
history_parts.append(f"Assistant: {msg.content}")
return "\n\n".join(history_parts)
async def generate_response(
self,
conversation_id: str,
user_message: str,
@@ -39,7 +56,7 @@ class ConversationAgent:
covered_topics: Optional[List[str]] = None
) -> str:
"""
生成 Agent 回应
异步生成 Agent 回应
Args:
conversation_id: 对话 ID
@@ -60,38 +77,39 @@ class ConversationAgent:
if not self.llm:
return "抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
# 获取系统提示词
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
# 获取对话记忆
memory = self._get_memory(conversation_id)
# 创建对话链
prompt_template = PromptTemplate(
input_variables=["history", "input"],
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
)
chain = ConversationChain(
llm=self.llm,
prompt=prompt_template,
memory=memory,
verbose=False
)
# 生成回应
response = chain.predict(input=user_message)
return response
try:
# 获取系统提示词
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
# 从 Redis 获取对话历史
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
# 构建完整 prompt
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
# 异步调用 LLM
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
# 保存对话到 Redis
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
return response_text
except Exception as e:
logger.error(f"生成回应失败: {e}")
return f"抱歉,生成回应时出现错误: {str(e)}"
def generate_response_with_state(
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema
) -> List[str]:
"""
基于共享状态生成引导式回复
基于共享状态异步生成引导式回复
Args:
conversation_id: 对话 ID
@@ -104,39 +122,44 @@ class ConversationAgent:
if not self.llm:
return ["抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
empty_slots = memoir_state.empty_slots_for_current_stage()
filled_slots = {
key: value.snippet
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
if value.snippet
}
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
filled_slots = {
key: value.snippet
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
if value.snippet
}
system_prompt = get_guided_conversation_prompt(
current_stage=memoir_state.current_stage,
empty_slots=empty_slots,
filled_slots=filled_slots,
user_message=user_message,
)
system_prompt = get_guided_conversation_prompt(
current_stage=memoir_state.current_stage,
empty_slots=empty_slots,
filled_slots=filled_slots,
user_message=user_message,
)
memory = self._get_memory(conversation_id)
prompt_template = PromptTemplate(
input_variables=["history", "input"],
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
)
chain = ConversationChain(
llm=self.llm,
prompt=prompt_template,
memory=memory,
verbose=False
)
response = chain.predict(input=user_message)
# 支持多条消息,用 [SPLIT] 分隔
messages = [msg.strip() for msg in response.split("[SPLIT]") if msg.strip()]
# 最多返回 3 条
return messages[:3] if messages else [response]
# 从 Redis 获取对话历史
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
# 构建完整 prompt
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
# 异步调用 LLM
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
# 保存对话到 Redis
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
# 支持多条消息,用 [SPLIT] 分隔
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
# 最多返回 3 条
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error(f"生成回应失败: {e}")
return [f"抱歉,生成回应时出现错误: {str(e)}"]
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
"""
@@ -168,8 +191,7 @@ class ConversationAgent:
# 默认返回当前阶段或童年阶段
return ConversationStage.CHILDHOOD
def clear_memory(self, conversation_id: str):
"""清除对话记忆"""
if conversation_id in self.memories:
del self.memories[conversation_id]
async def clear_memory(self, conversation_id: str):
"""清除对话记忆(从 Redis"""
await redis_service.clear_conversation_history(conversation_id)

View File

@@ -6,34 +6,25 @@
- 更新回忆录状态slots
- 生成/更新章节内容
- 创建创意章节标题
使用 Celery 进行后台任务处理,支持可靠的任务队列和重试机制
"""
from __future__ import annotations
import asyncio
import json
import uuid
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional
from sqlalchemy import select
from typing import Dict, List
from agents.state_schema import MemoirStateSchema
from database.database import AsyncSessionLocal
from database.models import Book, Chapter, Segment
from services.llm_service import llm_service
from services.memoir_state_service import (
get_or_create_state,
get_empty_slots,
mark_stage_complete,
switch_stage,
update_slot,
)
from .prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
)
logger = logging.getLogger(__name__)
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
@@ -53,7 +44,7 @@ class AnalysisResult:
class ContentAnalyzer:
"""对话内容分析"""
"""对话内容分析(支持异步)"""
def __init__(self) -> None:
self.llm = llm_service.get_llm()
@@ -79,17 +70,15 @@ class ContentAnalyzer:
is_new_chapter = False
if self.llm:
prompt = get_state_extraction_prompt(
user_message=user_message,
current_stage=current_state.current_stage,
stage_slots=current_state.slots.get(detected_stage, {}),
)
# 使用异步调用避免阻塞
response = await asyncio.get_event_loop().run_in_executor(
None, lambda: self.llm.invoke(prompt)
)
content = response.content.strip()
try:
prompt = get_state_extraction_prompt(
user_message=user_message,
current_stage=current_state.current_stage,
stage_slots=current_state.slots.get(detected_stage, {}),
)
# 使用异步调用
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
detected_stage = parsed.get("detected_stage", detected_stage)
extracted_slots = parsed.get("slots", {}) or {}
@@ -97,6 +86,9 @@ class ContentAnalyzer:
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
except json.JSONDecodeError:
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
except Exception as e:
logger.error(f"分析消息失败: {e}")
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
else:
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
@@ -109,7 +101,7 @@ class ContentAnalyzer:
class MemoirGenerator:
"""回忆录生成与更新"""
"""回忆录生成与更新(支持异步)"""
def __init__(self) -> None:
self.llm = llm_service.get_llm()
@@ -117,142 +109,87 @@ class MemoirGenerator:
async def generate_chapter_title(self, stage: str, slots: Dict[str, str], emotion: str) -> str:
if not self.llm:
return f"{stage} 回忆"
prompt = get_creative_title_prompt(stage=stage, emotion=emotion, slots=slots)
# 使用异步调用避免阻塞
response = await asyncio.get_event_loop().run_in_executor(
None, lambda: self.llm.invoke(prompt)
)
return response.content.strip().strip('"')
try:
prompt = get_creative_title_prompt(stage=stage, emotion=emotion, slots=slots)
# 使用异步调用
response = await self.llm.ainvoke(prompt)
return response.content.strip().strip('"')
except Exception as e:
logger.error(f"生成标题失败: {e}")
return f"{stage} 回忆"
async def generate_narrative(self, stage: str, slots: Dict[str, str], new_content: str, existing_content: str) -> str:
if not self.llm:
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
prompt = get_narrative_prompt(stage=stage, slots=slots, new_content=new_content, existing_content=existing_content)
# 使用异步调用避免阻塞
response = await asyncio.get_event_loop().run_in_executor(
None, lambda: self.llm.invoke(prompt)
)
return response.content.strip()
try:
prompt = get_narrative_prompt(stage=stage, slots=slots, new_content=new_content, existing_content=existing_content)
# 使用异步调用
response = await self.llm.ainvoke(prompt)
return response.content.strip()
except Exception as e:
logger.error(f"生成叙事失败: {e}")
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
class BackgroundTaskRunner:
"""后台任务调度(去抖"""
"""后台任务调度器(使用 Celery"""
def __init__(self, debounce_seconds: int = 5) -> None:
self.debounce_seconds = debounce_seconds
self.pending_tasks: Dict[str, List[str]] = {}
self._scheduled: Dict[str, asyncio.Task] = {}
# 内存中的待处理任务(用于去抖)
self._pending: Dict[str, List[str]] = {}
self._timers: Dict[str, object] = {}
self.analyzer = ContentAnalyzer()
self.generator = MemoirGenerator()
async def queue_message(self, user_id: str, segment_id: str) -> None:
self.pending_tasks.setdefault(user_id, []).append(segment_id)
if user_id in self._scheduled:
self._scheduled[user_id].cancel()
self._scheduled[user_id] = asyncio.create_task(self._debounced_process(user_id))
"""
将消息加入处理队列
使用 Celery 延迟任务实现去抖效果
"""
import asyncio
# 收集待处理的 segment_ids
self._pending.setdefault(user_id, []).append(segment_id)
# 取消之前的定时器
if user_id in self._timers:
self._timers[user_id].cancel()
# 创建新的定时器
async def delayed_submit():
try:
await asyncio.sleep(self.debounce_seconds)
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
# 提交到 Celery
from tasks.memoir_tasks import process_memoir_segments
process_memoir_segments.delay(user_id, segment_ids)
logger.info(f"已提交 Celery 任务: user_id={user_id}, segments={len(segment_ids)}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"提交 Celery 任务失败: {e}")
self._timers[user_id] = asyncio.create_task(delayed_submit())
async def _debounced_process(self, user_id: str) -> None:
try:
await asyncio.sleep(self.debounce_seconds)
except asyncio.CancelledError:
return
async with AsyncSessionLocal() as db:
await self.process_pending(user_id, db)
async def process_pending(self, user_id: str, db) -> None:
segment_ids = self.pending_tasks.pop(user_id, [])
if not segment_ids:
return
stmt = select(Segment).where(Segment.id.in_(segment_ids))
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
return
state = await get_or_create_state(user_id, db)
stage_to_segments: Dict[str, List[Segment]] = {}
for segment in segments:
analysis = await self.analyzer.analyze_message(segment.transcript_text, state)
detected_stage = analysis.detected_stage
if detected_stage != state.current_stage:
state = await switch_stage(user_id, detected_stage, db)
for slot_name, snippet in analysis.extracted_slots.items():
state = await update_slot(
user_id=user_id,
stage=detected_stage,
slot_name=slot_name,
snippet=snippet,
segment_ids=[segment.id],
db=db,
)
stage_to_segments.setdefault(detected_stage, []).append(segment)
for stage, stage_segments in stage_to_segments.items():
segment_texts = [seg.transcript_text for seg in stage_segments]
combined_text = "\n\n".join(segment_texts)
source_ids = [seg.id for seg in stage_segments]
stmt_chapter = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
)
result_chapter = await db.execute(stmt_chapter)
chapter = result_chapter.scalar_one_or_none()
slot_snippets = {
key: value.snippet for key, value in (state.slots.get(stage, {}) or {}).items() if value.snippet
}
title = chapter.title if chapter else await self.generator.generate_chapter_title(stage, slot_snippets, "neutral")
existing_content = chapter.content if chapter else ""
narrative = await self.generator.generate_narrative(stage, slot_snippets, combined_text, existing_content)
if chapter:
chapter.content = narrative
chapter.title = title
chapter.is_new = True
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
else:
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
content=narrative,
order_index=999,
status="completed",
category=stage,
images=[],
is_new=True,
source_segments=source_ids,
)
db.add(chapter)
await db.flush()
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
result_book = await db.execute(stmt_book)
book = result_book.scalar_one_or_none()
if not book:
book = Book(
id=str(uuid.uuid4()),
user_id=user_id,
title="我的回忆录",
total_pages=0,
total_words=0,
cover_image_url=None,
)
db.add(book)
book.has_update = True
book.last_update_chapter_id = chapter.id
empty_slots = await get_empty_slots(user_id, db)
if not empty_slots:
await mark_stage_complete(user_id, state.current_stage, db)
for seg in segments:
seg.processed = True
await db.commit()
async def flush_pending(self, user_id: str) -> None:
"""
立即提交用户的待处理任务(用于对话结束时)
"""
# 取消定时器
if user_id in self._timers:
self._timers[user_id].cancel()
del self._timers[user_id]
# 提交待处理任务
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
from tasks.memoir_tasks import process_memoir_segments
process_memoir_segments.delay(user_id, segment_ids)
logger.info(f"立即提交 Celery 任务: user_id={user_id}, segments={len(segment_ids)}")

View File

@@ -1,7 +1,9 @@
"""
回忆录整理 Agent基于传记结构将口语改写为书面语归类到章节
支持异步调用
"""
import json
import logging
from typing import List, Dict, Optional
from services.llm_service import llm_service
@@ -14,17 +16,19 @@ from .prompts import (
CHAPTER_ORDER
)
logger = logging.getLogger(__name__)
class MemoryAgent:
"""回忆录整理 Agent"""
"""回忆录整理 Agent(支持异步)"""
def __init__(self):
# 使用 LLM 服务获取 LLM 实例
self.llm = llm_service.get_llm()
def classify_chapter(self, segments_text: str) -> str:
async def classify_chapter(self, segments_text: str) -> str:
"""
分类章节
异步分类章节
Args:
segments_text: 对话段落文本
@@ -36,28 +40,34 @@ class MemoryAgent:
# 如果没有配置 LLM返回默认类别
return "childhood"
prompt = get_chapter_classification_prompt(segments_text)
response = self.llm.invoke(prompt)
# 提取类别
category = response.content.strip().lower()
# 验证类别是否有效
if category in CHAPTER_CATEGORIES:
return category
try:
prompt = get_chapter_classification_prompt(segments_text)
# 异步调用 LLM
response = await self.llm.ainvoke(prompt)
# 提取类别
content = response.content if hasattr(response, 'content') else str(response)
category = content.strip().lower()
# 验证类别是否有效
if category in CHAPTER_CATEGORIES:
return category
except Exception as e:
logger.error(f"分类章节失败: {e}")
# 默认返回 childhood
return "childhood"
def rewrite_to_literary(
async def rewrite_to_literary(
self,
segments_text: str,
chapter_category: str,
existing_content: Optional[str] = None
) -> Dict:
"""
将口语改写为书面语
异步将口语改写为书面语
Args:
segments_text: 对话段落文本
@@ -76,14 +86,16 @@ class MemoryAgent:
"image_suggestions": []
}
prompt = get_text_rewrite_prompt(segments_text, chapter_category, existing_content or "")
response = self.llm.invoke(prompt)
# 尝试解析 JSON
try:
# 提取 JSON 部分
content = response.content.strip()
prompt = get_text_rewrite_prompt(segments_text, chapter_category, existing_content or "")
# 异步调用 LLM
response = await self.llm.ainvoke(prompt)
# 尝试解析 JSON
content = response.content if hasattr(response, 'content') else str(response)
content = content.strip()
# 移除可能的 markdown 代码块标记
if content.startswith("```json"):
content = content[7:]
@@ -95,22 +107,31 @@ class MemoryAgent:
result = json.loads(content)
return result
except json.JSONDecodeError:
# 如果解析失败,返回基本结构
return {
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
"content": response.content,
"content": response.content if hasattr(response, 'content') else str(response),
"summary": "",
"image_suggestions": []
}
except Exception as e:
logger.error(f"改写文本失败: {e}")
return {
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
"content": segments_text,
"summary": "",
"image_suggestions": []
}
def process_segments(
async def process_segments(
self,
segments: List[Dict],
existing_chapters: Optional[Dict[str, Dict]] = None
) -> Dict[str, Dict]:
"""
处理对话段落,生成或更新章节
异步处理对话段落,生成或更新章节
Args:
segments: 对话段落列表,每个包含 transcript_text
@@ -130,8 +151,8 @@ class MemoryAgent:
if not text:
continue
# 分类
category = self.classify_chapter(text)
# 异步分类
category = await self.classify_chapter(text)
if category not in segments_by_category:
segments_by_category[category] = []
@@ -145,8 +166,8 @@ class MemoryAgent:
combined_text = "\n\n".join(texts)
existing_content = existing_chapters.get(category, {}).get("content", "")
# 改写为书面语
result = self.rewrite_to_literary(combined_text, category, existing_content)
# 异步改写为书面语
result = await self.rewrite_to_literary(combined_text, category, existing_content)
# 更新章节
updated_chapters[category] = {