Files
life-echo/api/app/agents/chat/profile_agent.py
2026-03-19 14:36:40 +08:00

153 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ProfileAgent用户资料收集 Specialist
负责提取资料、资料追问、资料收集开场白,不负责 Redis 持久化(由 Orchestrator 统一处理)
"""
import json
from typing import Any, Dict, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from app.core.dependencies import get_llm_provider
from app.core.logging import get_logger
from app.agents.chat.helpers import format_history_string, get_history_messages
from app.agents.chat.prompts_profile import (
get_profile_extraction_prompt,
get_profile_followup_prompt,
get_profile_greeting_prompt,
)
from app.features.memoir.memoir_images.json_payload import extract_json_payload
logger = get_logger(__name__)
def _get_langchain_llm():
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
class ProfileAgent:
"""用户资料收集 Specialist Agent"""
def __init__(self):
self.llm = _get_langchain_llm()
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""从用户消息中提取资料字段,不持久化"""
if not self.llm or not missing_fields:
return {}
recent_dialogue = ""
if conversation_id:
history_messages = await get_history_messages(conversation_id)
recent = (
history_messages[-4:] if len(history_messages) > 4 else history_messages
)
parts = []
for msg in recent:
if isinstance(msg, HumanMessage):
parts.append(f"用户: {msg.content}")
elif isinstance(msg, AIMessage):
parts.append(f"助手: {msg.content}")
recent_dialogue = "\n".join(parts) if parts else ""
try:
prompt = get_profile_extraction_prompt(
user_message, missing_fields, recent_dialogue=recent_dialogue or None
)
json_llm = self.llm.bind(
model_kwargs={"response_format": {"type": "json_object"}},
max_tokens=512,
)
response = await json_llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(extract_json_payload(content))
result = {}
if "birth_year" in parsed and parsed["birth_year"] is not None:
raw = parsed["birth_year"]
if isinstance(raw, int) and 1900 <= raw <= 2100:
result["birth_year"] = raw
elif isinstance(raw, str) and raw.isdigit():
y = int(raw)
if y < 100:
y = 1900 + y if y >= 50 else 2000 + y
if 1900 <= y <= 2100:
result["birth_year"] = y
if "birth_place" in parsed and parsed["birth_place"]:
result["birth_place"] = str(parsed["birth_place"])
if "grew_up_place" in parsed and parsed["grew_up_place"]:
result["grew_up_place"] = str(parsed["grew_up_place"])
if "occupation" in parsed and parsed["occupation"]:
result["occupation"] = str(parsed["occupation"])
return result
except (json.JSONDecodeError, Exception) as e:
logger.error("提取资料信息失败: %s", e)
return {}
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: Dict[str, str],
nickname: str = "",
) -> List[str]:
"""生成资料追问回复,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return ["谢谢!还能告诉我更多吗?"]
try:
prompt = get_profile_followup_prompt(
missing_fields, filled_fields, user_message, nickname
)
history_messages = await get_history_messages(conversation_id)
history_string = format_history_string(history_messages)
full_prompt = (
f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
)
response = await self.llm.ainvoke(full_prompt)
response_text = (
response.content if hasattr(response, "content") else str(response)
)
messages = [
msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()
]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error("生成资料跟进回复失败: %s", e)
return ["谢谢分享!能再告诉我一些吗?"]
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
"""生成资料收集开场白,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
try:
prompt = get_profile_greeting_prompt(missing_fields, nickname)
history_messages = await get_history_messages(conversation_id)
history_string = format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
response = await self.llm.ainvoke(full_prompt)
response_text = (
response.content if hasattr(response, "content") else str(response)
)
messages = [
msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()
]
return messages[:2] if messages else [response_text]
except Exception as e:
logger.error("生成资料收集开场白失败: %s", e)
return [
"你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"
]