Merge branch 'feat/improve-agent-prompt'
This commit is contained in:
@@ -28,6 +28,18 @@ class UserProfileResponse(BaseModel):
|
||||
avatar_url: Optional[str]
|
||||
subscription_type: str
|
||||
created_at: str
|
||||
birth_year: Optional[int] = None
|
||||
birth_place: Optional[str] = None
|
||||
grew_up_place: Optional[str] = None
|
||||
occupation: Optional[str] = None
|
||||
|
||||
|
||||
class UpdateUserProfileRequest(BaseModel):
|
||||
"""更新用户基础资料请求"""
|
||||
birth_year: Optional[int] = None
|
||||
birth_place: Optional[str] = None
|
||||
grew_up_place: Optional[str] = None
|
||||
occupation: Optional[str] = None
|
||||
|
||||
|
||||
class TestSubscriptionRequest(BaseModel):
|
||||
@@ -59,7 +71,43 @@ async def get_user_profile(
|
||||
nickname=current_user.nickname,
|
||||
avatar_url=current_user.avatar_url,
|
||||
subscription_type=current_user.subscription_type,
|
||||
created_at=current_user.created_at.isoformat()
|
||||
created_at=current_user.created_at.isoformat(),
|
||||
birth_year=current_user.birth_year,
|
||||
birth_place=current_user.birth_place,
|
||||
grew_up_place=current_user.grew_up_place,
|
||||
occupation=current_user.occupation,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/profile", response_model=UserProfileResponse)
|
||||
async def update_user_profile(
|
||||
body: UpdateUserProfileRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
):
|
||||
"""更新用户基础资料(出生年份、出生地、成长地、职业)"""
|
||||
if body.birth_year is not None:
|
||||
current_user.birth_year = body.birth_year
|
||||
if body.birth_place is not None:
|
||||
current_user.birth_place = body.birth_place
|
||||
if body.grew_up_place is not None:
|
||||
current_user.grew_up_place = body.grew_up_place
|
||||
if body.occupation is not None:
|
||||
current_user.occupation = body.occupation
|
||||
await db.commit()
|
||||
await db.refresh(current_user)
|
||||
return UserProfileResponse(
|
||||
id=current_user.id,
|
||||
phone=current_user.phone,
|
||||
email=current_user.email,
|
||||
nickname=current_user.nickname,
|
||||
avatar_url=current_user.avatar_url,
|
||||
subscription_type=current_user.subscription_type,
|
||||
created_at=current_user.created_at.isoformat(),
|
||||
birth_year=current_user.birth_year,
|
||||
birth_place=current_user.birth_place,
|
||||
grew_up_place=current_user.grew_up_place,
|
||||
occupation=current_user.occupation,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -161,6 +161,28 @@ async def websocket_endpoint(
|
||||
return
|
||||
|
||||
|
||||
# 首次连接时检查资料完整性,发送资料收集开场白
|
||||
missing_profile = _get_missing_profile_fields(user)
|
||||
if missing_profile:
|
||||
try:
|
||||
greetings = await manager.conversation_agent.generate_profile_greeting(
|
||||
conversation_id=conversation_id,
|
||||
missing_fields=missing_profile,
|
||||
nickname=user.nickname or "",
|
||||
)
|
||||
import asyncio as _asyncio_greet
|
||||
for i, text in enumerate(greetings):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": text, "index": i, "total": len(greetings)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
if i < len(greetings) - 1:
|
||||
await _asyncio_greet.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
|
||||
|
||||
# 主循环:处理消息
|
||||
while True:
|
||||
try:
|
||||
@@ -206,7 +228,8 @@ async def websocket_endpoint(
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
manager=manager
|
||||
manager=manager,
|
||||
user=user,
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.AUDIO_MESSAGE:
|
||||
@@ -267,7 +290,8 @@ async def websocket_endpoint(
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
manager=manager
|
||||
manager=manager,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
# 转写失败,发送错误消息
|
||||
@@ -380,58 +404,140 @@ async def websocket_endpoint(
|
||||
await manager.disconnect(conversation_id)
|
||||
|
||||
|
||||
def _get_missing_profile_fields(user: UserModel) -> list:
|
||||
"""检查用户缺失的资料字段"""
|
||||
from agents.prompts.profile_prompts import get_missing_profile_fields
|
||||
return get_missing_profile_fields(
|
||||
birth_year=user.birth_year,
|
||||
birth_place=user.birth_place,
|
||||
grew_up_place=user.grew_up_place,
|
||||
occupation=user.occupation,
|
||||
)
|
||||
|
||||
|
||||
def _get_filled_profile_fields(user: UserModel) -> dict:
|
||||
"""获取用户已有的资料字段(中文展示)"""
|
||||
from agents.prompts.profile_prompts import PROFILE_FIELD_NAMES
|
||||
filled = {}
|
||||
if user.birth_year:
|
||||
filled["birth_year"] = str(user.birth_year)
|
||||
if user.birth_place:
|
||||
filled["birth_place"] = user.birth_place
|
||||
if user.grew_up_place:
|
||||
filled["grew_up_place"] = user.grew_up_place
|
||||
if user.occupation:
|
||||
filled["occupation"] = user.occupation
|
||||
return filled
|
||||
|
||||
|
||||
async def _apply_extracted_profile(user: UserModel, extracted: dict, db: AsyncSession):
|
||||
"""将提取到的资料信息保存到用户模型"""
|
||||
changed = False
|
||||
if "birth_year" in extracted and not user.birth_year:
|
||||
user.birth_year = extracted["birth_year"]
|
||||
changed = True
|
||||
if "birth_place" in extracted and not user.birth_place:
|
||||
user.birth_place = extracted["birth_place"]
|
||||
changed = True
|
||||
if "grew_up_place" in extracted and not user.grew_up_place:
|
||||
user.grew_up_place = extracted["grew_up_place"]
|
||||
changed = True
|
||||
if "occupation" in extracted and not user.occupation:
|
||||
user.occupation = extracted["occupation"]
|
||||
changed = True
|
||||
if changed:
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
async def process_user_message(
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
conversation: Conversation,
|
||||
segment: Segment,
|
||||
db: AsyncSession,
|
||||
manager: ConnectionManager
|
||||
manager: ConnectionManager,
|
||||
user: UserModel = None,
|
||||
) -> None:
|
||||
"""
|
||||
处理用户消息,生成Agent回应(异步版本)
|
||||
|
||||
Args:
|
||||
conversation_id: 对话ID
|
||||
user_message: 用户消息文本
|
||||
conversation: 对话对象
|
||||
segment: 段落对象
|
||||
db: 数据库会话
|
||||
manager: 连接管理器
|
||||
|
||||
Returns:
|
||||
更新后的对话阶段
|
||||
支持资料收集模式和正式访谈模式
|
||||
"""
|
||||
import asyncio as _asyncio
|
||||
|
||||
|
||||
agent = manager.conversation_agent
|
||||
|
||||
# --- 资料收集模式 ---
|
||||
if user:
|
||||
missing = _get_missing_profile_fields(user)
|
||||
if missing:
|
||||
try:
|
||||
extracted = await agent.extract_profile_from_message(user_message, missing)
|
||||
if extracted:
|
||||
await _apply_extracted_profile(user, extracted, db)
|
||||
|
||||
remaining = _get_missing_profile_fields(user)
|
||||
filled = _get_filled_profile_fields(user)
|
||||
responses = await agent.generate_profile_followup(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
missing_fields=remaining,
|
||||
filled_fields=filled,
|
||||
nickname=user.nickname or "",
|
||||
)
|
||||
|
||||
segment.agent_response = "\n\n".join(responses)
|
||||
await db.commit()
|
||||
|
||||
for i, response_text in enumerate(responses):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": response_text, "index": i, "total": len(responses)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
if i < len(responses) - 1:
|
||||
await _asyncio.sleep(0.5)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"资料收集处理失败: {e}", exc_info=True)
|
||||
|
||||
# --- 正式访谈模式 ---
|
||||
state = await get_or_create_state(conversation.user_id, db)
|
||||
|
||||
if conversation.conversation_stage != state.current_stage:
|
||||
conversation.conversation_stage = state.current_stage
|
||||
await db.commit()
|
||||
|
||||
# 获取已聊话题(保留老逻辑用于提示)
|
||||
stmt_segments = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id
|
||||
).order_by(Segment.created_at)
|
||||
result_segments = await db.execute(stmt_segments)
|
||||
previous_segments = result_segments.scalars().all()
|
||||
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
|
||||
|
||||
|
||||
# 构建用户资料上下文
|
||||
user_profile_context = ""
|
||||
if user:
|
||||
from agents.prompts.profile_prompts import format_user_profile_context
|
||||
user_profile_context = format_user_profile_context(
|
||||
birth_year=user.birth_year,
|
||||
birth_place=user.birth_place,
|
||||
grew_up_place=user.grew_up_place,
|
||||
occupation=user.occupation,
|
||||
)
|
||||
|
||||
try:
|
||||
# 异步生成回应(可能是多条消息)
|
||||
responses = await agent.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
memoir_state=state
|
||||
memoir_state=state,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
|
||||
# 更新段落的 Agent 回应(存储完整内容)
|
||||
|
||||
segment.agent_response = "\n\n".join(responses)
|
||||
await db.commit()
|
||||
|
||||
# 发送 Agent 回应(支持多条消息)
|
||||
|
||||
for i, response_text in enumerate(responses):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
@@ -439,13 +545,11 @@ async def process_user_message(
|
||||
"data": {"text": response_text, "index": i, "total": len(responses)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
# 多条消息之间稍作间隔,模拟打字效果
|
||||
if i < len(responses) - 1:
|
||||
await _asyncio.sleep(0.5)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理用户消息失败: {e}", exc_info=True)
|
||||
# 只在连接仍然活跃时发送错误消息
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
@@ -455,8 +559,6 @@ async def process_user_message(
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||||
|
||||
Reference in New Issue
Block a user