diff --git a/api/tests/test_websocket_baseline.py b/api/tests/test_websocket_baseline.py new file mode 100644 index 0000000..51a29ed --- /dev/null +++ b/api/tests/test_websocket_baseline.py @@ -0,0 +1,440 @@ +import unittest +from contextlib import ExitStack +from dataclasses import dataclass +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +from starlette.websockets import WebSocketDisconnect, WebSocketState + +from database.models import Conversation, Segment +from routers import websocket as ws_router + + +class _FakeWebSocket: + def __init__(self, messages, token="valid-token"): + self.query_params = {} + if token is not None: + self.query_params["token"] = token + self._messages = list(messages) + self.application_state = WebSocketState.CONNECTED + self.close_calls = [] + + async def receive_json(self): + if not self._messages: + raise WebSocketDisconnect() + next_item = self._messages.pop(0) + if isinstance(next_item, BaseException): + raise next_item + return next_item + + async def close(self, code=None, reason=None): + self.close_calls.append({"code": code, "reason": reason}) + self.application_state = WebSocketState.DISCONNECTED + + +@dataclass +class _ScalarsResult: + _items: list + + def all(self): + return list(self._items) + + +@dataclass +class _ExecuteResult: + _items: list + + def scalars(self): + return _ScalarsResult(self._items) + + +class _FakeAsyncDB: + def __init__(self, user, conversation=None, segments=None): + self.user = user + self.conversation = conversation + self.segments = list(segments or []) + self.added = [] + self.commit_calls = 0 + self.refresh_calls = 0 + + async def get(self, model, key): + model_name = getattr(model, "__name__", "") + if model_name == "User": + return self.user + if model_name == "Conversation": + if self.conversation and self.conversation.id == key: + return self.conversation + return None + return None + + def add(self, obj): + self.added.append(obj) + if isinstance(obj, Conversation): + self.conversation = obj + if isinstance(obj, Segment): + self.segments.append(obj) + + async def commit(self): + self.commit_calls += 1 + + async def refresh(self, obj): + _ = obj + self.refresh_calls += 1 + + async def execute(self, stmt): + _ = stmt + return _ExecuteResult(self.segments) + + +class _FakeManager: + def __init__(self): + self.active_connections = {} + self.sent_messages = [] + self.disconnect_calls = [] + self.background_runner = SimpleNamespace( + queue_message=AsyncMock(), + flush_pending=AsyncMock(), + ) + self.conversation_agent = SimpleNamespace( + generate_profile_greeting=AsyncMock(return_value=[]), + ) + + async def connect(self, websocket, conversation_id): + self.active_connections[conversation_id] = websocket + + async def disconnect(self, conversation_id): + self.disconnect_calls.append(conversation_id) + self.active_connections.pop(conversation_id, None) + + async def send_message(self, conversation_id, message): + self.sent_messages.append({"conversation_id": conversation_id, "message": message}) + + +def _make_user(): + # Provide all profile fields to skip greeting/profile-collection branch. + return SimpleNamespace( + id="user-1", + nickname="tester", + subscription_type="premium", + birth_year=1990, + birth_place="A", + grew_up_place="B", + occupation="dev", + ) + + +def _db_provider(db): + async def _provider(): + yield db + + return _provider + + +class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): + async def test_invalid_token_closes_connection(self): + websocket = _FakeWebSocket(messages=[], token="invalid") + + with patch.object(ws_router, "verify_token", return_value=None): + await ws_router.websocket_endpoint(websocket, "conv-1") + + self.assertEqual(len(websocket.close_calls), 1) + self.assertEqual(websocket.close_calls[0]["reason"], "无效的认证令牌") + + async def test_text_message_creates_segment_and_dispatches_agent(self): + user = _make_user() + conversation = Conversation(id="conv-1", user_id=user.id, status="active") + fake_db = _FakeAsyncDB(user=user, conversation=conversation) + fake_manager = _FakeManager() + fake_websocket = _FakeWebSocket( + messages=[ + {"type": "text", "data": {"text": "你好"}}, + WebSocketDisconnect(), + ] + ) + + process_user_message_mock = AsyncMock() + + with ExitStack() as stack: + stack.enter_context( + patch.object( + ws_router, + "verify_token", + return_value={"type": "access", "sub": user.id}, + ) + ) + stack.enter_context( + patch.object(ws_router, "get_async_db", _db_provider(fake_db)) + ) + stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context( + patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) + ) + stack.enter_context( + patch("routers.quota.check_can_send_message", return_value=(True, "")) + ) + stack.enter_context( + patch.object(ws_router, "process_user_message", process_user_message_mock) + ) + + await ws_router.websocket_endpoint(fake_websocket, "conv-1") + + segments = [obj for obj in fake_db.added if isinstance(obj, Segment)] + self.assertEqual(len(segments), 1) + self.assertEqual(segments[0].transcript_text, "你好") + self.assertIsNone(segments[0].audio_url) + fake_manager.background_runner.queue_message.assert_awaited_once() + process_user_message_mock.assert_awaited_once() + + message_types = [item["message"]["type"] for item in fake_manager.sent_messages] + self.assertIn(ws_router.MessageType.CONNECT, message_types) + + async def test_audio_message_transcribes_then_calls_agent(self): + user = _make_user() + conversation = Conversation(id="conv-1", user_id=user.id, status="active") + fake_db = _FakeAsyncDB(user=user, conversation=conversation) + fake_manager = _FakeManager() + fake_websocket = _FakeWebSocket( + messages=[ + { + "type": "audio_message", + "data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 12}, + }, + WebSocketDisconnect(), + ] + ) + + process_user_message_mock = AsyncMock() + transcribe_mock = AsyncMock(return_value="这是转写结果") + + with ExitStack() as stack: + stack.enter_context( + patch.object( + ws_router, + "verify_token", + return_value={"type": "access", "sub": user.id}, + ) + ) + stack.enter_context( + patch.object(ws_router, "get_async_db", _db_provider(fake_db)) + ) + stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context( + patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) + ) + stack.enter_context( + patch("routers.quota.check_can_send_message", return_value=(True, "")) + ) + stack.enter_context( + patch.object(ws_router, "process_user_message", process_user_message_mock) + ) + stack.enter_context( + patch.object(ws_router.asr_service, "transcribe", transcribe_mock) + ) + + await ws_router.websocket_endpoint(fake_websocket, "conv-1") + + transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=") + process_user_message_mock.assert_awaited_once() + + segments = [obj for obj in fake_db.added if isinstance(obj, Segment)] + self.assertEqual(len(segments), 1) + self.assertEqual(segments[0].transcript_text, "这是转写结果") + self.assertEqual(segments[0].audio_url, "audio:12s") + + transcript_msgs = [ + item["message"] + for item in fake_manager.sent_messages + if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT + ] + self.assertEqual(len(transcript_msgs), 1) + self.assertEqual(transcript_msgs[0]["data"]["text"], "这是转写结果") + + async def test_transcribe_only_returns_transcript_without_segment_or_agent(self): + user = _make_user() + conversation = Conversation(id="conv-1", user_id=user.id, status="active") + fake_db = _FakeAsyncDB(user=user, conversation=conversation) + fake_manager = _FakeManager() + fake_websocket = _FakeWebSocket( + messages=[ + { + "type": "transcribe_only", + "data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ="}, + }, + WebSocketDisconnect(), + ] + ) + + process_user_message_mock = AsyncMock() + transcribe_mock = AsyncMock(return_value="仅转写文本") + + with ExitStack() as stack: + stack.enter_context( + patch.object( + ws_router, + "verify_token", + return_value={"type": "access", "sub": user.id}, + ) + ) + stack.enter_context( + patch.object(ws_router, "get_async_db", _db_provider(fake_db)) + ) + stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context( + patch.object(ws_router, "process_user_message", process_user_message_mock) + ) + stack.enter_context( + patch.object(ws_router.asr_service, "transcribe", transcribe_mock) + ) + + await ws_router.websocket_endpoint(fake_websocket, "conv-1") + + transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=") + process_user_message_mock.assert_not_awaited() + self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 0) + + transcript_msgs = [ + item["message"] + for item in fake_manager.sent_messages + if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT + ] + self.assertEqual(len(transcript_msgs), 1) + self.assertEqual(transcript_msgs[0]["data"]["text"], "仅转写文本") + + async def test_end_conversation_updates_status_and_triggers_processing(self): + user = _make_user() + conversation = Conversation(id="conv-1", user_id=user.id, status="active") + fake_db = _FakeAsyncDB(user=user, conversation=conversation) + fake_manager = _FakeManager() + fake_websocket = _FakeWebSocket( + messages=[{"type": "end_conversation", "conversation_id": "conv-1"}] + ) + process_conversation_segments_mock = AsyncMock() + + with ExitStack() as stack: + stack.enter_context( + patch.object( + ws_router, + "verify_token", + return_value={"type": "access", "sub": user.id}, + ) + ) + stack.enter_context( + patch.object(ws_router, "get_async_db", _db_provider(fake_db)) + ) + stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context( + patch.object( + ws_router, + "process_conversation_segments", + process_conversation_segments_mock, + ) + ) + + await ws_router.websocket_endpoint(fake_websocket, "conv-1") + + self.assertEqual(conversation.status, "ended") + process_conversation_segments_mock.assert_awaited_once() + end_msgs = [ + item["message"] + for item in fake_manager.sent_messages + if item["message"]["type"] == ws_router.MessageType.END_CONVERSATION + ] + self.assertEqual(len(end_msgs), 1) + + async def test_transcribe_only_missing_audio_returns_error(self): + user = _make_user() + conversation = Conversation(id="conv-1", user_id=user.id, status="active") + fake_db = _FakeAsyncDB(user=user, conversation=conversation) + fake_manager = _FakeManager() + fake_websocket = _FakeWebSocket( + messages=[ + {"type": "transcribe_only", "data": {}}, + WebSocketDisconnect(), + ] + ) + transcribe_mock = AsyncMock(return_value="should-not-be-called") + + with ExitStack() as stack: + stack.enter_context( + patch.object( + ws_router, + "verify_token", + return_value={"type": "access", "sub": user.id}, + ) + ) + stack.enter_context( + patch.object(ws_router, "get_async_db", _db_provider(fake_db)) + ) + stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context( + patch.object(ws_router.asr_service, "transcribe", transcribe_mock) + ) + + await ws_router.websocket_endpoint(fake_websocket, "conv-1") + + transcribe_mock.assert_not_awaited() + error_msgs = [ + item["message"] + for item in fake_manager.sent_messages + if item["message"]["type"] == ws_router.MessageType.ERROR + ] + self.assertEqual(len(error_msgs), 1) + self.assertEqual(error_msgs[0]["data"]["message"], "缺少 audio_base64") + + async def test_audio_message_transcribe_failure_sends_error_and_skips_agent(self): + user = _make_user() + conversation = Conversation(id="conv-1", user_id=user.id, status="active") + fake_db = _FakeAsyncDB(user=user, conversation=conversation) + fake_manager = _FakeManager() + fake_websocket = _FakeWebSocket( + messages=[ + { + "type": "audio_message", + "data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 8}, + }, + WebSocketDisconnect(), + ] + ) + + process_user_message_mock = AsyncMock() + transcribe_mock = AsyncMock(return_value="转写失败: mock error") + + with ExitStack() as stack: + stack.enter_context( + patch.object( + ws_router, + "verify_token", + return_value={"type": "access", "sub": user.id}, + ) + ) + stack.enter_context( + patch.object(ws_router, "get_async_db", _db_provider(fake_db)) + ) + stack.enter_context(patch.object(ws_router, "manager", fake_manager)) + stack.enter_context( + patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) + ) + stack.enter_context( + patch("routers.quota.check_can_send_message", return_value=(True, "")) + ) + stack.enter_context( + patch.object(ws_router, "process_user_message", process_user_message_mock) + ) + stack.enter_context( + patch.object(ws_router.asr_service, "transcribe", transcribe_mock) + ) + + await ws_router.websocket_endpoint(fake_websocket, "conv-1") + + process_user_message_mock.assert_not_awaited() + error_msgs = [ + item["message"] + for item in fake_manager.sent_messages + if item["message"]["type"] == ws_router.MessageType.ERROR + ] + self.assertEqual(len(error_msgs), 1) + self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入") + + +if __name__ == "__main__": + unittest.main()