import unittest from contextlib import ExitStack from dataclasses import dataclass from types import SimpleNamespace from unittest.mock import AsyncMock, patch from starlette.websockets import WebSocketDisconnect, WebSocketState from database.models import Conversation, Segment from routers import websocket as ws_router class _FakeWebSocket: def __init__(self, messages, token="valid-token"): self.query_params = {} if token is not None: self.query_params["token"] = token self._messages = list(messages) self.application_state = WebSocketState.CONNECTED self.close_calls = [] async def receive_json(self): if not self._messages: raise WebSocketDisconnect() next_item = self._messages.pop(0) if isinstance(next_item, BaseException): raise next_item return next_item async def close(self, code=None, reason=None): self.close_calls.append({"code": code, "reason": reason}) self.application_state = WebSocketState.DISCONNECTED @dataclass class _ScalarsResult: _items: list def all(self): return list(self._items) @dataclass class _ExecuteResult: _items: list def scalars(self): return _ScalarsResult(self._items) class _FakeAsyncDB: def __init__(self, user, conversation=None, segments=None): self.user = user self.conversation = conversation self.segments = list(segments or []) self.added = [] self.commit_calls = 0 self.refresh_calls = 0 async def get(self, model, key): model_name = getattr(model, "__name__", "") if model_name == "User": return self.user if model_name == "Conversation": if self.conversation and self.conversation.id == key: return self.conversation return None return None def add(self, obj): self.added.append(obj) if isinstance(obj, Conversation): self.conversation = obj if isinstance(obj, Segment): self.segments.append(obj) async def commit(self): self.commit_calls += 1 async def refresh(self, obj): _ = obj self.refresh_calls += 1 async def execute(self, stmt): _ = stmt return _ExecuteResult(self.segments) class _FakeManager: def __init__(self): self.active_connections = {} self.sent_messages = [] self.disconnect_calls = [] self.background_runner = SimpleNamespace( queue_message=AsyncMock(), flush_pending=AsyncMock(), ) self.conversation_agent = SimpleNamespace( generate_profile_greeting=AsyncMock(return_value=[]), ) async def connect(self, websocket, conversation_id): self.active_connections[conversation_id] = websocket async def disconnect(self, conversation_id): self.disconnect_calls.append(conversation_id) self.active_connections.pop(conversation_id, None) async def send_message(self, conversation_id, message): self.sent_messages.append({"conversation_id": conversation_id, "message": message}) def _make_user(): # Provide all profile fields to skip greeting/profile-collection branch. return SimpleNamespace( id="user-1", nickname="tester", subscription_type="premium", birth_year=1990, birth_place="A", grew_up_place="B", occupation="dev", ) def _db_provider(db): async def _provider(): yield db return _provider class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): async def test_invalid_token_closes_connection(self): websocket = _FakeWebSocket(messages=[], token="invalid") with patch.object(ws_router, "verify_token", return_value=None): await ws_router.websocket_endpoint(websocket, "conv-1") self.assertEqual(len(websocket.close_calls), 1) self.assertEqual(websocket.close_calls[0]["reason"], "无效的认证令牌") async def test_text_message_creates_segment_and_dispatches_agent(self): user = _make_user() conversation = Conversation(id="conv-1", user_id=user.id, status="active") fake_db = _FakeAsyncDB(user=user, conversation=conversation) fake_manager = _FakeManager() fake_websocket = _FakeWebSocket( messages=[ {"type": "text", "data": {"text": "你好"}}, WebSocketDisconnect(), ] ) process_user_message_mock = AsyncMock() with ExitStack() as stack: stack.enter_context( patch.object( ws_router, "verify_token", return_value={"type": "access", "sub": user.id}, ) ) stack.enter_context( patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) stack.enter_context( patch("routers.quota.check_can_send_message", return_value=(True, "")) ) stack.enter_context( patch.object(ws_router, "process_user_message", process_user_message_mock) ) await ws_router.websocket_endpoint(fake_websocket, "conv-1") segments = [obj for obj in fake_db.added if isinstance(obj, Segment)] self.assertEqual(len(segments), 1) self.assertEqual(segments[0].transcript_text, "你好") self.assertIsNone(segments[0].audio_url) fake_manager.background_runner.queue_message.assert_awaited_once() process_user_message_mock.assert_awaited_once() message_types = [item["message"]["type"] for item in fake_manager.sent_messages] self.assertIn(ws_router.MessageType.CONNECT, message_types) async def test_audio_message_transcribes_then_calls_agent(self): user = _make_user() conversation = Conversation(id="conv-1", user_id=user.id, status="active") fake_db = _FakeAsyncDB(user=user, conversation=conversation) fake_manager = _FakeManager() fake_websocket = _FakeWebSocket( messages=[ { "type": "audio_message", "data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 12}, }, WebSocketDisconnect(), ] ) process_user_message_mock = AsyncMock() transcribe_mock = AsyncMock(return_value="这是转写结果") with ExitStack() as stack: stack.enter_context( patch.object( ws_router, "verify_token", return_value={"type": "access", "sub": user.id}, ) ) stack.enter_context( patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) stack.enter_context( patch("routers.quota.check_can_send_message", return_value=(True, "")) ) stack.enter_context( patch.object(ws_router, "process_user_message", process_user_message_mock) ) stack.enter_context( patch.object(ws_router.asr_service, "transcribe", transcribe_mock) ) await ws_router.websocket_endpoint(fake_websocket, "conv-1") transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=") process_user_message_mock.assert_awaited_once() segments = [obj for obj in fake_db.added if isinstance(obj, Segment)] self.assertEqual(len(segments), 1) self.assertEqual(segments[0].transcript_text, "这是转写结果") self.assertEqual(segments[0].audio_url, "audio:12s") transcript_msgs = [ item["message"] for item in fake_manager.sent_messages if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT ] self.assertEqual(len(transcript_msgs), 1) self.assertEqual(transcript_msgs[0]["data"]["text"], "这是转写结果") async def test_transcribe_only_returns_transcript_without_segment_or_agent(self): user = _make_user() conversation = Conversation(id="conv-1", user_id=user.id, status="active") fake_db = _FakeAsyncDB(user=user, conversation=conversation) fake_manager = _FakeManager() fake_websocket = _FakeWebSocket( messages=[ { "type": "transcribe_only", "data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ="}, }, WebSocketDisconnect(), ] ) process_user_message_mock = AsyncMock() transcribe_mock = AsyncMock(return_value="仅转写文本") with ExitStack() as stack: stack.enter_context( patch.object( ws_router, "verify_token", return_value={"type": "access", "sub": user.id}, ) ) stack.enter_context( patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context( patch.object(ws_router, "process_user_message", process_user_message_mock) ) stack.enter_context( patch.object(ws_router.asr_service, "transcribe", transcribe_mock) ) await ws_router.websocket_endpoint(fake_websocket, "conv-1") transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=") process_user_message_mock.assert_not_awaited() self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 0) transcript_msgs = [ item["message"] for item in fake_manager.sent_messages if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT ] self.assertEqual(len(transcript_msgs), 1) self.assertEqual(transcript_msgs[0]["data"]["text"], "仅转写文本") async def test_end_conversation_updates_status_and_triggers_processing(self): user = _make_user() conversation = Conversation(id="conv-1", user_id=user.id, status="active") fake_db = _FakeAsyncDB(user=user, conversation=conversation) fake_manager = _FakeManager() fake_websocket = _FakeWebSocket( messages=[{"type": "end_conversation", "conversation_id": "conv-1"}] ) process_conversation_segments_mock = AsyncMock() with ExitStack() as stack: stack.enter_context( patch.object( ws_router, "verify_token", return_value={"type": "access", "sub": user.id}, ) ) stack.enter_context( patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context( patch.object( ws_router, "process_conversation_segments", process_conversation_segments_mock, ) ) await ws_router.websocket_endpoint(fake_websocket, "conv-1") self.assertEqual(conversation.status, "ended") process_conversation_segments_mock.assert_awaited_once() end_msgs = [ item["message"] for item in fake_manager.sent_messages if item["message"]["type"] == ws_router.MessageType.END_CONVERSATION ] self.assertEqual(len(end_msgs), 1) async def test_transcribe_only_missing_audio_returns_error(self): user = _make_user() conversation = Conversation(id="conv-1", user_id=user.id, status="active") fake_db = _FakeAsyncDB(user=user, conversation=conversation) fake_manager = _FakeManager() fake_websocket = _FakeWebSocket( messages=[ {"type": "transcribe_only", "data": {}}, WebSocketDisconnect(), ] ) transcribe_mock = AsyncMock(return_value="should-not-be-called") with ExitStack() as stack: stack.enter_context( patch.object( ws_router, "verify_token", return_value={"type": "access", "sub": user.id}, ) ) stack.enter_context( patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context( patch.object(ws_router.asr_service, "transcribe", transcribe_mock) ) await ws_router.websocket_endpoint(fake_websocket, "conv-1") transcribe_mock.assert_not_awaited() error_msgs = [ item["message"] for item in fake_manager.sent_messages if item["message"]["type"] == ws_router.MessageType.ERROR ] self.assertEqual(len(error_msgs), 1) self.assertEqual(error_msgs[0]["data"]["message"], "缺少 audio_base64") async def test_audio_message_transcribe_failure_sends_error_and_skips_agent(self): user = _make_user() conversation = Conversation(id="conv-1", user_id=user.id, status="active") fake_db = _FakeAsyncDB(user=user, conversation=conversation) fake_manager = _FakeManager() fake_websocket = _FakeWebSocket( messages=[ { "type": "audio_message", "data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 8}, }, WebSocketDisconnect(), ] ) process_user_message_mock = AsyncMock() transcribe_mock = AsyncMock(return_value="转写失败: mock error") with ExitStack() as stack: stack.enter_context( patch.object( ws_router, "verify_token", return_value={"type": "access", "sub": user.id}, ) ) stack.enter_context( patch.object(ws_router, "get_async_db", _db_provider(fake_db)) ) stack.enter_context(patch.object(ws_router, "manager", fake_manager)) stack.enter_context( patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0)) ) stack.enter_context( patch("routers.quota.check_can_send_message", return_value=(True, "")) ) stack.enter_context( patch.object(ws_router, "process_user_message", process_user_message_mock) ) stack.enter_context( patch.object(ws_router.asr_service, "transcribe", transcribe_mock) ) await ws_router.websocket_endpoint(fake_websocket, "conv-1") process_user_message_mock.assert_not_awaited() error_msgs = [ item["message"] for item in fake_manager.sent_messages if item["message"]["type"] == ws_router.MessageType.ERROR ] self.assertEqual(len(error_msgs), 1) self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入") if __name__ == "__main__": unittest.main()