fix: 修复打个招呼的bug

This commit is contained in:
yangshilin
2026-03-13 14:50:04 +08:00
parent 2eb066dbec
commit ab4fb46c0d
4 changed files with 55 additions and 28 deletions

View File

@@ -133,8 +133,9 @@ class ConversationAgent:
await self._save_message(conversation_id, "ai", response_text) await self._save_message(conversation_id, "ai", response_text)
# 最多 2 段,防止 LLM 自问自答
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text] return messages[:2] if messages else [response_text]
except Exception as e: except Exception as e:
logger.error(f"生成资料收集开场白失败: {e}") logger.error(f"生成资料收集开场白失败: {e}")
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"] return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
@@ -169,8 +170,9 @@ class ConversationAgent:
await self._save_message(conversation_id, "ai", response_text) await self._save_message(conversation_id, "ai", response_text)
# 最多 2 段:问候 + 问题,防止 LLM 自问自答
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text] return messages[:2] if messages else [response_text]
except Exception as e: except Exception as e:
logger.error(f"生成开场白失败: {e}", exc_info=True) logger.error(f"生成开场白失败: {e}", exc_info=True)
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"] return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]

View File

@@ -204,7 +204,8 @@ def get_opening_prompt(
## 回复格式 ## 回复格式
- 可以分成 2 条消息,用 [SPLIT] 分隔:第一条问候,第二条问题;或合并成一条「问候 + 问题」。 - 可以分成 2 条消息,用 [SPLIT] 分隔:第一条问候,第二条问题;或合并成一条「问候 + 问题」。
- 禁止输出括号、注释、思考过程。 - **严禁**输出括号、注释、思考过程。
- **严禁**模拟或虚构用户的回答。你只能输出「你的问候 + 你的问题」,不能替用户回答,不能自问自答。最多 2 段(问候 + 问题),禁止更多。
示例(仅供参考风格): 示例(仅供参考风格):
"你好呀~ 有空的话想听听你的人生故事。你小时候是在哪儿长大的?那边有什么特别让你怀念的?" "你好呀~ 有空的话想听听你的人生故事。你小时候是在哪儿长大的?那边有什么特别让你怀念的?"

View File

@@ -478,31 +478,32 @@ async def websocket_endpoint(
return return
# 首次连接时检查资料完整性,发送资料收集开场白 # 首次连接时检查:若 Redis 已有历史(用户曾进入过此对话),不再发送开场白,避免重复/自问自答
missing_profile = _get_missing_profile_fields(user) history = await redis_service.get_conversation_history(conversation_id)
if missing_profile: if not history:
try: # 空对话:发送开场白(资料收集或正式访谈)
greetings = await manager.conversation_agent.generate_profile_greeting( missing_profile = _get_missing_profile_fields(user)
conversation_id=conversation_id, if missing_profile:
missing_fields=missing_profile, try:
nickname=user.nickname or "", greetings = await manager.conversation_agent.generate_profile_greeting(
) conversation_id=conversation_id,
import asyncio as _asyncio_greet missing_fields=missing_profile,
for i, text in enumerate(greetings): nickname=user.nickname or "",
await manager.send_message(conversation_id, { )
"type": MessageType.AGENT_RESPONSE, import asyncio as _asyncio_greet
"conversation_id": conversation_id, for i, text in enumerate(greetings):
"data": {"text": text, "index": i, "total": len(greetings)}, await manager.send_message(conversation_id, {
"timestamp": datetime.now(timezone.utc).isoformat() "type": MessageType.AGENT_RESPONSE,
}) "conversation_id": conversation_id,
if i < len(greetings) - 1: "data": {"text": text, "index": i, "total": len(greetings)},
await _asyncio_greet.sleep(0.5) "timestamp": datetime.now(timezone.utc).isoformat()
except Exception as e: })
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True) if i < len(greetings) - 1:
else: await _asyncio_greet.sleep(0.5)
# 资料已完整若为空对话用户通过「打个招呼」进入AI 先开口提问 except Exception as e:
history = await redis_service.get_conversation_history(conversation_id) logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
if not history: else:
# 资料已完整AI 先开口提问
try: try:
state = await get_or_create_state(user_id, db) state = await get_or_create_state(user_id, db)
user_profile_context = format_user_profile_context( user_profile_context = format_user_profile_context(

View File

@@ -99,6 +99,7 @@ class _FakeManager:
) )
self.conversation_agent = SimpleNamespace( self.conversation_agent = SimpleNamespace(
generate_profile_greeting=AsyncMock(return_value=[]), generate_profile_greeting=AsyncMock(return_value=[]),
generate_opening_message=AsyncMock(return_value=[]),
) )
async def connect(self, websocket, conversation_id): async def connect(self, websocket, conversation_id):
@@ -147,6 +148,15 @@ def _db_provider(db):
return _provider return _provider
def _redis_empty_history_patch():
"""Patch redis to return empty history so websocket sends opening (or skips if mocked)."""
return patch.object(
ws_router.redis_service,
"get_conversation_history",
new=AsyncMock(return_value=[]),
)
class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
async def test_invalid_token_closes_connection(self): async def test_invalid_token_closes_connection(self):
websocket = _FakeWebSocket(messages=[], token="invalid") websocket = _FakeWebSocket(messages=[], token="invalid")
@@ -183,6 +193,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
) )
@@ -235,6 +246,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
) )
@@ -296,6 +308,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock) patch.object(ws_router, "process_user_message", process_user_message_mock)
) )
@@ -339,6 +352,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch.object( patch.object(
ws_router, ws_router,
@@ -383,6 +397,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock) patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
) )
@@ -428,6 +443,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
) )
@@ -503,6 +519,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
) )
@@ -579,6 +596,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
) )
@@ -652,6 +670,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
) )
@@ -714,6 +733,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
) )
@@ -775,6 +795,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
) )
@@ -848,6 +869,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
) )
@@ -927,6 +949,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db)) patch.object(ws_router, "get_async_db", _db_provider(fake_db))
) )
stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context( stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
) )