fix: 修复打个招呼的bug

This commit is contained in:
yangshilin
2026-03-13 14:50:04 +08:00
parent 2eb066dbec
commit ab4fb46c0d
4 changed files with 55 additions and 28 deletions

View File

@@ -133,8 +133,9 @@ class ConversationAgent:
await self._save_message(conversation_id, "ai", response_text)
# 最多 2 段,防止 LLM 自问自答
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
return messages[:2] if messages else [response_text]
except Exception as e:
logger.error(f"生成资料收集开场白失败: {e}")
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
@@ -169,8 +170,9 @@ class ConversationAgent:
await self._save_message(conversation_id, "ai", response_text)
# 最多 2 段:问候 + 问题,防止 LLM 自问自答
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
return messages[:2] if messages else [response_text]
except Exception as e:
logger.error(f"生成开场白失败: {e}", exc_info=True)
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]

View File

@@ -204,7 +204,8 @@ def get_opening_prompt(
## 回复格式
- 可以分成 2 条消息,用 [SPLIT] 分隔:第一条问候,第二条问题;或合并成一条「问候 + 问题」。
- 禁止输出括号、注释、思考过程。
- **严禁**输出括号、注释、思考过程。
- **严禁**模拟或虚构用户的回答。你只能输出「你的问候 + 你的问题」,不能替用户回答,不能自问自答。最多 2 段(问候 + 问题),禁止更多。
示例(仅供参考风格):
"你好呀~ 有空的话想听听你的人生故事。你小时候是在哪儿长大的?那边有什么特别让你怀念的?"

View File

@@ -478,31 +478,32 @@ async def websocket_endpoint(
return
# 首次连接时检查资料完整性,发送资料收集开场白
missing_profile = _get_missing_profile_fields(user)
if missing_profile:
try:
greetings = await manager.conversation_agent.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_profile,
nickname=user.nickname or "",
)
import asyncio as _asyncio_greet
for i, text in enumerate(greetings):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": text, "index": i, "total": len(greetings)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
if i < len(greetings) - 1:
await _asyncio_greet.sleep(0.5)
except Exception as e:
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
else:
# 资料已完整若为空对话用户通过「打个招呼」进入AI 先开口提问
history = await redis_service.get_conversation_history(conversation_id)
if not history:
# 首次连接时检查:若 Redis 已有历史(用户曾进入过此对话),不再发送开场白,避免重复/自问自答
history = await redis_service.get_conversation_history(conversation_id)
if not history:
# 空对话:发送开场白(资料收集或正式访谈)
missing_profile = _get_missing_profile_fields(user)
if missing_profile:
try:
greetings = await manager.conversation_agent.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_profile,
nickname=user.nickname or "",
)
import asyncio as _asyncio_greet
for i, text in enumerate(greetings):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": text, "index": i, "total": len(greetings)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
if i < len(greetings) - 1:
await _asyncio_greet.sleep(0.5)
except Exception as e:
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
else:
# 资料已完整AI 先开口提问
try:
state = await get_or_create_state(user_id, db)
user_profile_context = format_user_profile_context(

View File

@@ -99,6 +99,7 @@ class _FakeManager:
)
self.conversation_agent = SimpleNamespace(
generate_profile_greeting=AsyncMock(return_value=[]),
generate_opening_message=AsyncMock(return_value=[]),
)
async def connect(self, websocket, conversation_id):
@@ -147,6 +148,15 @@ def _db_provider(db):
return _provider
def _redis_empty_history_patch():
"""Patch redis to return empty history so websocket sends opening (or skips if mocked)."""
return patch.object(
ws_router.redis_service,
"get_conversation_history",
new=AsyncMock(return_value=[]),
)
class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
async def test_invalid_token_closes_connection(self):
websocket = _FakeWebSocket(messages=[], token="invalid")
@@ -183,6 +193,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
@@ -235,6 +246,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
@@ -296,6 +308,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
@@ -339,6 +352,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch.object(
ws_router,
@@ -383,6 +397,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
@@ -428,6 +443,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
@@ -503,6 +519,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
@@ -579,6 +596,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
@@ -652,6 +670,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
@@ -714,6 +733,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
@@ -775,6 +795,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
@@ -848,6 +869,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
@@ -927,6 +949,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)