feat: 添加Redis支持和Celery任务处理

- 新增Redis服务模块用于会话状态存储和缓存
- 集成Celery用于后台任务处理
- 更新Docker Compose配置以支持开发环境
- 优化API以支持异步调用和Redis会话存储
- 更新文档以反映新的开发环境配置和使用方法
This commit is contained in:
penghanyuan
2026-01-21 23:06:47 +01:00
parent 44bd478c1e
commit dbbb924625
16 changed files with 1339 additions and 309 deletions

View File

@@ -1,37 +1,54 @@
"""
对话 Agent基于访谈问题清单动态选择问题实时生成回应
支持异步调用和 Redis 会话存储
"""
from typing import List, Optional
import logging
from typing import List, Optional, Dict, Any
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from services.llm_service import llm_service
from services.redis_service import redis_service
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt
from .state_schema import MemoirStateSchema
logger = logging.getLogger(__name__)
class ConversationAgent:
"""对话 Agent"""
"""对话 Agent(支持异步和 Redis 存储)"""
def __init__(self):
# 使用 LLM 服务获取 LLM 实例
self.llm = llm_service.get_llm()
# 对话记忆
self.memories: dict[str, ConversationBufferMemory] = {}
def _get_memory(self, conversation_id: str) -> ConversationBufferMemory:
"""获取或创建对话记忆"""
if conversation_id not in self.memories:
self.memories[conversation_id] = ConversationBufferMemory(
return_messages=True,
memory_key="history"
)
return self.memories[conversation_id]
async def _get_history_messages(self, conversation_id: str) -> List[Any]:
"""从 Redis 获取对话历史并转换为 LangChain 消息格式"""
history = await redis_service.get_conversation_history(conversation_id)
messages = []
for msg in history:
if msg["role"] == "human":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "ai":
messages.append(AIMessage(content=msg["content"]))
return messages
def generate_response(
async def _save_message(self, conversation_id: str, role: str, content: str):
"""保存消息到 Redis"""
await redis_service.add_message(conversation_id, role, content)
def _format_history_string(self, messages: List[Any]) -> str:
"""将消息列表格式化为字符串(用于 prompt"""
history_parts = []
for msg in messages:
if isinstance(msg, HumanMessage):
history_parts.append(f"Human: {msg.content}")
elif isinstance(msg, AIMessage):
history_parts.append(f"Assistant: {msg.content}")
return "\n\n".join(history_parts)
async def generate_response(
self,
conversation_id: str,
user_message: str,
@@ -39,7 +56,7 @@ class ConversationAgent:
covered_topics: Optional[List[str]] = None
) -> str:
"""
生成 Agent 回应
异步生成 Agent 回应
Args:
conversation_id: 对话 ID
@@ -60,38 +77,39 @@ class ConversationAgent:
if not self.llm:
return "抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
# 获取系统提示词
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
# 获取对话记忆
memory = self._get_memory(conversation_id)
# 创建对话链
prompt_template = PromptTemplate(
input_variables=["history", "input"],
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
)
chain = ConversationChain(
llm=self.llm,
prompt=prompt_template,
memory=memory,
verbose=False
)
# 生成回应
response = chain.predict(input=user_message)
return response
try:
# 获取系统提示词
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
# 从 Redis 获取对话历史
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
# 构建完整 prompt
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
# 异步调用 LLM
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
# 保存对话到 Redis
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
return response_text
except Exception as e:
logger.error(f"生成回应失败: {e}")
return f"抱歉,生成回应时出现错误: {str(e)}"
def generate_response_with_state(
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema
) -> List[str]:
"""
基于共享状态生成引导式回复
基于共享状态异步生成引导式回复
Args:
conversation_id: 对话 ID
@@ -104,39 +122,44 @@ class ConversationAgent:
if not self.llm:
return ["抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
empty_slots = memoir_state.empty_slots_for_current_stage()
filled_slots = {
key: value.snippet
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
if value.snippet
}
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
filled_slots = {
key: value.snippet
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
if value.snippet
}
system_prompt = get_guided_conversation_prompt(
current_stage=memoir_state.current_stage,
empty_slots=empty_slots,
filled_slots=filled_slots,
user_message=user_message,
)
system_prompt = get_guided_conversation_prompt(
current_stage=memoir_state.current_stage,
empty_slots=empty_slots,
filled_slots=filled_slots,
user_message=user_message,
)
memory = self._get_memory(conversation_id)
prompt_template = PromptTemplate(
input_variables=["history", "input"],
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
)
chain = ConversationChain(
llm=self.llm,
prompt=prompt_template,
memory=memory,
verbose=False
)
response = chain.predict(input=user_message)
# 支持多条消息,用 [SPLIT] 分隔
messages = [msg.strip() for msg in response.split("[SPLIT]") if msg.strip()]
# 最多返回 3 条
return messages[:3] if messages else [response]
# 从 Redis 获取对话历史
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
# 构建完整 prompt
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
# 异步调用 LLM
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
# 保存对话到 Redis
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
# 支持多条消息,用 [SPLIT] 分隔
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
# 最多返回 3 条
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error(f"生成回应失败: {e}")
return [f"抱歉,生成回应时出现错误: {str(e)}"]
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
"""
@@ -168,8 +191,7 @@ class ConversationAgent:
# 默认返回当前阶段或童年阶段
return ConversationStage.CHILDHOOD
def clear_memory(self, conversation_id: str):
"""清除对话记忆"""
if conversation_id in self.memories:
del self.memories[conversation_id]
async def clear_memory(self, conversation_id: str):
"""清除对话记忆(从 Redis"""
await redis_service.clear_conversation_history(conversation_id)