Files
life-echo/api/tests/test_pipeline_tts_cancel_emits_all_segments.py

332 lines
11 KiB
Python
Raw Permalink Normal View History

"""WS pipeline 回归TTS 取消不能丢段AGENT_RESPONSE 必须为每段下发。
- 历史 bug客户端在多段回复中途发送 ``tts_cancel``
``process_user_message`` TTS 分支 ``break``导致剩余段的 ``agent_response``
被静默丢弃FE 失去后续文本气泡并可能停留在 "正在回复…" 状态
- 期望取消仅影响后续 TTS 合成AGENT_RESPONSE 必须为每段完整下发
- 同时校验``responses`` 为空时下发 ERROR异常路径下发 ERROR
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.features.conversation.chat_turn import ChatTurnDecision, ChatTurnResult
from app.features.conversation.history_store import HumanAiTurnIds
from app.features.conversation.ws import pipeline as ws_pipeline
from app.features.conversation.ws.message_types import MessageType
class _FakeResult:
def __init__(self) -> None:
self.rowcount = 0
class _FakeDb:
"""足以驱动 process_user_message 的最小 AsyncSession 替身。"""
def __init__(self) -> None:
self.execute = AsyncMock(return_value=_FakeResult())
self.commit = AsyncMock(return_value=None)
self.rollback = AsyncMock(return_value=None)
def _make_segment(*, segment_id: str = "seg-1") -> MagicMock:
"""构造与 pipeline 字段访问对齐的 Segment 替身(不真正落库)。"""
seg = MagicMock()
seg.id = segment_id
seg.audio_url = None
seg.audio_duration_seconds = None
seg.created_at = None
return seg
def _make_conversation(*, conversation_id: str) -> MagicMock:
conv = MagicMock()
conv.id = conversation_id
conv.last_message_at = None
return conv
def _make_user(*, user_id: str = "user-1", language: str = "zh") -> SimpleNamespace:
return SimpleNamespace(id=user_id, language_preference=language)
def _ids_for(conversation_id: str) -> HumanAiTurnIds:
return HumanAiTurnIds(
human_message_id=f"{conversation_id}-human",
assistant_message_id=f"{conversation_id}-ai",
)
def _patch_common(monkeypatch: pytest.MonkeyPatch) -> tuple[list[dict], MagicMock]:
"""统一 mock 持久化层与段间 sleep返回 manager.send_message 的捕获列表。"""
sent_messages: list[dict] = []
async def _capture_send(_conv_id: str, message: dict) -> None:
sent_messages.append(message)
monkeypatch.setattr(ws_pipeline.manager, "send_message", _capture_send)
monkeypatch.setattr(ws_pipeline.manager, "active_connections", {})
fake_store = MagicMock()
fake_store.record_human_ai_turn_with_segment = AsyncMock()
fake_store.attach_ai_tts_for_turn = AsyncMock(return_value=None)
monkeypatch.setattr(
ws_pipeline, "ConversationHistoryStore", lambda _db: fake_store
)
# 段间 sleep 在测试里不需要真等;保留 await 语义
async def _no_sleep(_seconds: float) -> None: # pragma: no cover - trivial
return None
monkeypatch.setattr(ws_pipeline.asyncio, "sleep", _no_sleep)
return sent_messages, fake_store
@pytest.mark.asyncio
async def test_tts_cancel_mid_flight_still_emits_all_agent_response_segments(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""关键回归i=0 完成 TTS 后客户端取消i=1/i=2 的 AGENT_RESPONSE 必须仍下发。"""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-cancel-mid"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id)
turn_result = ChatTurnResult(
messages=["第一段", "第二段", "第三段"],
skip_tts=False,
decision=ChatTurnDecision(),
)
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(return_value=turn_result),
)
monkeypatch.setattr(ws_pipeline.settings, "enable_tts", True)
tts_calls: list[int] = []
async def _fake_send_tts_audio(
_conv_id: str,
_text: str,
*,
chunk_index: int,
chunk_total: int, # noqa: ARG001
assistant_message_id: str | None, # noqa: ARG001
tts_epoch_start: int, # noqa: ARG001
manual: bool = False, # noqa: ARG001
language: str = "zh", # noqa: ARG001
) -> str | None:
tts_calls.append(chunk_index)
# 第 0 段合成完成后客户端按取消,第 1 段进入循环时已经是新 epoch
if chunk_index == 0:
ws_pipeline.bump_tts_cancel_epoch(_conv_id)
return f"https://cos/{_conv_id}/seg-0.mp3"
return None
monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _fake_send_tts_audio)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="说说你小时候",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=True,
)
agent_responses = [
m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE
]
assert [m["data"]["text"] for m in agent_responses] == [
"第一段",
"第二段",
"第三段",
], "TTS 取消后剩余段的 AGENT_RESPONSE 必须仍然下发"
assert [m["data"]["index"] for m in agent_responses] == [0, 1, 2]
assert all(m["data"]["total"] == 3 for m in agent_responses)
# i=0 已合成;取消后不应再触发 i=1 / i=2 的 TTS 合成
assert tts_calls == [0]
@pytest.mark.asyncio
async def test_tts_cancel_before_any_segment_still_emits_agent_response(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""边界TTS 在第 0 段刚开始就被取消,第 0 段的 AGENT_RESPONSE 仍必须下发。"""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-cancel-pre"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id)
turn_result = ChatTurnResult(
messages=["唯一段"],
skip_tts=False,
decision=ChatTurnDecision(),
)
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(return_value=turn_result),
)
monkeypatch.setattr(ws_pipeline.settings, "enable_tts", True)
async def _fake_tts_then_cancel(
_conv_id: str,
_text: str,
*,
chunk_index: int, # noqa: ARG001
chunk_total: int, # noqa: ARG001
assistant_message_id: str | None, # noqa: ARG001
tts_epoch_start: int, # noqa: ARG001
manual: bool = False, # noqa: ARG001
language: str = "zh", # noqa: ARG001
) -> str | None:
ws_pipeline.bump_tts_cancel_epoch(_conv_id)
return None
monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _fake_tts_then_cancel)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=True,
)
agent_responses = [
m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE
]
assert len(agent_responses) == 1
assert agent_responses[0]["data"]["text"] == "唯一段"
assert agent_responses[0]["data"]["index"] == 0
assert agent_responses[0]["data"]["total"] == 1
@pytest.mark.asyncio
async def test_empty_responses_emits_terminal_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""无回复时必须下发 ERROR否则 FE 会卡在 "正在回复…""""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-empty"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn_with_segment.return_value = None
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(
return_value=ChatTurnResult(
messages=[],
skip_tts=False,
decision=ChatTurnDecision(),
)
),
)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="x",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=False,
)
error_messages = [m for m in sent_messages if m["type"] == MessageType.ERROR]
assert len(error_messages) == 1
@pytest.mark.asyncio
async def test_process_turn_exception_emits_terminal_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""LLM/编排抛错时必须下发 ERROR否则 FE 会卡在 "正在回复…""""
sent_messages, _ = _patch_common(monkeypatch)
conversation_id = "conv-boom"
ws_pipeline.manager.active_connections[conversation_id] = object()
async def _boom(*_args, **_kwargs):
raise RuntimeError("upstream blew up")
monkeypatch.setattr(ws_pipeline.chat_turn_service, "process_turn", _boom)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="y",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=False,
)
error_messages = [m for m in sent_messages if m["type"] == MessageType.ERROR]
assert len(error_messages) == 1
@pytest.mark.asyncio
async def test_tts_disabled_emits_all_segments_without_tts_calls(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""未开启本轮 TTS 时不调用合成,但每段 AGENT_RESPONSE 仍必须下发。"""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-text-only"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id)
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(
return_value=ChatTurnResult(
messages=["A", "B"],
skip_tts=False,
decision=ChatTurnDecision(),
)
),
)
tts_calls: list[int] = []
async def _should_not_be_called(*_args, **_kwargs):
tts_calls.append(1)
return None
monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _should_not_be_called)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="hi",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=False,
)
agent_responses = [
m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE
]
assert [m["data"]["text"] for m in agent_responses] == ["A", "B"]
assert tts_calls == []