feat: 支持长语音分段上传与断线补传
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
@@ -89,6 +90,7 @@ class _FakeAsyncDB:
|
||||
class _FakeManager:
|
||||
def __init__(self):
|
||||
self.active_connections = {}
|
||||
self.segment_states = {}
|
||||
self.sent_messages = []
|
||||
self.disconnect_calls = []
|
||||
self.background_runner = SimpleNamespace(
|
||||
@@ -109,6 +111,21 @@ class _FakeManager:
|
||||
async def send_message(self, conversation_id, message):
|
||||
self.sent_messages.append({"conversation_id": conversation_id, "message": message})
|
||||
|
||||
def get_or_create_segment_state(self, conversation_id, voice_session_id):
|
||||
state_key = (conversation_id, voice_session_id)
|
||||
if state_key not in self.segment_states:
|
||||
self.segment_states[state_key] = ws_router.SegmentStreamState()
|
||||
return self.segment_states[state_key]
|
||||
|
||||
def register_segment_task(self, conversation_id, voice_session_id, task):
|
||||
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
|
||||
state.active_tasks.add(task)
|
||||
|
||||
def _cleanup(done_task):
|
||||
state.active_tasks.discard(done_task)
|
||||
|
||||
task.add_done_callback(_cleanup)
|
||||
|
||||
|
||||
def _make_user():
|
||||
# Provide all profile fields to skip greeting/profile-collection branch.
|
||||
@@ -435,6 +452,431 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(len(error_msgs), 1)
|
||||
self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入")
|
||||
|
||||
async def test_audio_segment_out_of_order_is_aggregated_by_segment_index(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "seg-1",
|
||||
"segment_index": 1,
|
||||
"duration": 12,
|
||||
"is_last": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "seg-0",
|
||||
"segment_index": 0,
|
||||
"duration": 10,
|
||||
"is_last": False,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(
|
||||
side_effect=lambda audio: {
|
||||
"seg-0": "这是第 0 段",
|
||||
"seg-1": "这是第 1 段",
|
||||
}[audio]
|
||||
)
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
self.assertEqual(transcribe_mock.await_count, 2)
|
||||
ordered_messages = [
|
||||
call.kwargs["user_message"] for call in process_user_message_mock.await_args_list
|
||||
]
|
||||
self.assertEqual(ordered_messages, ["这是第 0 段", "这是第 1 段"])
|
||||
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2)
|
||||
|
||||
async def test_audio_segment_duplicate_index_is_idempotent(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "dup-seg-0-a",
|
||||
"segment_index": 0,
|
||||
"duration": 10,
|
||||
"is_last": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "dup-seg-0-b",
|
||||
"segment_index": 0,
|
||||
"duration": 10,
|
||||
"is_last": True,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(return_value="重复分段去重测试")
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
self.assertEqual(transcribe_mock.await_count, 1)
|
||||
process_user_message_mock.assert_awaited_once()
|
||||
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 1)
|
||||
|
||||
async def test_audio_segment_same_index_is_allowed_for_different_voice_sessions(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "session-a-seg-0",
|
||||
"voice_session_id": "voice-session-a",
|
||||
"client_segment_id": "voice-session-a-0",
|
||||
"segment_index": 0,
|
||||
"duration": 10,
|
||||
"is_last": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "session-b-seg-0",
|
||||
"voice_session_id": "voice-session-b",
|
||||
"client_segment_id": "voice-session-b-0",
|
||||
"segment_index": 0,
|
||||
"duration": 8,
|
||||
"is_last": True,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(
|
||||
side_effect=lambda audio: {
|
||||
"session-a-seg-0": "第一轮第 0 段",
|
||||
"session-b-seg-0": "第二轮第 0 段",
|
||||
}[audio]
|
||||
)
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
ordered_messages = [
|
||||
call.kwargs["user_message"] for call in process_user_message_mock.await_args_list
|
||||
]
|
||||
self.assertEqual(ordered_messages, ["第一轮第 0 段", "第二轮第 0 段"])
|
||||
self.assertEqual(transcribe_mock.await_count, 2)
|
||||
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2)
|
||||
|
||||
async def test_audio_segment_sends_transition_feedback_while_processing(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "slow-seg-0",
|
||||
"segment_index": 0,
|
||||
"duration": 20,
|
||||
"is_last": True,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
async def _slow_transcribe(_: str) -> str:
|
||||
await asyncio.sleep(0.2)
|
||||
return "慢速转写"
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(side_effect=_slow_transcribe)
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
transition_msgs = [
|
||||
item["message"]
|
||||
for item in fake_manager.sent_messages
|
||||
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
|
||||
and item["message"].get("data", {}).get("transition") is True
|
||||
]
|
||||
self.assertGreaterEqual(len(transition_msgs), 1)
|
||||
|
||||
async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
existing_segment = Segment(
|
||||
id="seg-existing-0",
|
||||
conversation_id="conv-1",
|
||||
transcript_text="已存在的上一段",
|
||||
audio_url="audio-segment:voice-session-1:0",
|
||||
processed=False,
|
||||
)
|
||||
fake_db = _FakeAsyncDB(
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
segments=[existing_segment],
|
||||
)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "seg-1-after-reconnect",
|
||||
"voice_session_id": "voice-session-1",
|
||||
"client_segment_id": "voice-session-1-1",
|
||||
"segment_index": 1,
|
||||
"duration": 18,
|
||||
"is_last": True,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(return_value="这是重连后的第 1 段")
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
process_user_message_mock.assert_awaited_once()
|
||||
self.assertEqual(
|
||||
process_user_message_mock.await_args.kwargs["user_message"],
|
||||
"这是重连后的第 1 段",
|
||||
)
|
||||
|
||||
async def test_audio_segment_reconnect_uses_contiguous_prefix_not_max_index(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
existing_segments = [
|
||||
Segment(
|
||||
id="seg-existing-0",
|
||||
conversation_id="conv-1",
|
||||
transcript_text="已存在的第 0 段",
|
||||
audio_url="audio-segment:voice-session-gap:0",
|
||||
processed=False,
|
||||
),
|
||||
Segment(
|
||||
id="seg-existing-2",
|
||||
conversation_id="conv-1",
|
||||
transcript_text="已存在的第 2 段",
|
||||
audio_url="audio-segment:voice-session-gap:2",
|
||||
processed=False,
|
||||
),
|
||||
]
|
||||
fake_db = _FakeAsyncDB(
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
segments=existing_segments,
|
||||
)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "seg-1-gap-retry",
|
||||
"voice_session_id": "voice-session-gap",
|
||||
"client_segment_id": "voice-session-gap-1",
|
||||
"segment_index": 1,
|
||||
"duration": 18,
|
||||
"is_last": False,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(return_value="补传后的第 1 段")
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
process_user_message_mock.assert_awaited_once()
|
||||
self.assertEqual(
|
||||
process_user_message_mock.await_args.kwargs["user_message"],
|
||||
"补传后的第 1 段",
|
||||
)
|
||||
self.assertEqual(transcribe_mock.await_count, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user