feat: 添加Redis支持和Celery任务处理
- 新增Redis服务模块用于会话状态存储和缓存 - 集成Celery用于后台任务处理 - 更新Docker Compose配置以支持开发环境 - 优化API以支持异步调用和Redis会话存储 - 更新文档以反映新的开发环境配置和使用方法
This commit is contained in:
@@ -1,37 +1,54 @@
|
||||
"""
|
||||
对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应
|
||||
支持异步调用和 Redis 会话存储
|
||||
"""
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from services.llm_service import llm_service
|
||||
from services.redis_service import redis_service
|
||||
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt
|
||||
from .state_schema import MemoirStateSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationAgent:
|
||||
"""对话 Agent"""
|
||||
"""对话 Agent(支持异步和 Redis 存储)"""
|
||||
|
||||
def __init__(self):
|
||||
# 使用 LLM 服务获取 LLM 实例
|
||||
self.llm = llm_service.get_llm()
|
||||
|
||||
# 对话记忆
|
||||
self.memories: dict[str, ConversationBufferMemory] = {}
|
||||
|
||||
def _get_memory(self, conversation_id: str) -> ConversationBufferMemory:
|
||||
"""获取或创建对话记忆"""
|
||||
if conversation_id not in self.memories:
|
||||
self.memories[conversation_id] = ConversationBufferMemory(
|
||||
return_messages=True,
|
||||
memory_key="history"
|
||||
)
|
||||
return self.memories[conversation_id]
|
||||
async def _get_history_messages(self, conversation_id: str) -> List[Any]:
|
||||
"""从 Redis 获取对话历史并转换为 LangChain 消息格式"""
|
||||
history = await redis_service.get_conversation_history(conversation_id)
|
||||
messages = []
|
||||
for msg in history:
|
||||
if msg["role"] == "human":
|
||||
messages.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "ai":
|
||||
messages.append(AIMessage(content=msg["content"]))
|
||||
return messages
|
||||
|
||||
def generate_response(
|
||||
async def _save_message(self, conversation_id: str, role: str, content: str):
|
||||
"""保存消息到 Redis"""
|
||||
await redis_service.add_message(conversation_id, role, content)
|
||||
|
||||
def _format_history_string(self, messages: List[Any]) -> str:
|
||||
"""将消息列表格式化为字符串(用于 prompt)"""
|
||||
history_parts = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
history_parts.append(f"Human: {msg.content}")
|
||||
elif isinstance(msg, AIMessage):
|
||||
history_parts.append(f"Assistant: {msg.content}")
|
||||
return "\n\n".join(history_parts)
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
@@ -39,7 +56,7 @@ class ConversationAgent:
|
||||
covered_topics: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""
|
||||
生成 Agent 回应
|
||||
异步生成 Agent 回应
|
||||
|
||||
Args:
|
||||
conversation_id: 对话 ID
|
||||
@@ -60,38 +77,39 @@ class ConversationAgent:
|
||||
if not self.llm:
|
||||
return "抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
|
||||
|
||||
# 获取系统提示词
|
||||
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
|
||||
|
||||
# 获取对话记忆
|
||||
memory = self._get_memory(conversation_id)
|
||||
|
||||
# 创建对话链
|
||||
prompt_template = PromptTemplate(
|
||||
input_variables=["history", "input"],
|
||||
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
|
||||
)
|
||||
|
||||
chain = ConversationChain(
|
||||
llm=self.llm,
|
||||
prompt=prompt_template,
|
||||
memory=memory,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# 生成回应
|
||||
response = chain.predict(input=user_message)
|
||||
|
||||
return response
|
||||
try:
|
||||
# 获取系统提示词
|
||||
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
|
||||
|
||||
# 从 Redis 获取对话历史
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
# 构建完整 prompt
|
||||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
|
||||
# 异步调用 LLM
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 保存对话到 Redis
|
||||
await self._save_message(conversation_id, "human", user_message)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成回应失败: {e}")
|
||||
return f"抱歉,生成回应时出现错误: {str(e)}"
|
||||
|
||||
def generate_response_with_state(
|
||||
async def generate_response_with_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema
|
||||
) -> List[str]:
|
||||
"""
|
||||
基于共享状态生成引导式回复
|
||||
基于共享状态异步生成引导式回复
|
||||
|
||||
Args:
|
||||
conversation_id: 对话 ID
|
||||
@@ -104,39 +122,44 @@ class ConversationAgent:
|
||||
if not self.llm:
|
||||
return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
|
||||
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
filled_slots = {
|
||||
key: value.snippet
|
||||
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
try:
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
filled_slots = {
|
||||
key: value.snippet
|
||||
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots=empty_slots,
|
||||
filled_slots=filled_slots,
|
||||
user_message=user_message,
|
||||
)
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots=empty_slots,
|
||||
filled_slots=filled_slots,
|
||||
user_message=user_message,
|
||||
)
|
||||
|
||||
memory = self._get_memory(conversation_id)
|
||||
prompt_template = PromptTemplate(
|
||||
input_variables=["history", "input"],
|
||||
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
|
||||
)
|
||||
|
||||
chain = ConversationChain(
|
||||
llm=self.llm,
|
||||
prompt=prompt_template,
|
||||
memory=memory,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
response = chain.predict(input=user_message)
|
||||
|
||||
# 支持多条消息,用 [SPLIT] 分隔
|
||||
messages = [msg.strip() for msg in response.split("[SPLIT]") if msg.strip()]
|
||||
# 最多返回 3 条
|
||||
return messages[:3] if messages else [response]
|
||||
# 从 Redis 获取对话历史
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
# 构建完整 prompt
|
||||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
|
||||
# 异步调用 LLM
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 保存对话到 Redis
|
||||
await self._save_message(conversation_id, "human", user_message)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
# 支持多条消息,用 [SPLIT] 分隔
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
# 最多返回 3 条
|
||||
return messages[:3] if messages else [response_text]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成回应失败: {e}")
|
||||
return [f"抱歉,生成回应时出现错误: {str(e)}"]
|
||||
|
||||
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
|
||||
"""
|
||||
@@ -168,8 +191,7 @@ class ConversationAgent:
|
||||
# 默认返回当前阶段或童年阶段
|
||||
return ConversationStage.CHILDHOOD
|
||||
|
||||
def clear_memory(self, conversation_id: str):
|
||||
"""清除对话记忆"""
|
||||
if conversation_id in self.memories:
|
||||
del self.memories[conversation_id]
|
||||
async def clear_memory(self, conversation_id: str):
|
||||
"""清除对话记忆(从 Redis)"""
|
||||
await redis_service.clear_conversation_history(conversation_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user