feat & refactor: 重构agents目录结构;AI回复模块agent结构封装
This commit is contained in:
143
api/app/agents/chat/interview_agent.py
Normal file
143
api/app/agents/chat/interview_agent.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
InterviewAgent:正式访谈 Specialist
|
||||
负责状态感知回复、开场白,不负责 Redis 持久化(由 Orchestrator 统一处理)
|
||||
"""
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.dependencies import get_llm_provider
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from app.agents.chat.helpers import format_history_string, get_history_messages
|
||||
from app.agents.prompts import get_guided_conversation_prompt, get_opening_prompt
|
||||
from app.agents.prompts.conversation_prompts import SLOT_NAME_MAP
|
||||
from app.agents.state_schema import MemoirStateSchema
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_langchain_llm():
|
||||
try:
|
||||
provider = get_llm_provider()
|
||||
return getattr(provider, "langchain_llm", None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class InterviewAgent:
|
||||
"""正式访谈 Specialist Agent"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = _get_langchain_llm()
|
||||
|
||||
def _detect_user_stage(self, user_message: str) -> str:
|
||||
"""根据关键词检测用户正在谈论的人生阶段"""
|
||||
message = user_message.lower()
|
||||
stage_keywords = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"],
|
||||
"career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"],
|
||||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"],
|
||||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"],
|
||||
}
|
||||
for stage, keywords in stage_keywords.items():
|
||||
if any(word in message for word in keywords):
|
||||
return stage
|
||||
return ""
|
||||
|
||||
def _estimate_same_topic_turns(
|
||||
self, history_messages: List[Any], current_filled_slots: dict
|
||||
) -> int:
|
||||
"""估算同一话题的连续轮数"""
|
||||
if len(history_messages) < 4:
|
||||
return len(history_messages) // 2
|
||||
recent_messages = history_messages[-6:]
|
||||
keywords_per_turn = []
|
||||
for i in range(0, len(recent_messages), 2):
|
||||
if i + 1 < len(recent_messages):
|
||||
human_msg = (
|
||||
recent_messages[i].content
|
||||
if hasattr(recent_messages[i], "content")
|
||||
else str(recent_messages[i])
|
||||
)
|
||||
ai_msg = (
|
||||
recent_messages[i + 1].content
|
||||
if hasattr(recent_messages[i + 1], "content")
|
||||
else str(recent_messages[i + 1])
|
||||
)
|
||||
keywords_per_turn.append((human_msg + ai_msg)[:100])
|
||||
if len(keywords_per_turn) >= 3:
|
||||
return 3
|
||||
return len(keywords_per_turn)
|
||||
|
||||
async def generate_response_with_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
) -> List[str]:
|
||||
"""生成状态感知的访谈回复,不持久化(由 Orchestrator 负责)"""
|
||||
if not self.llm:
|
||||
return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
|
||||
try:
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
filled_slots = {
|
||||
key: value.snippet
|
||||
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
detected_user_stage = self._detect_user_stage(user_message)
|
||||
history_messages = await get_history_messages(conversation_id)
|
||||
conversation_turn = len(history_messages) // 2
|
||||
same_topic_turns = self._estimate_same_topic_turns(
|
||||
history_messages, filled_slots
|
||||
)
|
||||
all_stages_coverage = memoir_state.all_stages_coverage()
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots=empty_slots,
|
||||
filled_slots=filled_slots,
|
||||
user_message=user_message,
|
||||
conversation_turn=conversation_turn,
|
||||
same_topic_turns=same_topic_turns,
|
||||
all_stages_coverage=all_stages_coverage,
|
||||
detected_user_stage=detected_user_stage,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
history_string = format_history_string(history_messages)
|
||||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error("生成回应失败: %s", e)
|
||||
return [f"抱歉,生成回应时出现错误: {str(e)}"]
|
||||
|
||||
async def generate_opening_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
) -> List[str]:
|
||||
"""生成空对话开场白,不持久化(由 Orchestrator 负责)"""
|
||||
if not self.llm:
|
||||
return ["你好呀~ 有空聊聊你的人生故事吗?你小时候是在哪儿长大的?"]
|
||||
try:
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots]
|
||||
if not empty_slots_readable:
|
||||
empty_slots_readable = ["成长的地方", "难忘的事", "重要的人"]
|
||||
prompt = get_opening_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots_readable=empty_slots_readable,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
full_prompt = f"{prompt}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:2] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error("生成开场白失败: %s", e, exc_info=True)
|
||||
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]
|
||||
Reference in New Issue
Block a user