Merge branch 'feat/improve-agent-prompt'
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应
|
||||
支持异步调用和 Redis 会话存储
|
||||
支持用户基础资料收集和时代背景融入
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
@@ -11,6 +13,13 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from services.llm_service import llm_service
|
||||
from services.redis_service import redis_service
|
||||
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt
|
||||
from .prompts.profile_prompts import (
|
||||
get_profile_greeting_prompt,
|
||||
get_profile_extraction_prompt,
|
||||
get_profile_followup_prompt,
|
||||
format_user_profile_context,
|
||||
get_missing_profile_fields,
|
||||
)
|
||||
from .state_schema import MemoirStateSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -102,6 +111,87 @@ class ConversationAgent:
|
||||
logger.error(f"生成回应失败: {e}")
|
||||
return f"抱歉,生成回应时出现错误: {str(e)}"
|
||||
|
||||
async def generate_profile_greeting(
|
||||
self,
|
||||
conversation_id: str,
|
||||
missing_fields: List[str],
|
||||
nickname: str = "",
|
||||
) -> List[str]:
|
||||
"""生成资料收集的开场白(首次对话时使用)"""
|
||||
if not self.llm:
|
||||
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
|
||||
|
||||
try:
|
||||
prompt = get_profile_greeting_prompt(missing_fields, nickname)
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error(f"生成资料收集开场白失败: {e}")
|
||||
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
|
||||
|
||||
async def extract_profile_from_message(self, user_message: str, missing_fields: List[str]) -> Dict[str, Any]:
|
||||
"""从用户消息中提取基础资料信息"""
|
||||
if not self.llm or not missing_fields:
|
||||
return {}
|
||||
|
||||
try:
|
||||
prompt = get_profile_extraction_prompt(user_message, missing_fields)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content.strip()
|
||||
parsed = json.loads(content)
|
||||
result = {}
|
||||
if "birth_year" in parsed and isinstance(parsed["birth_year"], int):
|
||||
result["birth_year"] = parsed["birth_year"]
|
||||
if "birth_place" in parsed and parsed["birth_place"]:
|
||||
result["birth_place"] = str(parsed["birth_place"])
|
||||
if "grew_up_place" in parsed and parsed["grew_up_place"]:
|
||||
result["grew_up_place"] = str(parsed["grew_up_place"])
|
||||
if "occupation" in parsed and parsed["occupation"]:
|
||||
result["occupation"] = str(parsed["occupation"])
|
||||
return result
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.error(f"提取资料信息失败: {e}")
|
||||
return {}
|
||||
|
||||
async def generate_profile_followup(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
missing_fields: List[str],
|
||||
filled_fields: Dict[str, str],
|
||||
nickname: str = "",
|
||||
) -> List[str]:
|
||||
"""在资料收集过程中生成跟进回复"""
|
||||
if not self.llm:
|
||||
return ["谢谢!还能告诉我更多吗?"]
|
||||
|
||||
try:
|
||||
prompt = get_profile_followup_prompt(missing_fields, filled_fields, user_message, nickname)
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
await self._save_message(conversation_id, "human", user_message)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error(f"生成资料跟进回复失败: {e}")
|
||||
return ["谢谢分享!能再告诉我一些吗?"]
|
||||
|
||||
def _detect_user_stage(self, user_message: str) -> str:
|
||||
"""
|
||||
通过关键词检测用户当前正在谈论的人生阶段。
|
||||
@@ -126,7 +216,8 @@ class ConversationAgent:
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
) -> List[str]:
|
||||
"""
|
||||
基于共享状态异步生成引导式回复
|
||||
@@ -135,6 +226,7 @@ class ConversationAgent:
|
||||
conversation_id: 对话 ID
|
||||
user_message: 用户消息
|
||||
memoir_state: 共享状态
|
||||
user_profile_context: 用户基础资料上下文
|
||||
|
||||
Returns:
|
||||
Agent 回应文本列表(支持多条消息)
|
||||
@@ -150,18 +242,11 @@ class ConversationAgent:
|
||||
if value.snippet
|
||||
}
|
||||
|
||||
# 检测用户当前正在谈论的阶段
|
||||
detected_user_stage = self._detect_user_stage(user_message)
|
||||
|
||||
# 从 Redis 获取对话历史,用于计算对话轮数
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
conversation_turn = len(history_messages) // 2 # 每轮包括一个用户消息和一个AI回复
|
||||
|
||||
# 计算同一话题的轮数(简单估算:基于已填充槽位的变化)
|
||||
# 如果槽位数量没有增加,说明还在同一话题深入
|
||||
conversation_turn = len(history_messages) // 2
|
||||
same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots)
|
||||
|
||||
# 获取所有阶段的覆盖情况
|
||||
all_stages_coverage = memoir_state.all_stages_coverage()
|
||||
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
@@ -173,24 +258,19 @@ class ConversationAgent:
|
||||
same_topic_turns=same_topic_turns,
|
||||
all_stages_coverage=all_stages_coverage,
|
||||
detected_user_stage=detected_user_stage,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
# 构建完整 prompt
|
||||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
|
||||
# 异步调用 LLM
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 保存对话到 Redis
|
||||
await self._save_message(conversation_id, "human", user_message)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
# 支持多条消息,用 [SPLIT] 分隔
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
# 最多返回 3 条
|
||||
return messages[:3] if messages else [response_text]
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user