Files
life-echo/api/app/agents/chat/profile_agent.py

133 lines
5.6 KiB
Python
Raw Normal View History

"""
ProfileAgent用户资料收集 Specialist
负责提取资料资料追问资料收集开场白不负责 Redis 持久化 Orchestrator 统一处理
"""
import json
from typing import Any, Dict, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from app.core.dependencies import get_llm_provider
from app.core.logging import get_logger
from app.agents.chat.helpers import format_history_string, get_history_messages
from app.agents.chat.prompts_profile import (
get_profile_extraction_prompt,
get_profile_followup_prompt,
get_profile_greeting_prompt,
)
logger = get_logger(__name__)
def _get_langchain_llm():
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
class ProfileAgent:
"""用户资料收集 Specialist Agent"""
def __init__(self):
self.llm = _get_langchain_llm()
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""从用户消息中提取资料字段,不持久化"""
if not self.llm or not missing_fields:
return {}
recent_dialogue = ""
if conversation_id:
history_messages = await get_history_messages(conversation_id)
recent = history_messages[-4:] if len(history_messages) > 4 else history_messages
parts = []
for msg in recent:
if isinstance(msg, HumanMessage):
parts.append(f"用户: {msg.content}")
elif isinstance(msg, AIMessage):
parts.append(f"助手: {msg.content}")
recent_dialogue = "\n".join(parts) if parts else ""
try:
prompt = get_profile_extraction_prompt(
user_message, missing_fields, recent_dialogue=recent_dialogue or None
)
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
result = {}
if "birth_year" in parsed and parsed["birth_year"] is not None:
raw = parsed["birth_year"]
if isinstance(raw, int) and 1900 <= raw <= 2100:
result["birth_year"] = raw
elif isinstance(raw, str) and raw.isdigit():
y = int(raw)
if y < 100:
y = 1900 + y if y >= 50 else 2000 + y
if 1900 <= y <= 2100:
result["birth_year"] = y
if "birth_place" in parsed and parsed["birth_place"]:
result["birth_place"] = str(parsed["birth_place"])
if "grew_up_place" in parsed and parsed["grew_up_place"]:
result["grew_up_place"] = str(parsed["grew_up_place"])
if "occupation" in parsed and parsed["occupation"]:
result["occupation"] = str(parsed["occupation"])
return result
except (json.JSONDecodeError, Exception) as e:
logger.error("提取资料信息失败: %s", e)
return {}
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: Dict[str, str],
nickname: str = "",
) -> List[str]:
"""生成资料追问回复,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return ["谢谢!还能告诉我更多吗?"]
try:
prompt = get_profile_followup_prompt(
missing_fields, filled_fields, user_message, nickname
)
history_messages = await get_history_messages(conversation_id)
history_string = format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error("生成资料跟进回复失败: %s", e)
return ["谢谢分享!能再告诉我一些吗?"]
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
"""生成资料收集开场白,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
try:
prompt = get_profile_greeting_prompt(missing_fields, nickname)
history_messages = await get_history_messages(conversation_id)
history_string = format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:2] if messages else [response_text]
except Exception as e:
logger.error("生成资料收集开场白失败: %s", e)
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]