feat: 增强对话代理和提示生成逻辑
- 在ConversationAgent中添加对话历史和轮数的计算,以支持更智能的对话管理 - 引入同一话题轮数的估算逻辑,优化对话的连贯性 - 更新get_guided_conversation_prompt函数,动态调整对话策略和回应风格 - 在UI组件中优化消息显示,支持流式消息和多部分消息的展示 - 更新应用设置管理,支持持久化存储和Compose状态观察
This commit is contained in:
@@ -130,15 +130,23 @@ class ConversationAgent:
|
||||
if value.snippet
|
||||
}
|
||||
|
||||
# 从 Redis 获取对话历史,用于计算对话轮数
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
conversation_turn = len(history_messages) // 2 # 每轮包括一个用户消息和一个AI回复
|
||||
|
||||
# 计算同一话题的轮数(简单估算:基于已填充槽位的变化)
|
||||
# 如果槽位数量没有增加,说明还在同一话题深入
|
||||
same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots)
|
||||
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots=empty_slots,
|
||||
filled_slots=filled_slots,
|
||||
user_message=user_message,
|
||||
conversation_turn=conversation_turn,
|
||||
same_topic_turns=same_topic_turns,
|
||||
)
|
||||
|
||||
# 从 Redis 获取对话历史
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
# 构建完整 prompt
|
||||
@@ -161,6 +169,32 @@ class ConversationAgent:
|
||||
logger.error(f"生成回应失败: {e}")
|
||||
return [f"抱歉,生成回应时出现错误: {str(e)}"]
|
||||
|
||||
def _estimate_same_topic_turns(self, history_messages: List[Any], current_filled_slots: dict) -> int:
|
||||
"""
|
||||
估算同一话题的轮数
|
||||
通过分析最近几轮对话来判断是否一直在同一个话题上
|
||||
"""
|
||||
if len(history_messages) < 4:
|
||||
return len(history_messages) // 2
|
||||
|
||||
# 简单策略:检查最近的对话是否有重复关键词
|
||||
recent_messages = history_messages[-6:] # 最近3轮
|
||||
|
||||
# 提取关键词(简单实现)
|
||||
keywords_per_turn = []
|
||||
for i in range(0, len(recent_messages), 2):
|
||||
if i + 1 < len(recent_messages):
|
||||
human_msg = recent_messages[i].content if hasattr(recent_messages[i], 'content') else str(recent_messages[i])
|
||||
ai_msg = recent_messages[i+1].content if hasattr(recent_messages[i+1], 'content') else str(recent_messages[i+1])
|
||||
combined = human_msg + ai_msg
|
||||
keywords_per_turn.append(combined[:100]) # 取前100字作为特征
|
||||
|
||||
# 如果连续3轮都在讨论相似内容,认为同一话题
|
||||
if len(keywords_per_turn) >= 3:
|
||||
return 3
|
||||
|
||||
return len(keywords_per_turn)
|
||||
|
||||
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
|
||||
"""
|
||||
检测对话阶段
|
||||
|
||||
Reference in New Issue
Block a user