"""WS pipeline 回归:TTS 取消不能丢段,AGENT_RESPONSE 必须为每段下发。 - 历史 bug:客户端在多段回复中途发送 ``tts_cancel`` 时, ``process_user_message`` 在 TTS 分支 ``break``,导致剩余段的 ``agent_response`` 被静默丢弃;FE 失去后续文本气泡,并可能停留在 "正在回复…" 状态。 - 期望:取消仅影响后续 TTS 合成,AGENT_RESPONSE 必须为每段完整下发。 - 同时校验:``responses`` 为空时下发 ERROR;异常路径下发 ERROR。 """ from __future__ import annotations from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock import pytest from app.features.conversation.chat_turn import ChatTurnDecision, ChatTurnResult from app.features.conversation.history_store import HumanAiTurnIds from app.features.conversation.ws import pipeline as ws_pipeline from app.features.conversation.ws.message_types import MessageType class _FakeResult: def __init__(self) -> None: self.rowcount = 0 class _FakeDb: """足以驱动 process_user_message 的最小 AsyncSession 替身。""" def __init__(self) -> None: self.execute = AsyncMock(return_value=_FakeResult()) self.commit = AsyncMock(return_value=None) self.rollback = AsyncMock(return_value=None) def _make_segment(*, segment_id: str = "seg-1") -> MagicMock: """构造与 pipeline 字段访问对齐的 Segment 替身(不真正落库)。""" seg = MagicMock() seg.id = segment_id seg.audio_url = None seg.audio_duration_seconds = None seg.created_at = None return seg def _make_conversation(*, conversation_id: str) -> MagicMock: conv = MagicMock() conv.id = conversation_id conv.last_message_at = None return conv def _make_user(*, user_id: str = "user-1", language: str = "zh") -> SimpleNamespace: return SimpleNamespace(id=user_id, language_preference=language) def _ids_for(conversation_id: str) -> HumanAiTurnIds: return HumanAiTurnIds( human_message_id=f"{conversation_id}-human", assistant_message_id=f"{conversation_id}-ai", ) def _patch_common(monkeypatch: pytest.MonkeyPatch) -> tuple[list[dict], MagicMock]: """统一 mock 持久化层与段间 sleep;返回 manager.send_message 的捕获列表。""" sent_messages: list[dict] = [] async def _capture_send(_conv_id: str, message: dict) -> None: sent_messages.append(message) monkeypatch.setattr(ws_pipeline.manager, "send_message", _capture_send) monkeypatch.setattr(ws_pipeline.manager, "active_connections", {}) fake_store = MagicMock() fake_store.record_human_ai_turn_with_segment = AsyncMock() fake_store.attach_ai_tts_for_turn = AsyncMock(return_value=None) monkeypatch.setattr( ws_pipeline, "ConversationHistoryStore", lambda _db: fake_store ) # 段间 sleep 在测试里不需要真等;保留 await 语义 async def _no_sleep(_seconds: float) -> None: # pragma: no cover - trivial return None monkeypatch.setattr(ws_pipeline.asyncio, "sleep", _no_sleep) return sent_messages, fake_store @pytest.mark.asyncio async def test_tts_cancel_mid_flight_still_emits_all_agent_response_segments( monkeypatch: pytest.MonkeyPatch, ) -> None: """关键回归:i=0 完成 TTS 后客户端取消,i=1/i=2 的 AGENT_RESPONSE 必须仍下发。""" sent_messages, fake_store = _patch_common(monkeypatch) conversation_id = "conv-cancel-mid" ws_pipeline.manager.active_connections[conversation_id] = object() fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id) turn_result = ChatTurnResult( messages=["第一段", "第二段", "第三段"], skip_tts=False, decision=ChatTurnDecision(), ) monkeypatch.setattr( ws_pipeline.chat_turn_service, "process_turn", AsyncMock(return_value=turn_result), ) monkeypatch.setattr(ws_pipeline.settings, "enable_tts", True) tts_calls: list[int] = [] async def _fake_send_tts_audio( _conv_id: str, _text: str, *, chunk_index: int, chunk_total: int, # noqa: ARG001 assistant_message_id: str | None, # noqa: ARG001 tts_epoch_start: int, # noqa: ARG001 manual: bool = False, # noqa: ARG001 language: str = "zh", # noqa: ARG001 ) -> str | None: tts_calls.append(chunk_index) # 第 0 段合成完成后客户端按取消,第 1 段进入循环时已经是新 epoch if chunk_index == 0: ws_pipeline.bump_tts_cancel_epoch(_conv_id) return f"https://cos/{_conv_id}/seg-0.mp3" return None monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _fake_send_tts_audio) db = _FakeDb() await ws_pipeline.process_user_message( conversation_id=conversation_id, user_message="说说你小时候", conversation=_make_conversation(conversation_id=conversation_id), segment=_make_segment(), db=db, user=_make_user(), tts_this_turn=True, ) agent_responses = [ m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE ] assert [m["data"]["text"] for m in agent_responses] == [ "第一段", "第二段", "第三段", ], "TTS 取消后剩余段的 AGENT_RESPONSE 必须仍然下发" assert [m["data"]["index"] for m in agent_responses] == [0, 1, 2] assert all(m["data"]["total"] == 3 for m in agent_responses) # i=0 已合成;取消后不应再触发 i=1 / i=2 的 TTS 合成 assert tts_calls == [0] @pytest.mark.asyncio async def test_tts_cancel_before_any_segment_still_emits_agent_response( monkeypatch: pytest.MonkeyPatch, ) -> None: """边界:TTS 在第 0 段刚开始就被取消,第 0 段的 AGENT_RESPONSE 仍必须下发。""" sent_messages, fake_store = _patch_common(monkeypatch) conversation_id = "conv-cancel-pre" ws_pipeline.manager.active_connections[conversation_id] = object() fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id) turn_result = ChatTurnResult( messages=["唯一段"], skip_tts=False, decision=ChatTurnDecision(), ) monkeypatch.setattr( ws_pipeline.chat_turn_service, "process_turn", AsyncMock(return_value=turn_result), ) monkeypatch.setattr(ws_pipeline.settings, "enable_tts", True) async def _fake_tts_then_cancel( _conv_id: str, _text: str, *, chunk_index: int, # noqa: ARG001 chunk_total: int, # noqa: ARG001 assistant_message_id: str | None, # noqa: ARG001 tts_epoch_start: int, # noqa: ARG001 manual: bool = False, # noqa: ARG001 language: str = "zh", # noqa: ARG001 ) -> str | None: ws_pipeline.bump_tts_cancel_epoch(_conv_id) return None monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _fake_tts_then_cancel) db = _FakeDb() await ws_pipeline.process_user_message( conversation_id=conversation_id, user_message="嗯", conversation=_make_conversation(conversation_id=conversation_id), segment=_make_segment(), db=db, user=_make_user(), tts_this_turn=True, ) agent_responses = [ m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE ] assert len(agent_responses) == 1 assert agent_responses[0]["data"]["text"] == "唯一段" assert agent_responses[0]["data"]["index"] == 0 assert agent_responses[0]["data"]["total"] == 1 @pytest.mark.asyncio async def test_empty_responses_emits_terminal_error( monkeypatch: pytest.MonkeyPatch, ) -> None: """无回复时必须下发 ERROR,否则 FE 会卡在 "正在回复…"。""" sent_messages, fake_store = _patch_common(monkeypatch) conversation_id = "conv-empty" ws_pipeline.manager.active_connections[conversation_id] = object() fake_store.record_human_ai_turn_with_segment.return_value = None monkeypatch.setattr( ws_pipeline.chat_turn_service, "process_turn", AsyncMock( return_value=ChatTurnResult( messages=[], skip_tts=False, decision=ChatTurnDecision(), ) ), ) db = _FakeDb() await ws_pipeline.process_user_message( conversation_id=conversation_id, user_message="x", conversation=_make_conversation(conversation_id=conversation_id), segment=_make_segment(), db=db, user=_make_user(), tts_this_turn=False, ) error_messages = [m for m in sent_messages if m["type"] == MessageType.ERROR] assert len(error_messages) == 1 @pytest.mark.asyncio async def test_process_turn_exception_emits_terminal_error( monkeypatch: pytest.MonkeyPatch, ) -> None: """LLM/编排抛错时必须下发 ERROR,否则 FE 会卡在 "正在回复…"。""" sent_messages, _ = _patch_common(monkeypatch) conversation_id = "conv-boom" ws_pipeline.manager.active_connections[conversation_id] = object() async def _boom(*_args, **_kwargs): raise RuntimeError("upstream blew up") monkeypatch.setattr(ws_pipeline.chat_turn_service, "process_turn", _boom) db = _FakeDb() await ws_pipeline.process_user_message( conversation_id=conversation_id, user_message="y", conversation=_make_conversation(conversation_id=conversation_id), segment=_make_segment(), db=db, user=_make_user(), tts_this_turn=False, ) error_messages = [m for m in sent_messages if m["type"] == MessageType.ERROR] assert len(error_messages) == 1 @pytest.mark.asyncio async def test_tts_disabled_emits_all_segments_without_tts_calls( monkeypatch: pytest.MonkeyPatch, ) -> None: """未开启本轮 TTS 时不调用合成,但每段 AGENT_RESPONSE 仍必须下发。""" sent_messages, fake_store = _patch_common(monkeypatch) conversation_id = "conv-text-only" ws_pipeline.manager.active_connections[conversation_id] = object() fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id) monkeypatch.setattr( ws_pipeline.chat_turn_service, "process_turn", AsyncMock( return_value=ChatTurnResult( messages=["A", "B"], skip_tts=False, decision=ChatTurnDecision(), ) ), ) tts_calls: list[int] = [] async def _should_not_be_called(*_args, **_kwargs): tts_calls.append(1) return None monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _should_not_be_called) db = _FakeDb() await ws_pipeline.process_user_message( conversation_id=conversation_id, user_message="hi", conversation=_make_conversation(conversation_id=conversation_id), segment=_make_segment(), db=db, user=_make_user(), tts_this_turn=False, ) agent_responses = [ m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE ] assert [m["data"]["text"] for m in agent_responses] == ["A", "B"] assert tts_calls == []