fix: 修复打个招呼的bug
This commit is contained in:
@@ -133,8 +133,9 @@ class ConversationAgent:
|
||||
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
# 最多 2 段,防止 LLM 自问自答
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
return messages[:2] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error(f"生成资料收集开场白失败: {e}")
|
||||
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
|
||||
@@ -169,8 +170,9 @@ class ConversationAgent:
|
||||
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
# 最多 2 段:问候 + 问题,防止 LLM 自问自答
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
return messages[:2] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error(f"生成开场白失败: {e}", exc_info=True)
|
||||
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]
|
||||
|
||||
@@ -204,7 +204,8 @@ def get_opening_prompt(
|
||||
|
||||
## 回复格式
|
||||
- 可以分成 2 条消息,用 [SPLIT] 分隔:第一条问候,第二条问题;或合并成一条「问候 + 问题」。
|
||||
- 禁止输出括号、注释、思考过程。
|
||||
- **严禁**输出括号、注释、思考过程。
|
||||
- **严禁**模拟或虚构用户的回答。你只能输出「你的问候 + 你的问题」,不能替用户回答,不能自问自答。最多 2 段(问候 + 问题),禁止更多。
|
||||
|
||||
示例(仅供参考风格):
|
||||
"你好呀~ 有空的话想听听你的人生故事。你小时候是在哪儿长大的?那边有什么特别让你怀念的?"
|
||||
|
||||
@@ -478,31 +478,32 @@ async def websocket_endpoint(
|
||||
return
|
||||
|
||||
|
||||
# 首次连接时检查资料完整性,发送资料收集开场白
|
||||
missing_profile = _get_missing_profile_fields(user)
|
||||
if missing_profile:
|
||||
try:
|
||||
greetings = await manager.conversation_agent.generate_profile_greeting(
|
||||
conversation_id=conversation_id,
|
||||
missing_fields=missing_profile,
|
||||
nickname=user.nickname or "",
|
||||
)
|
||||
import asyncio as _asyncio_greet
|
||||
for i, text in enumerate(greetings):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": text, "index": i, "total": len(greetings)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
if i < len(greetings) - 1:
|
||||
await _asyncio_greet.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
|
||||
else:
|
||||
# 资料已完整:若为空对话(用户通过「打个招呼」进入),AI 先开口提问
|
||||
history = await redis_service.get_conversation_history(conversation_id)
|
||||
if not history:
|
||||
# 首次连接时检查:若 Redis 已有历史(用户曾进入过此对话),不再发送开场白,避免重复/自问自答
|
||||
history = await redis_service.get_conversation_history(conversation_id)
|
||||
if not history:
|
||||
# 空对话:发送开场白(资料收集或正式访谈)
|
||||
missing_profile = _get_missing_profile_fields(user)
|
||||
if missing_profile:
|
||||
try:
|
||||
greetings = await manager.conversation_agent.generate_profile_greeting(
|
||||
conversation_id=conversation_id,
|
||||
missing_fields=missing_profile,
|
||||
nickname=user.nickname or "",
|
||||
)
|
||||
import asyncio as _asyncio_greet
|
||||
for i, text in enumerate(greetings):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": text, "index": i, "total": len(greetings)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
if i < len(greetings) - 1:
|
||||
await _asyncio_greet.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
|
||||
else:
|
||||
# 资料已完整:AI 先开口提问
|
||||
try:
|
||||
state = await get_or_create_state(user_id, db)
|
||||
user_profile_context = format_user_profile_context(
|
||||
|
||||
@@ -99,6 +99,7 @@ class _FakeManager:
|
||||
)
|
||||
self.conversation_agent = SimpleNamespace(
|
||||
generate_profile_greeting=AsyncMock(return_value=[]),
|
||||
generate_opening_message=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
async def connect(self, websocket, conversation_id):
|
||||
@@ -147,6 +148,15 @@ def _db_provider(db):
|
||||
return _provider
|
||||
|
||||
|
||||
def _redis_empty_history_patch():
|
||||
"""Patch redis to return empty history so websocket sends opening (or skips if mocked)."""
|
||||
return patch.object(
|
||||
ws_router.redis_service,
|
||||
"get_conversation_history",
|
||||
new=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
|
||||
class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_invalid_token_closes_connection(self):
|
||||
websocket = _FakeWebSocket(messages=[], token="invalid")
|
||||
@@ -183,6 +193,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
@@ -235,6 +246,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
@@ -296,6 +308,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
@@ -339,6 +352,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
@@ -383,6 +397,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
@@ -428,6 +443,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
@@ -503,6 +519,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
@@ -579,6 +596,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
@@ -652,6 +670,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
@@ -714,6 +733,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
@@ -775,6 +795,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
@@ -848,6 +869,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
@@ -927,6 +949,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user