Files
life-echo/api/tests/test_websocket_baseline.py
2026-03-19 14:36:40 +08:00

1499 lines
54 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import unittest
from contextlib import ExitStack
from dataclasses import dataclass
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from starlette.websockets import WebSocketDisconnect, WebSocketState
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.ws import router as ws_router
class _FakeWebSocket:
def __init__(self, messages, token="valid-token"):
self.query_params = {}
if token is not None:
self.query_params["token"] = token
self._messages = list(messages)
self.application_state = WebSocketState.CONNECTED
self.close_calls = []
async def receive_json(self):
if not self._messages:
raise WebSocketDisconnect()
next_item = self._messages.pop(0)
if isinstance(next_item, BaseException):
raise next_item
return next_item
async def close(self, code=None, reason=None):
self.close_calls.append({"code": code, "reason": reason})
self.application_state = WebSocketState.DISCONNECTED
@dataclass
class _ScalarsResult:
_items: list
def all(self):
return list(self._items)
@dataclass
class _ExecuteResult:
_items: list
def scalars(self):
return _ScalarsResult(self._items)
def scalar_one_or_none(self):
return self._items[0] if len(self._items) == 1 else None
class _FakeAsyncDB:
def __init__(self, user, conversation=None, segments=None):
self.user = user
self.conversation = conversation
self.segments = list(segments or [])
self.added = []
self.commit_calls = 0
self.refresh_calls = 0
self.state_result = None # 若为 Noneget_or_create_state 的查询返回空
async def get(self, model, key):
model_name = getattr(model, "__name__", "")
if model_name == "User":
return self.user
if model_name == "Conversation":
if self.conversation and self.conversation.id == key:
return self.conversation
return None
return None
def add(self, obj):
self.added.append(obj)
if isinstance(obj, Conversation):
self.conversation = obj
if isinstance(obj, Segment):
self.segments.append(obj)
async def commit(self):
self.commit_calls += 1
async def refresh(self, obj):
_ = obj
self.refresh_calls += 1
async def execute(self, stmt):
stmt_str = str(stmt)
if "MemoirState" in stmt_str or "memoir_state" in stmt_str:
return _ExecuteResult(
[self.state_result] if self.state_result is not None else []
)
return _ExecuteResult(self.segments)
class _FakeManager:
def __init__(self):
self.active_connections = {}
self.segment_states = {}
self.sent_messages = []
self.disconnect_calls = []
self.background_runner = SimpleNamespace(
queue_message=AsyncMock(),
flush_pending=AsyncMock(),
)
self.conversation_agent = SimpleNamespace(
generate_profile_greeting=AsyncMock(return_value=[]),
generate_opening_message=AsyncMock(return_value=[]),
)
async def connect(self, websocket, conversation_id):
self.active_connections[conversation_id] = websocket
async def disconnect(self, conversation_id):
self.disconnect_calls.append(conversation_id)
self.active_connections.pop(conversation_id, None)
async def send_message(self, conversation_id, message):
self.sent_messages.append(
{"conversation_id": conversation_id, "message": message}
)
def get_or_create_segment_state(self, conversation_id, voice_session_id):
state_key = (conversation_id, voice_session_id)
if state_key not in self.segment_states:
self.segment_states[state_key] = ws_router.SegmentStreamState()
return self.segment_states[state_key]
def register_segment_task(self, conversation_id, voice_session_id, task):
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
state.active_tasks.add(task)
def _cleanup(done_task):
state.active_tasks.discard(done_task)
task.add_done_callback(_cleanup)
def _make_user():
# Provide all profile fields to skip greeting/profile-collection branch.
return SimpleNamespace(
id="user-1",
nickname="tester",
subscription_type="premium",
birth_year=1990,
birth_place="A",
grew_up_place="B",
occupation="dev",
)
def _db_provider(db):
"""返回可被 patch 到 get_async_db 的异步生成器(旧用法)。"""
async def _provider():
yield db
return _provider
class _FakeSessionCM:
"""模拟 async with AsyncSessionLocal() as db 的上下文管理器。"""
def __init__(self, db):
self._db = db
async def __aenter__(self):
return self._db
async def __aexit__(self, *args):
pass
def _session_local_factory(fake_db):
"""返回可 patch 到 AsyncSessionLocal 的工厂,使 async with AsyncSessionLocal() as db 得到 fake_db。"""
def _factory():
return _FakeSessionCM(fake_db)
return _factory
def _redis_empty_history_patch():
"""Patch redis to return empty history so websocket sends opening (or skips if mocked)."""
return patch.object(
ws_router.redis_service,
"get_conversation_history",
new=AsyncMock(return_value=[]),
)
class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
async def test_invalid_token_closes_connection(self):
websocket = _FakeWebSocket(messages=[], token="invalid")
with patch.object(ws_router, "verify_token", return_value=None):
await ws_router.websocket_endpoint(websocket, "conv-1")
self.assertEqual(len(websocket.close_calls), 1)
self.assertEqual(websocket.close_calls[0]["reason"], "无效的认证令牌")
async def test_text_message_creates_segment_and_dispatches_agent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{"type": "text", "data": {"text": "你好"}},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.router.check_ws_quota",
new=AsyncMock(return_value=(True, "")),
)
)
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
self.assertEqual(len(segments), 1)
self.assertEqual(segments[0].transcript_text, "你好")
self.assertIsNone(segments[0].audio_url)
self.assertIsNotNone(conversation.last_message_at)
fake_manager.background_runner.queue_message.assert_awaited_once()
process_user_message_mock.assert_awaited_once()
message_types = [item["message"]["type"] for item in fake_manager.sent_messages]
self.assertIn(ws_router.MessageType.CONNECT, message_types)
async def test_audio_message_transcribes_then_calls_agent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_message",
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 12},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="这是转写结果")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.router.check_ws_quota",
new=AsyncMock(return_value=(True, "")),
)
)
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.process_user_message",
process_user_message_mock,
)
)
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
transcribe_mock.assert_awaited_once_with(b"fake-audio-b64", "m4a")
process_user_message_mock.assert_awaited_once()
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
self.assertEqual(len(segments), 1)
self.assertEqual(segments[0].transcript_text, "这是转写结果")
self.assertEqual(segments[0].audio_url, "audio:12s")
self.assertIsNotNone(conversation.last_message_at)
transcript_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
]
self.assertEqual(len(transcript_msgs), 1)
self.assertEqual(transcript_msgs[0]["data"]["text"], "这是转写结果")
async def test_transcribe_only_returns_transcript_without_segment_or_agent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "transcribe_only",
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ="},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="仅转写文本")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.process_user_message",
process_user_message_mock,
)
)
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
transcribe_mock.assert_awaited_once_with(b"fake-audio-b64", "m4a")
process_user_message_mock.assert_not_awaited()
self.assertEqual(
len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 0
)
transcript_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
]
self.assertEqual(len(transcript_msgs), 1)
self.assertEqual(transcript_msgs[0]["data"]["text"], "仅转写文本")
async def test_end_conversation_updates_status_and_triggers_processing(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[{"type": "end_conversation", "conversation_id": "conv-1"}]
)
process_conversation_segments_mock = AsyncMock()
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch.object(
ws_router,
"process_conversation_segments",
process_conversation_segments_mock,
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
self.assertEqual(conversation.status, "ended")
process_conversation_segments_mock.assert_awaited_once()
end_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.END_CONVERSATION
]
self.assertEqual(len(end_msgs), 1)
async def test_transcribe_only_missing_audio_returns_error(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{"type": "transcribe_only", "data": {}},
WebSocketDisconnect(),
]
)
transcribe_mock = AsyncMock(return_value="should-not-be-called")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
transcribe_mock.assert_not_awaited()
error_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.ERROR
]
self.assertEqual(len(error_msgs), 1)
self.assertEqual(error_msgs[0]["data"]["message"], "缺少 audio_base64")
async def test_audio_message_transcribe_failure_sends_error_and_skips_agent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_message",
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 8},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="转写失败: mock error")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.router.check_ws_quota",
new=AsyncMock(return_value=(True, "")),
)
)
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.process_user_message",
process_user_message_mock,
)
)
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
process_user_message_mock.assert_not_awaited()
error_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.ERROR
]
self.assertEqual(len(error_msgs), 1)
self.assertEqual(
error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入"
)
async def test_audio_segment_out_of_order_is_aggregated_by_segment_index(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-1",
"voice_session_id": "voice-session-1",
"segment_index": 1,
"duration": 12,
"is_last": False,
},
},
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-0",
"voice_session_id": "voice-session-1",
"segment_index": 0,
"duration": 10,
"is_last": False,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(side_effect=["这是第 1 段", "这是第 0 段"])
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.router.check_ws_quota",
new=AsyncMock(return_value=(True, "")),
)
)
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.process_user_message",
process_user_message_mock,
)
)
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
self.assertEqual(transcribe_mock.await_count, 2)
ordered_messages = [
call.kwargs["user_message"]
for call in process_user_message_mock.await_args_list
]
self.assertEqual(ordered_messages, ["这是第 0 段", "这是第 1 段"])
self.assertEqual(
len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2
)
transcript_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
]
self.assertEqual(
[msg["data"]["voice_session_id"] for msg in transcript_msgs],
["voice-session-1", "voice-session-1"],
)
async def test_audio_segment_duplicate_index_is_idempotent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "dup-seg-0-a",
"segment_index": 0,
"duration": 10,
"is_last": False,
},
},
{
"type": "audio_segment",
"data": {
"audio_base64": "dup-seg-0-b",
"segment_index": 0,
"duration": 10,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="重复分段去重测试")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.router.check_ws_quota",
new=AsyncMock(return_value=(True, "")),
)
)
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.process_user_message",
process_user_message_mock,
)
)
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
self.assertEqual(transcribe_mock.await_count, 1)
process_user_message_mock.assert_awaited_once()
self.assertEqual(
len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 1
)
async def test_audio_segment_same_index_is_allowed_for_different_voice_sessions(
self,
):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "session-a-seg-0",
"voice_session_id": "voice-session-a",
"client_segment_id": "voice-session-a-0",
"segment_index": 0,
"duration": 10,
"is_last": True,
},
},
{
"type": "audio_segment",
"data": {
"audio_base64": "session-b-seg-0",
"voice_session_id": "voice-session-b",
"client_segment_id": "voice-session-b-0",
"segment_index": 0,
"duration": 8,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(side_effect=["第一轮第 0 段", "第二轮第 0 段"])
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.router.check_ws_quota",
new=AsyncMock(return_value=(True, "")),
)
)
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.process_user_message",
process_user_message_mock,
)
)
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
ordered_messages = [
call.kwargs["user_message"]
for call in process_user_message_mock.await_args_list
]
self.assertEqual(ordered_messages, ["第一轮第 0 段", "第二轮第 0 段"])
self.assertEqual(transcribe_mock.await_count, 2)
self.assertEqual(
len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2
)
async def test_audio_segment_sends_transition_feedback_while_processing(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "slow-seg-0",
"segment_index": 0,
"duration": 20,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
async def _slow_transcribe(_: str = None, **kwargs) -> str:
await asyncio.sleep(0.2)
return "慢速转写"
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(side_effect=_slow_transcribe)
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.router.check_ws_quota",
new=AsyncMock(return_value=(True, "")),
)
)
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.process_user_message",
process_user_message_mock,
)
)
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
# 当前逻辑:仅首个分段且 5s 后发一次「我在认真听」;若本段 is_last 则取消,故此处应为 0
transition_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
and item["message"].get("data", {}).get("transition") is True
]
self.assertEqual(len(transition_msgs), 0)
async def test_recording_started_sends_listening_feedback_after_delay(self):
"""客户端发送 recording_started 后,延迟 5s 发一次「我在认真听」。"""
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "recording_started",
"data": {"voice_session_id": "session-1"},
},
WebSocketDisconnect(),
]
)
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.LISTENING_FEEDBACK_DELAY_SEC",
0.05,
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.12)
transition_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
and item["message"].get("data", {}).get("transition") is True
]
self.assertEqual(len(transition_msgs), 1)
self.assertIn("我在认真听", transition_msgs[0]["data"].get("text", ""))
async def test_audio_segment_last_segment_does_not_emit_terminal_transition(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "last-seg-0",
"voice_session_id": "voice-session-last",
"client_segment_id": "voice-session-last-0",
"segment_index": 0,
"duration": 15,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="最后一段转写")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.router.check_ws_quota",
new=AsyncMock(return_value=(True, "")),
)
)
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.process_user_message",
process_user_message_mock,
)
)
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
# 仅一段且 is_last延迟任务被取消不应发出 transition
transition_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
and item["message"].get("data", {}).get("transition") is True
]
self.assertEqual(len(transition_msgs), 0)
async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(
self,
):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
existing_segment = Segment(
id="seg-existing-0",
conversation_id="conv-1",
transcript_text="已存在的上一段",
audio_url="audio-segment:voice-session-1:0",
processed=False,
)
fake_db = _FakeAsyncDB(
user=user,
conversation=conversation,
segments=[existing_segment],
)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-1-after-reconnect",
"voice_session_id": "voice-session-1",
"client_segment_id": "voice-session-1-1",
"segment_index": 1,
"duration": 18,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="这是重连后的第 1 段")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.router.check_ws_quota",
new=AsyncMock(return_value=(True, "")),
)
)
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.process_user_message",
process_user_message_mock,
)
)
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
process_user_message_mock.assert_awaited_once()
self.assertEqual(
process_user_message_mock.await_args.kwargs["user_message"],
"这是重连后的第 1 段",
)
async def test_audio_segment_reconnect_uses_contiguous_prefix_not_max_index(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
existing_segments = [
Segment(
id="seg-existing-0",
conversation_id="conv-1",
transcript_text="已存在的第 0 段",
audio_url="audio-segment:voice-session-gap:0",
processed=False,
),
Segment(
id="seg-existing-2",
conversation_id="conv-1",
transcript_text="已存在的第 2 段",
audio_url="audio-segment:voice-session-gap:2",
processed=False,
),
]
fake_db = _FakeAsyncDB(
user=user,
conversation=conversation,
segments=existing_segments,
)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-1-gap-retry",
"voice_session_id": "voice-session-gap",
"client_segment_id": "voice-session-gap-1",
"segment_index": 1,
"duration": 18,
"is_last": False,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="补传后的第 1 段")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
_session_local_factory(fake_db),
)
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
)
stack.enter_context(
patch.object(
ws_router, "background_runner", fake_manager.background_runner
)
)
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch(
"app.features.conversation.ws.router.check_ws_quota",
new=AsyncMock(return_value=(True, "")),
)
)
stack.enter_context(
patch.object(
ws_router, "process_user_message", process_user_message_mock
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.process_user_message",
process_user_message_mock,
)
)
stack.enter_context(
patch.object(
ws_router,
"get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
stack.enter_context(
patch(
"app.features.conversation.ws.pipeline.get_asr_provider",
lambda: SimpleNamespace(transcribe=transcribe_mock),
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
process_user_message_mock.assert_awaited_once()
self.assertEqual(
process_user_message_mock.await_args.kwargs["user_message"],
"补传后的第 1 段",
)
self.assertEqual(transcribe_mock.await_count, 1)
if __name__ == "__main__":
unittest.main()