test: 在重构前补充 WebSocket 基线测试
补齐 WebSocket 关键流程的基线测试,为后续长语音重构提供回归保护。
This commit is contained in:
440
api/tests/test_websocket_baseline.py
Normal file
440
api/tests/test_websocket_baseline.py
Normal file
@@ -0,0 +1,440 @@
|
||||
import unittest
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from starlette.websockets import WebSocketDisconnect, WebSocketState
|
||||
|
||||
from database.models import Conversation, Segment
|
||||
from routers import websocket as ws_router
|
||||
|
||||
|
||||
class _FakeWebSocket:
|
||||
def __init__(self, messages, token="valid-token"):
|
||||
self.query_params = {}
|
||||
if token is not None:
|
||||
self.query_params["token"] = token
|
||||
self._messages = list(messages)
|
||||
self.application_state = WebSocketState.CONNECTED
|
||||
self.close_calls = []
|
||||
|
||||
async def receive_json(self):
|
||||
if not self._messages:
|
||||
raise WebSocketDisconnect()
|
||||
next_item = self._messages.pop(0)
|
||||
if isinstance(next_item, BaseException):
|
||||
raise next_item
|
||||
return next_item
|
||||
|
||||
async def close(self, code=None, reason=None):
|
||||
self.close_calls.append({"code": code, "reason": reason})
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ScalarsResult:
|
||||
_items: list
|
||||
|
||||
def all(self):
|
||||
return list(self._items)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ExecuteResult:
|
||||
_items: list
|
||||
|
||||
def scalars(self):
|
||||
return _ScalarsResult(self._items)
|
||||
|
||||
|
||||
class _FakeAsyncDB:
|
||||
def __init__(self, user, conversation=None, segments=None):
|
||||
self.user = user
|
||||
self.conversation = conversation
|
||||
self.segments = list(segments or [])
|
||||
self.added = []
|
||||
self.commit_calls = 0
|
||||
self.refresh_calls = 0
|
||||
|
||||
async def get(self, model, key):
|
||||
model_name = getattr(model, "__name__", "")
|
||||
if model_name == "User":
|
||||
return self.user
|
||||
if model_name == "Conversation":
|
||||
if self.conversation and self.conversation.id == key:
|
||||
return self.conversation
|
||||
return None
|
||||
return None
|
||||
|
||||
def add(self, obj):
|
||||
self.added.append(obj)
|
||||
if isinstance(obj, Conversation):
|
||||
self.conversation = obj
|
||||
if isinstance(obj, Segment):
|
||||
self.segments.append(obj)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_calls += 1
|
||||
|
||||
async def refresh(self, obj):
|
||||
_ = obj
|
||||
self.refresh_calls += 1
|
||||
|
||||
async def execute(self, stmt):
|
||||
_ = stmt
|
||||
return _ExecuteResult(self.segments)
|
||||
|
||||
|
||||
class _FakeManager:
|
||||
def __init__(self):
|
||||
self.active_connections = {}
|
||||
self.sent_messages = []
|
||||
self.disconnect_calls = []
|
||||
self.background_runner = SimpleNamespace(
|
||||
queue_message=AsyncMock(),
|
||||
flush_pending=AsyncMock(),
|
||||
)
|
||||
self.conversation_agent = SimpleNamespace(
|
||||
generate_profile_greeting=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
async def connect(self, websocket, conversation_id):
|
||||
self.active_connections[conversation_id] = websocket
|
||||
|
||||
async def disconnect(self, conversation_id):
|
||||
self.disconnect_calls.append(conversation_id)
|
||||
self.active_connections.pop(conversation_id, None)
|
||||
|
||||
async def send_message(self, conversation_id, message):
|
||||
self.sent_messages.append({"conversation_id": conversation_id, "message": message})
|
||||
|
||||
|
||||
def _make_user():
|
||||
# Provide all profile fields to skip greeting/profile-collection branch.
|
||||
return SimpleNamespace(
|
||||
id="user-1",
|
||||
nickname="tester",
|
||||
subscription_type="premium",
|
||||
birth_year=1990,
|
||||
birth_place="A",
|
||||
grew_up_place="B",
|
||||
occupation="dev",
|
||||
)
|
||||
|
||||
|
||||
def _db_provider(db):
|
||||
async def _provider():
|
||||
yield db
|
||||
|
||||
return _provider
|
||||
|
||||
|
||||
class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_invalid_token_closes_connection(self):
|
||||
websocket = _FakeWebSocket(messages=[], token="invalid")
|
||||
|
||||
with patch.object(ws_router, "verify_token", return_value=None):
|
||||
await ws_router.websocket_endpoint(websocket, "conv-1")
|
||||
|
||||
self.assertEqual(len(websocket.close_calls), 1)
|
||||
self.assertEqual(websocket.close_calls[0]["reason"], "无效的认证令牌")
|
||||
|
||||
async def test_text_message_creates_segment_and_dispatches_agent(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{"type": "text", "data": {"text": "你好"}},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
|
||||
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
|
||||
self.assertEqual(len(segments), 1)
|
||||
self.assertEqual(segments[0].transcript_text, "你好")
|
||||
self.assertIsNone(segments[0].audio_url)
|
||||
fake_manager.background_runner.queue_message.assert_awaited_once()
|
||||
process_user_message_mock.assert_awaited_once()
|
||||
|
||||
message_types = [item["message"]["type"] for item in fake_manager.sent_messages]
|
||||
self.assertIn(ws_router.MessageType.CONNECT, message_types)
|
||||
|
||||
async def test_audio_message_transcribes_then_calls_agent(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_message",
|
||||
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 12},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(return_value="这是转写结果")
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
|
||||
transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=")
|
||||
process_user_message_mock.assert_awaited_once()
|
||||
|
||||
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
|
||||
self.assertEqual(len(segments), 1)
|
||||
self.assertEqual(segments[0].transcript_text, "这是转写结果")
|
||||
self.assertEqual(segments[0].audio_url, "audio:12s")
|
||||
|
||||
transcript_msgs = [
|
||||
item["message"]
|
||||
for item in fake_manager.sent_messages
|
||||
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
|
||||
]
|
||||
self.assertEqual(len(transcript_msgs), 1)
|
||||
self.assertEqual(transcript_msgs[0]["data"]["text"], "这是转写结果")
|
||||
|
||||
async def test_transcribe_only_returns_transcript_without_segment_or_agent(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "transcribe_only",
|
||||
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ="},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(return_value="仅转写文本")
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
|
||||
transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=")
|
||||
process_user_message_mock.assert_not_awaited()
|
||||
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 0)
|
||||
|
||||
transcript_msgs = [
|
||||
item["message"]
|
||||
for item in fake_manager.sent_messages
|
||||
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
|
||||
]
|
||||
self.assertEqual(len(transcript_msgs), 1)
|
||||
self.assertEqual(transcript_msgs[0]["data"]["text"], "仅转写文本")
|
||||
|
||||
async def test_end_conversation_updates_status_and_triggers_processing(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[{"type": "end_conversation", "conversation_id": "conv-1"}]
|
||||
)
|
||||
process_conversation_segments_mock = AsyncMock()
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"process_conversation_segments",
|
||||
process_conversation_segments_mock,
|
||||
)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
|
||||
self.assertEqual(conversation.status, "ended")
|
||||
process_conversation_segments_mock.assert_awaited_once()
|
||||
end_msgs = [
|
||||
item["message"]
|
||||
for item in fake_manager.sent_messages
|
||||
if item["message"]["type"] == ws_router.MessageType.END_CONVERSATION
|
||||
]
|
||||
self.assertEqual(len(end_msgs), 1)
|
||||
|
||||
async def test_transcribe_only_missing_audio_returns_error(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{"type": "transcribe_only", "data": {}},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
transcribe_mock = AsyncMock(return_value="should-not-be-called")
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
|
||||
transcribe_mock.assert_not_awaited()
|
||||
error_msgs = [
|
||||
item["message"]
|
||||
for item in fake_manager.sent_messages
|
||||
if item["message"]["type"] == ws_router.MessageType.ERROR
|
||||
]
|
||||
self.assertEqual(len(error_msgs), 1)
|
||||
self.assertEqual(error_msgs[0]["data"]["message"], "缺少 audio_base64")
|
||||
|
||||
async def test_audio_message_transcribe_failure_sends_error_and_skips_agent(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_message",
|
||||
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 8},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(return_value="转写失败: mock error")
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
|
||||
process_user_message_mock.assert_not_awaited()
|
||||
error_msgs = [
|
||||
item["message"]
|
||||
for item in fake_manager.sent_messages
|
||||
if item["message"]["type"] == ws_router.MessageType.ERROR
|
||||
]
|
||||
self.assertEqual(len(error_msgs), 1)
|
||||
self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user