diff --git a/api/agents/conversation_agent.py b/api/agents/conversation_agent.py index 6ee361b..909df18 100644 --- a/api/agents/conversation_agent.py +++ b/api/agents/conversation_agent.py @@ -133,8 +133,9 @@ class ConversationAgent: await self._save_message(conversation_id, "ai", response_text) + # 最多 2 段,防止 LLM 自问自答 messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] - return messages[:3] if messages else [response_text] + return messages[:2] if messages else [response_text] except Exception as e: logger.error(f"生成资料收集开场白失败: {e}") return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"] @@ -169,8 +170,9 @@ class ConversationAgent: await self._save_message(conversation_id, "ai", response_text) + # 最多 2 段:问候 + 问题,防止 LLM 自问自答 messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] - return messages[:3] if messages else [response_text] + return messages[:2] if messages else [response_text] except Exception as e: logger.error(f"生成开场白失败: {e}", exc_info=True) return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"] diff --git a/api/agents/prompts/conversation_prompts.py b/api/agents/prompts/conversation_prompts.py index b79917d..6e1d41b 100644 --- a/api/agents/prompts/conversation_prompts.py +++ b/api/agents/prompts/conversation_prompts.py @@ -204,7 +204,8 @@ def get_opening_prompt( ## 回复格式 - 可以分成 2 条消息,用 [SPLIT] 分隔:第一条问候,第二条问题;或合并成一条「问候 + 问题」。 -- 禁止输出括号、注释、思考过程。 +- **严禁**输出括号、注释、思考过程。 +- **严禁**模拟或虚构用户的回答。你只能输出「你的问候 + 你的问题」,不能替用户回答,不能自问自答。最多 2 段(问候 + 问题),禁止更多。 示例(仅供参考风格): "你好呀~ 有空的话想听听你的人生故事。你小时候是在哪儿长大的?那边有什么特别让你怀念的?" diff --git a/api/routers/websocket.py b/api/routers/websocket.py index 5a0807d..8f305a8 100644 --- a/api/routers/websocket.py +++ b/api/routers/websocket.py @@ -478,31 +478,32 @@ async def websocket_endpoint( return - # 首次连接时检查资料完整性,发送资料收集开场白 - missing_profile = _get_missing_profile_fields(user) - if missing_profile: - try: - greetings = await manager.conversation_agent.generate_profile_greeting( - conversation_id=conversation_id, - missing_fields=missing_profile, - nickname=user.nickname or "", - ) - import asyncio as _asyncio_greet - for i, text in enumerate(greetings): - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": {"text": text, "index": i, "total": len(greetings)}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - if i < len(greetings) - 1: - await _asyncio_greet.sleep(0.5) - except Exception as e: - logger.error(f"发送资料收集开场白失败: {e}", exc_info=True) - else: - # 资料已完整:若为空对话(用户通过「打个招呼」进入),AI 先开口提问 - history = await redis_service.get_conversation_history(conversation_id) - if not history: + # 首次连接时检查:若 Redis 已有历史(用户曾进入过此对话),不再发送开场白,避免重复/自问自答 + history = await redis_service.get_conversation_history(conversation_id) + if not history: + # 空对话:发送开场白(资料收集或正式访谈) + missing_profile = _get_missing_profile_fields(user) + if missing_profile: + try: + greetings = await manager.conversation_agent.generate_profile_greeting( + conversation_id=conversation_id, + missing_fields=missing_profile, + nickname=user.nickname or "", + ) + import asyncio as _asyncio_greet + for i, text in enumerate(greetings): + await manager.send_message(conversation_id, { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": {"text": text, "index": i, "total": len(greetings)}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + if i < len(greetings) - 1: + await _asyncio_greet.sleep(0.5) + except Exception as e: + logger.error(f"发送资料收集开场白失败: {e}", exc_info=True) + else: + # 资料已完整:AI 先开口提问 try: state = await get_or_create_state(user_id, db) user_profile_context = format_user_profile_context( diff --git a/api/tests/test_websocket_baseline.py b/api/tests/test_websocket_baseline.py index a9505e2..a995b70 100644 --- a/api/tests/test_websocket_baseline.py +++ b/api/tests/test_websocket_baseline.py @@ -99,6 +99,7 @@ class _FakeManager: ) self.conversation_agent = SimpleNamespace( generate_profile_greeting=AsyncMock(return_value=[]), + generate_opening_message=AsyncMock(return_value=[]), ) async def connect(self, websocket, conversation_id): @@ -147,6 +148,15 @@ def _db_provider(db): return _provider +def _redis_empty_history_patch(): + """Patch redis to return empty history so websocket sends opening (or skips if mocked).""" + return patch.object( + ws_router.redis_service, + "get_conversation_history", + new=AsyncMock(return_value=[]), + ) + + class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): async def test_invalid_token_closes_connection(self): websocket = _FakeWebSocket(messages=[], token="invalid") @@ -183,6 +193,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) @@ -235,6 +246,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) @@ -296,6 +308,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch.object(ws_router, "process_user_message", process_user_message_mock) ) @@ -339,6 +352,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch.object( ws_router, @@ -383,6 +397,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch.object(ws_router.asr_service, "transcribe", transcribe_mock) ) @@ -428,6 +443,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) @@ -503,6 +519,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) @@ -579,6 +596,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) @@ -652,6 +670,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) @@ -714,6 +733,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) @@ -775,6 +795,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) @@ -848,6 +869,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) @@ -927,6 +949,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) )