Merge branch 'feat/improve-agent-prompt'

This commit is contained in:
penghanyuan
2026-03-01 10:12:23 +01:00
parent a69d5c625c
commit c1e2fb31a0
11 changed files with 644 additions and 65 deletions

View File

@@ -28,6 +28,18 @@ class UserProfileResponse(BaseModel):
avatar_url: Optional[str]
subscription_type: str
created_at: str
birth_year: Optional[int] = None
birth_place: Optional[str] = None
grew_up_place: Optional[str] = None
occupation: Optional[str] = None
class UpdateUserProfileRequest(BaseModel):
"""更新用户基础资料请求"""
birth_year: Optional[int] = None
birth_place: Optional[str] = None
grew_up_place: Optional[str] = None
occupation: Optional[str] = None
class TestSubscriptionRequest(BaseModel):
@@ -59,7 +71,43 @@ async def get_user_profile(
nickname=current_user.nickname,
avatar_url=current_user.avatar_url,
subscription_type=current_user.subscription_type,
created_at=current_user.created_at.isoformat()
created_at=current_user.created_at.isoformat(),
birth_year=current_user.birth_year,
birth_place=current_user.birth_place,
grew_up_place=current_user.grew_up_place,
occupation=current_user.occupation,
)
@router.put("/profile", response_model=UserProfileResponse)
async def update_user_profile(
body: UpdateUserProfileRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db),
):
"""更新用户基础资料(出生年份、出生地、成长地、职业)"""
if body.birth_year is not None:
current_user.birth_year = body.birth_year
if body.birth_place is not None:
current_user.birth_place = body.birth_place
if body.grew_up_place is not None:
current_user.grew_up_place = body.grew_up_place
if body.occupation is not None:
current_user.occupation = body.occupation
await db.commit()
await db.refresh(current_user)
return UserProfileResponse(
id=current_user.id,
phone=current_user.phone,
email=current_user.email,
nickname=current_user.nickname,
avatar_url=current_user.avatar_url,
subscription_type=current_user.subscription_type,
created_at=current_user.created_at.isoformat(),
birth_year=current_user.birth_year,
birth_place=current_user.birth_place,
grew_up_place=current_user.grew_up_place,
occupation=current_user.occupation,
)

View File

@@ -161,6 +161,28 @@ async def websocket_endpoint(
return
# 首次连接时检查资料完整性,发送资料收集开场白
missing_profile = _get_missing_profile_fields(user)
if missing_profile:
try:
greetings = await manager.conversation_agent.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_profile,
nickname=user.nickname or "",
)
import asyncio as _asyncio_greet
for i, text in enumerate(greetings):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": text, "index": i, "total": len(greetings)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
if i < len(greetings) - 1:
await _asyncio_greet.sleep(0.5)
except Exception as e:
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
# 主循环:处理消息
while True:
try:
@@ -206,7 +228,8 @@ async def websocket_endpoint(
conversation=conversation,
segment=segment,
db=db,
manager=manager
manager=manager,
user=user,
)
elif msg_type == MessageType.AUDIO_MESSAGE:
@@ -267,7 +290,8 @@ async def websocket_endpoint(
conversation=conversation,
segment=segment,
db=db,
manager=manager
manager=manager,
user=user,
)
else:
# 转写失败,发送错误消息
@@ -380,58 +404,140 @@ async def websocket_endpoint(
await manager.disconnect(conversation_id)
def _get_missing_profile_fields(user: UserModel) -> list:
"""检查用户缺失的资料字段"""
from agents.prompts.profile_prompts import get_missing_profile_fields
return get_missing_profile_fields(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
def _get_filled_profile_fields(user: UserModel) -> dict:
"""获取用户已有的资料字段(中文展示)"""
from agents.prompts.profile_prompts import PROFILE_FIELD_NAMES
filled = {}
if user.birth_year:
filled["birth_year"] = str(user.birth_year)
if user.birth_place:
filled["birth_place"] = user.birth_place
if user.grew_up_place:
filled["grew_up_place"] = user.grew_up_place
if user.occupation:
filled["occupation"] = user.occupation
return filled
async def _apply_extracted_profile(user: UserModel, extracted: dict, db: AsyncSession):
"""将提取到的资料信息保存到用户模型"""
changed = False
if "birth_year" in extracted and not user.birth_year:
user.birth_year = extracted["birth_year"]
changed = True
if "birth_place" in extracted and not user.birth_place:
user.birth_place = extracted["birth_place"]
changed = True
if "grew_up_place" in extracted and not user.grew_up_place:
user.grew_up_place = extracted["grew_up_place"]
changed = True
if "occupation" in extracted and not user.occupation:
user.occupation = extracted["occupation"]
changed = True
if changed:
await db.commit()
await db.refresh(user)
async def process_user_message(
conversation_id: str,
user_message: str,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
manager: ConnectionManager
manager: ConnectionManager,
user: UserModel = None,
) -> None:
"""
处理用户消息生成Agent回应异步版本
Args:
conversation_id: 对话ID
user_message: 用户消息文本
conversation: 对话对象
segment: 段落对象
db: 数据库会话
manager: 连接管理器
Returns:
更新后的对话阶段
支持资料收集模式和正式访谈模式
"""
import asyncio as _asyncio
agent = manager.conversation_agent
# --- 资料收集模式 ---
if user:
missing = _get_missing_profile_fields(user)
if missing:
try:
extracted = await agent.extract_profile_from_message(user_message, missing)
if extracted:
await _apply_extracted_profile(user, extracted, db)
remaining = _get_missing_profile_fields(user)
filled = _get_filled_profile_fields(user)
responses = await agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=remaining,
filled_fields=filled,
nickname=user.nickname or "",
)
segment.agent_response = "\n\n".join(responses)
await db.commit()
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
if i < len(responses) - 1:
await _asyncio.sleep(0.5)
return
except Exception as e:
logger.error(f"资料收集处理失败: {e}", exc_info=True)
# --- 正式访谈模式 ---
state = await get_or_create_state(conversation.user_id, db)
if conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
# 获取已聊话题(保留老逻辑用于提示)
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
# 构建用户资料上下文
user_profile_context = ""
if user:
from agents.prompts.profile_prompts import format_user_profile_context
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
try:
# 异步生成回应(可能是多条消息)
responses = await agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state
memoir_state=state,
user_profile_context=user_profile_context,
)
# 更新段落的 Agent 回应(存储完整内容)
segment.agent_response = "\n\n".join(responses)
await db.commit()
# 发送 Agent 回应(支持多条消息)
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
@@ -439,13 +545,11 @@ async def process_user_message(
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 多条消息之间稍作间隔,模拟打字效果
if i < len(responses) - 1:
await _asyncio.sleep(0.5)
except Exception as e:
logger.error(f"处理用户消息失败: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
@@ -455,8 +559,6 @@ async def process_user_message(
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
return
async def process_conversation_segments(conversation_id: str, db: AsyncSession):