diff --git a/api/routers/websocket.py b/api/routers/websocket.py index 4d6e718..f6c9afa 100644 --- a/api/routers/websocket.py +++ b/api/routers/websocket.py @@ -32,6 +32,7 @@ LEGACY_VOICE_SESSION_ID = "legacy" class MessageType(str, Enum): """WebSocket 消息类型""" CONNECT = "connect" + RECORDING_STARTED = "recording_started" # 客户端开始录音,用于服务端 5s 后发「我在认真听」 AUDIO_CHUNK = "audio_chunk" AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传) AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音) @@ -146,6 +147,9 @@ class SegmentStreamState: buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict) consumed_index: int = -1 active_tasks: Set[asyncio.Task] = field(default_factory=set) + # 录音开始约 5s 后只发一次「我在认真听」;若用户提前结束录音则取消待发 + listening_feedback_sent: bool = False + listening_feedback_task: Optional[asyncio.Task] = None def _utc_now() -> datetime: @@ -257,17 +261,21 @@ async def _get_persisted_contiguous_segment_index( return contiguous_index +LISTENING_FEEDBACK_DELAY_SEC = 5.0 +LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。" + + async def _send_segment_transition_feedback( conversation_id: str, segment_index: int, manager: ConnectionManager, ) -> None: - """ASR 处理中先给陪伴式过渡反馈,避免用户感知卡住。""" + """发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。""" await manager.send_message(conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { - "text": "我在认真听,你继续说,我会边听边整理重点。", + "text": LISTENING_FEEDBACK_TEXT, "transition": True, "segment_index": segment_index, }, @@ -275,6 +283,22 @@ async def _send_segment_transition_feedback( }) +async def _delayed_listening_feedback( + conversation_id: str, + voice_session_id: str, + manager: ConnectionManager, +) -> None: + """录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。""" + await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC) + state = manager.get_or_create_segment_state(conversation_id, voice_session_id) + async with state.lock: + if state.listening_feedback_sent: + return + state.listening_feedback_sent = True + state.listening_feedback_task = None + await _send_segment_transition_feedback(conversation_id, 0, manager) + + async def _process_audio_segment_async( conversation_id: str, user_id: str, @@ -600,6 +624,28 @@ async def websocket_endpoint( user_message_timestamp=segment.created_at or user_message_timestamp, ) + elif msg_type == MessageType.RECORDING_STARTED: + # 用户点击开始录音:启动 5s 定时器,到时发一次「我在认真听」 + data = message.get("data", {}) + voice_session_id = _normalize_voice_session_id(data.get("voice_session_id")) + segment_state = manager.get_or_create_segment_state( + conversation_id, + voice_session_id, + ) + async with segment_state.lock: + if segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done(): + continue # 本会话已有待发任务,不重复 + if segment_state.listening_feedback_sent: + continue + delayed_task = asyncio.create_task( + _delayed_listening_feedback( + conversation_id=conversation_id, + voice_session_id=voice_session_id, + manager=manager, + ) + ) + segment_state.listening_feedback_task = delayed_task + elif msg_type == MessageType.AUDIO_SEGMENT: # 处理分段语音消息(长语音持续上传) data = message.get("data", {}) @@ -680,12 +726,13 @@ async def websocket_endpoint( ) continue - # 先发过渡反馈,减少“等待空白”体感 - await _send_segment_transition_feedback( - conversation_id=conversation_id, - segment_index=segment_index, - manager=manager, - ) + # 若本段是用户结束录音的最后一段,取消尚未发出的「我在认真听」,避免结束后再说 + if is_last: + async with segment_state.lock: + t = segment_state.listening_feedback_task + segment_state.listening_feedback_task = None + if t is not None and not t.done(): + t.cancel() task = asyncio.create_task( _process_audio_segment_async( diff --git a/api/tests/test_websocket_baseline.py b/api/tests/test_websocket_baseline.py index 011c217..cc0839b 100644 --- a/api/tests/test_websocket_baseline.py +++ b/api/tests/test_websocket_baseline.py @@ -752,13 +752,56 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): await ws_router.websocket_endpoint(fake_websocket, "conv-1") await asyncio.sleep(0.05) + # 当前逻辑:仅首个分段且 5s 后发一次「我在认真听」;若本段 is_last 则取消,故此处应为 0 transition_msgs = [ item["message"] for item in fake_manager.sent_messages if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE and item["message"].get("data", {}).get("transition") is True ] - self.assertGreaterEqual(len(transition_msgs), 1) + self.assertEqual(len(transition_msgs), 0) + + async def test_recording_started_sends_listening_feedback_after_delay(self): + """客户端发送 recording_started 后,延迟 5s 发一次「我在认真听」。""" + user = _make_user() + conversation = Conversation(id="conv-1", user_id=user.id, status="active") + fake_db = _FakeAsyncDB(user=user, conversation=conversation) + fake_manager = _FakeManager() + fake_websocket = _FakeWebSocket( + messages=[ + {"type": "recording_started", "data": {"voice_session_id": "session-1"}}, + WebSocketDisconnect(), + ] + ) + + with ExitStack() as stack: + stack.enter_context( + patch.object( + ws_router, + "verify_token", + return_value={"type": "access", "sub": user.id}, + ) + ) + stack.enter_context( + patch.object(ws_router, "get_async_db", _db_provider(fake_db)) + ) + stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context(_redis_empty_history_patch()) + stack.enter_context( + patch("routers.websocket.LISTENING_FEEDBACK_DELAY_SEC", 0.05) + ) + + await ws_router.websocket_endpoint(fake_websocket, "conv-1") + await asyncio.sleep(0.12) + + transition_msgs = [ + item["message"] + for item in fake_manager.sent_messages + if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE + and item["message"].get("data", {}).get("transition") is True + ] + self.assertEqual(len(transition_msgs), 1) + self.assertIn("我在认真听", transition_msgs[0]["data"].get("text", "")) async def test_audio_segment_last_segment_does_not_emit_terminal_transition(self): user = _make_user() @@ -814,14 +857,14 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): await ws_router.websocket_endpoint(fake_websocket, "conv-1") await asyncio.sleep(0.05) + # 仅一段且 is_last:延迟任务被取消,不应发出 transition transition_msgs = [ item["message"] for item in fake_manager.sent_messages if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE and item["message"].get("data", {}).get("transition") is True ] - self.assertEqual(len(transition_msgs), 1) - self.assertIsNone(transition_msgs[0]["data"].get("is_last")) + self.assertEqual(len(transition_msgs), 0) async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(self): user = _make_user() diff --git a/app-android/app/src/main/java/com/huaga/life_echo/feature/conversation/adapters/ConversationRealtimeAdapter.kt b/app-android/app/src/main/java/com/huaga/life_echo/feature/conversation/adapters/ConversationRealtimeAdapter.kt index c7de86f..e4bf470 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/feature/conversation/adapters/ConversationRealtimeAdapter.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/feature/conversation/adapters/ConversationRealtimeAdapter.kt @@ -62,6 +62,10 @@ class ConversationRealtimeAdapter( webSocketClient.sendTextMessage(text, conversationId) } + override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) { + webSocketClient.sendRecordingStarted(conversationId, voiceSessionId) + } + override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) { webSocketClient.sendAudioChunk(chunk, conversationId) } diff --git a/app-android/app/src/main/java/com/huaga/life_echo/feature/conversation/ports/ConversationRealtimePort.kt b/app-android/app/src/main/java/com/huaga/life_echo/feature/conversation/ports/ConversationRealtimePort.kt index 93f299d..a309833 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/feature/conversation/ports/ConversationRealtimePort.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/feature/conversation/ports/ConversationRealtimePort.kt @@ -24,6 +24,7 @@ interface ConversationRealtimePort { fun isConnected(): Boolean suspend fun sendText(conversationId: String, text: String) + suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) suspend fun sendAudioSegment(request: AudioSegmentRequest) suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) diff --git a/app-android/app/src/main/java/com/huaga/life_echo/network/WebSocketClient.kt b/app-android/app/src/main/java/com/huaga/life_echo/network/WebSocketClient.kt index 9cce1e7..725aa13 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/network/WebSocketClient.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/network/WebSocketClient.kt @@ -242,6 +242,17 @@ open class WebSocketClient( data = buildJsonObject { put("text", JsonPrimitive(text)) } )) } + + suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) { + if (!_isConnected) return + sendMessage(WebSocketMessage( + type = MessageType.recording_started, + conversation_id = conversationId, + data = buildJsonObject { + put("voice_session_id", JsonPrimitive(voiceSessionId)) + } + )) + } suspend fun sendAudioSegment( audioBytes: ByteArray, diff --git a/app-android/app/src/main/java/com/huaga/life_echo/network/WebSocketMessage.kt b/app-android/app/src/main/java/com/huaga/life_echo/network/WebSocketMessage.kt index 58396b1..c8633b3 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/network/WebSocketMessage.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/network/WebSocketMessage.kt @@ -8,6 +8,7 @@ import kotlinx.serialization.json.buildJsonObject @Serializable enum class MessageType { connect, + recording_started, // 用户点击开始录音,服务端据此 5s 后发「我在认真听」 audio_chunk, audio_segment, // 分段语音(长语音边录边传) audio_message, // 完整音频消息(类似微信语音) diff --git a/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModel.kt b/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModel.kt index c324694..6e3ebbb 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModel.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModel.kt @@ -357,6 +357,13 @@ class CreateMemoryViewModel( } } Log.d(TAG, "录音已启动: ${result.session.filePath}") + // 通知服务端「开始录音」,服务端 5s 后发「我在认真听」 + viewModelScope.launch { + val cid = conversationId.value + if (cid != null && cid != "new" && ensureConnected()) { + conversationRealtime.sendRecordingStarted(cid, result.session.voiceSessionId) + } + } } is RecordingStartResult.AlreadyRecording -> { Log.w(TAG, "录音已在进行中") diff --git a/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelRecordingCoordinatorTest.kt b/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelRecordingCoordinatorTest.kt index 35d4af8..f6d63f2 100644 --- a/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelRecordingCoordinatorTest.kt +++ b/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelRecordingCoordinatorTest.kt @@ -661,6 +661,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest { override suspend fun disconnect() = Unit override fun isConnected(): Boolean = connected override suspend fun sendText(conversationId: String, text: String) = Unit + override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) = Unit override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit override suspend fun sendAudioSegment(request: AudioSegmentRequest) { sendAudioSegmentFailure?.let { throw it } @@ -705,6 +706,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest { override suspend fun sendText(conversationId: String, text: String) { sentTexts += text } + override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) = Unit override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit diff --git a/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelWarmupTest.kt b/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelWarmupTest.kt index 6f024e5..9f6dd7c 100644 --- a/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelWarmupTest.kt +++ b/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelWarmupTest.kt @@ -290,6 +290,7 @@ class CreateMemoryViewModelWarmupTest { override fun isConnected(): Boolean = connected override suspend fun sendText(conversationId: String, text: String) = Unit + override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) = Unit override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit