fix: 用户开始录音5s后ai反馈“我在认真听”
This commit is contained in:
@@ -32,6 +32,7 @@ LEGACY_VOICE_SESSION_ID = "legacy"
|
||||
class MessageType(str, Enum):
|
||||
"""WebSocket 消息类型"""
|
||||
CONNECT = "connect"
|
||||
RECORDING_STARTED = "recording_started" # 客户端开始录音,用于服务端 5s 后发「我在认真听」
|
||||
AUDIO_CHUNK = "audio_chunk"
|
||||
AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传)
|
||||
AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音)
|
||||
@@ -146,6 +147,9 @@ class SegmentStreamState:
|
||||
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
|
||||
consumed_index: int = -1
|
||||
active_tasks: Set[asyncio.Task] = field(default_factory=set)
|
||||
# 录音开始约 5s 后只发一次「我在认真听」;若用户提前结束录音则取消待发
|
||||
listening_feedback_sent: bool = False
|
||||
listening_feedback_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
@@ -257,17 +261,21 @@ async def _get_persisted_contiguous_segment_index(
|
||||
return contiguous_index
|
||||
|
||||
|
||||
LISTENING_FEEDBACK_DELAY_SEC = 5.0
|
||||
LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。"
|
||||
|
||||
|
||||
async def _send_segment_transition_feedback(
|
||||
conversation_id: str,
|
||||
segment_index: int,
|
||||
manager: ConnectionManager,
|
||||
) -> None:
|
||||
"""ASR 处理中先给陪伴式过渡反馈,避免用户感知卡住。"""
|
||||
"""发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。"""
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": "我在认真听,你继续说,我会边听边整理重点。",
|
||||
"text": LISTENING_FEEDBACK_TEXT,
|
||||
"transition": True,
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
@@ -275,6 +283,22 @@ async def _send_segment_transition_feedback(
|
||||
})
|
||||
|
||||
|
||||
async def _delayed_listening_feedback(
|
||||
conversation_id: str,
|
||||
voice_session_id: str,
|
||||
manager: ConnectionManager,
|
||||
) -> None:
|
||||
"""录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。"""
|
||||
await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC)
|
||||
state = manager.get_or_create_segment_state(conversation_id, voice_session_id)
|
||||
async with state.lock:
|
||||
if state.listening_feedback_sent:
|
||||
return
|
||||
state.listening_feedback_sent = True
|
||||
state.listening_feedback_task = None
|
||||
await _send_segment_transition_feedback(conversation_id, 0, manager)
|
||||
|
||||
|
||||
async def _process_audio_segment_async(
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
@@ -600,6 +624,28 @@ async def websocket_endpoint(
|
||||
user_message_timestamp=segment.created_at or user_message_timestamp,
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.RECORDING_STARTED:
|
||||
# 用户点击开始录音:启动 5s 定时器,到时发一次「我在认真听」
|
||||
data = message.get("data", {})
|
||||
voice_session_id = _normalize_voice_session_id(data.get("voice_session_id"))
|
||||
segment_state = manager.get_or_create_segment_state(
|
||||
conversation_id,
|
||||
voice_session_id,
|
||||
)
|
||||
async with segment_state.lock:
|
||||
if segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done():
|
||||
continue # 本会话已有待发任务,不重复
|
||||
if segment_state.listening_feedback_sent:
|
||||
continue
|
||||
delayed_task = asyncio.create_task(
|
||||
_delayed_listening_feedback(
|
||||
conversation_id=conversation_id,
|
||||
voice_session_id=voice_session_id,
|
||||
manager=manager,
|
||||
)
|
||||
)
|
||||
segment_state.listening_feedback_task = delayed_task
|
||||
|
||||
elif msg_type == MessageType.AUDIO_SEGMENT:
|
||||
# 处理分段语音消息(长语音持续上传)
|
||||
data = message.get("data", {})
|
||||
@@ -680,12 +726,13 @@ async def websocket_endpoint(
|
||||
)
|
||||
continue
|
||||
|
||||
# 先发过渡反馈,减少“等待空白”体感
|
||||
await _send_segment_transition_feedback(
|
||||
conversation_id=conversation_id,
|
||||
segment_index=segment_index,
|
||||
manager=manager,
|
||||
)
|
||||
# 若本段是用户结束录音的最后一段,取消尚未发出的「我在认真听」,避免结束后再说
|
||||
if is_last:
|
||||
async with segment_state.lock:
|
||||
t = segment_state.listening_feedback_task
|
||||
segment_state.listening_feedback_task = None
|
||||
if t is not None and not t.done():
|
||||
t.cancel()
|
||||
|
||||
task = asyncio.create_task(
|
||||
_process_audio_segment_async(
|
||||
|
||||
@@ -752,13 +752,56 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# 当前逻辑:仅首个分段且 5s 后发一次「我在认真听」;若本段 is_last 则取消,故此处应为 0
|
||||
transition_msgs = [
|
||||
item["message"]
|
||||
for item in fake_manager.sent_messages
|
||||
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
|
||||
and item["message"].get("data", {}).get("transition") is True
|
||||
]
|
||||
self.assertGreaterEqual(len(transition_msgs), 1)
|
||||
self.assertEqual(len(transition_msgs), 0)
|
||||
|
||||
async def test_recording_started_sends_listening_feedback_after_delay(self):
|
||||
"""客户端发送 recording_started 后,延迟 5s 发一次「我在认真听」。"""
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{"type": "recording_started", "data": {"voice_session_id": "session-1"}},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(_redis_empty_history_patch())
|
||||
stack.enter_context(
|
||||
patch("routers.websocket.LISTENING_FEEDBACK_DELAY_SEC", 0.05)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.12)
|
||||
|
||||
transition_msgs = [
|
||||
item["message"]
|
||||
for item in fake_manager.sent_messages
|
||||
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
|
||||
and item["message"].get("data", {}).get("transition") is True
|
||||
]
|
||||
self.assertEqual(len(transition_msgs), 1)
|
||||
self.assertIn("我在认真听", transition_msgs[0]["data"].get("text", ""))
|
||||
|
||||
async def test_audio_segment_last_segment_does_not_emit_terminal_transition(self):
|
||||
user = _make_user()
|
||||
@@ -814,14 +857,14 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# 仅一段且 is_last:延迟任务被取消,不应发出 transition
|
||||
transition_msgs = [
|
||||
item["message"]
|
||||
for item in fake_manager.sent_messages
|
||||
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
|
||||
and item["message"].get("data", {}).get("transition") is True
|
||||
]
|
||||
self.assertEqual(len(transition_msgs), 1)
|
||||
self.assertIsNone(transition_msgs[0]["data"].get("is_last"))
|
||||
self.assertEqual(len(transition_msgs), 0)
|
||||
|
||||
async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(self):
|
||||
user = _make_user()
|
||||
|
||||
@@ -62,6 +62,10 @@ class ConversationRealtimeAdapter(
|
||||
webSocketClient.sendTextMessage(text, conversationId)
|
||||
}
|
||||
|
||||
override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) {
|
||||
webSocketClient.sendRecordingStarted(conversationId, voiceSessionId)
|
||||
}
|
||||
|
||||
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) {
|
||||
webSocketClient.sendAudioChunk(chunk, conversationId)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ interface ConversationRealtimePort {
|
||||
fun isConnected(): Boolean
|
||||
|
||||
suspend fun sendText(conversationId: String, text: String)
|
||||
suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String)
|
||||
suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String)
|
||||
suspend fun sendAudioSegment(request: AudioSegmentRequest)
|
||||
suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int)
|
||||
|
||||
@@ -243,6 +243,17 @@ open class WebSocketClient(
|
||||
))
|
||||
}
|
||||
|
||||
suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) {
|
||||
if (!_isConnected) return
|
||||
sendMessage(WebSocketMessage(
|
||||
type = MessageType.recording_started,
|
||||
conversation_id = conversationId,
|
||||
data = buildJsonObject {
|
||||
put("voice_session_id", JsonPrimitive(voiceSessionId))
|
||||
}
|
||||
))
|
||||
}
|
||||
|
||||
suspend fun sendAudioSegment(
|
||||
audioBytes: ByteArray,
|
||||
conversationId: String,
|
||||
|
||||
@@ -8,6 +8,7 @@ import kotlinx.serialization.json.buildJsonObject
|
||||
@Serializable
|
||||
enum class MessageType {
|
||||
connect,
|
||||
recording_started, // 用户点击开始录音,服务端据此 5s 后发「我在认真听」
|
||||
audio_chunk,
|
||||
audio_segment, // 分段语音(长语音边录边传)
|
||||
audio_message, // 完整音频消息(类似微信语音)
|
||||
|
||||
@@ -357,6 +357,13 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
}
|
||||
Log.d(TAG, "录音已启动: ${result.session.filePath}")
|
||||
// 通知服务端「开始录音」,服务端 5s 后发「我在认真听」
|
||||
viewModelScope.launch {
|
||||
val cid = conversationId.value
|
||||
if (cid != null && cid != "new" && ensureConnected()) {
|
||||
conversationRealtime.sendRecordingStarted(cid, result.session.voiceSessionId)
|
||||
}
|
||||
}
|
||||
}
|
||||
is RecordingStartResult.AlreadyRecording -> {
|
||||
Log.w(TAG, "录音已在进行中")
|
||||
|
||||
@@ -661,6 +661,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest {
|
||||
override suspend fun disconnect() = Unit
|
||||
override fun isConnected(): Boolean = connected
|
||||
override suspend fun sendText(conversationId: String, text: String) = Unit
|
||||
override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) = Unit
|
||||
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit
|
||||
override suspend fun sendAudioSegment(request: AudioSegmentRequest) {
|
||||
sendAudioSegmentFailure?.let { throw it }
|
||||
@@ -705,6 +706,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest {
|
||||
override suspend fun sendText(conversationId: String, text: String) {
|
||||
sentTexts += text
|
||||
}
|
||||
override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) = Unit
|
||||
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit
|
||||
override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit
|
||||
override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit
|
||||
|
||||
@@ -290,6 +290,7 @@ class CreateMemoryViewModelWarmupTest {
|
||||
|
||||
override fun isConnected(): Boolean = connected
|
||||
override suspend fun sendText(conversationId: String, text: String) = Unit
|
||||
override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) = Unit
|
||||
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit
|
||||
override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit
|
||||
override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit
|
||||
|
||||
Reference in New Issue
Block a user