Files
life-echo/api/tests/test_pipeline_tts_cancel_emits_all_segments.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

332 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""WS pipeline 回归TTS 取消不能丢段AGENT_RESPONSE 必须为每段下发。
- 历史 bug客户端在多段回复中途发送 ``tts_cancel`` 时,
``process_user_message`` 在 TTS 分支 ``break``,导致剩余段的 ``agent_response``
被静默丢弃FE 失去后续文本气泡,并可能停留在 "正在回复…" 状态。
- 期望:取消仅影响后续 TTS 合成AGENT_RESPONSE 必须为每段完整下发。
- 同时校验:``responses`` 为空时下发 ERROR异常路径下发 ERROR。
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.features.conversation.chat_turn import ChatTurnDecision, ChatTurnResult
from app.features.conversation.history_store import HumanAiTurnIds
from app.features.conversation.ws import pipeline as ws_pipeline
from app.features.conversation.ws.message_types import MessageType
class _FakeResult:
def __init__(self) -> None:
self.rowcount = 0
class _FakeDb:
"""足以驱动 process_user_message 的最小 AsyncSession 替身。"""
def __init__(self) -> None:
self.execute = AsyncMock(return_value=_FakeResult())
self.commit = AsyncMock(return_value=None)
self.rollback = AsyncMock(return_value=None)
def _make_segment(*, segment_id: str = "seg-1") -> MagicMock:
"""构造与 pipeline 字段访问对齐的 Segment 替身(不真正落库)。"""
seg = MagicMock()
seg.id = segment_id
seg.audio_url = None
seg.audio_duration_seconds = None
seg.created_at = None
return seg
def _make_conversation(*, conversation_id: str) -> MagicMock:
conv = MagicMock()
conv.id = conversation_id
conv.last_message_at = None
return conv
def _make_user(*, user_id: str = "user-1", language: str = "zh") -> SimpleNamespace:
return SimpleNamespace(id=user_id, language_preference=language)
def _ids_for(conversation_id: str) -> HumanAiTurnIds:
return HumanAiTurnIds(
human_message_id=f"{conversation_id}-human",
assistant_message_id=f"{conversation_id}-ai",
)
def _patch_common(monkeypatch: pytest.MonkeyPatch) -> tuple[list[dict], MagicMock]:
"""统一 mock 持久化层与段间 sleep返回 manager.send_message 的捕获列表。"""
sent_messages: list[dict] = []
async def _capture_send(_conv_id: str, message: dict) -> None:
sent_messages.append(message)
monkeypatch.setattr(ws_pipeline.manager, "send_message", _capture_send)
monkeypatch.setattr(ws_pipeline.manager, "active_connections", {})
fake_store = MagicMock()
fake_store.record_human_ai_turn_with_segment = AsyncMock()
fake_store.attach_ai_tts_for_turn = AsyncMock(return_value=None)
monkeypatch.setattr(
ws_pipeline, "ConversationHistoryStore", lambda _db: fake_store
)
# 段间 sleep 在测试里不需要真等;保留 await 语义
async def _no_sleep(_seconds: float) -> None: # pragma: no cover - trivial
return None
monkeypatch.setattr(ws_pipeline.asyncio, "sleep", _no_sleep)
return sent_messages, fake_store
@pytest.mark.asyncio
async def test_tts_cancel_mid_flight_still_emits_all_agent_response_segments(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""关键回归i=0 完成 TTS 后客户端取消i=1/i=2 的 AGENT_RESPONSE 必须仍下发。"""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-cancel-mid"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id)
turn_result = ChatTurnResult(
messages=["第一段", "第二段", "第三段"],
skip_tts=False,
decision=ChatTurnDecision(),
)
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(return_value=turn_result),
)
monkeypatch.setattr(ws_pipeline.settings, "enable_tts", True)
tts_calls: list[int] = []
async def _fake_send_tts_audio(
_conv_id: str,
_text: str,
*,
chunk_index: int,
chunk_total: int, # noqa: ARG001
assistant_message_id: str | None, # noqa: ARG001
tts_epoch_start: int, # noqa: ARG001
manual: bool = False, # noqa: ARG001
language: str = "zh", # noqa: ARG001
) -> str | None:
tts_calls.append(chunk_index)
# 第 0 段合成完成后客户端按取消,第 1 段进入循环时已经是新 epoch
if chunk_index == 0:
ws_pipeline.bump_tts_cancel_epoch(_conv_id)
return f"https://cos/{_conv_id}/seg-0.mp3"
return None
monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _fake_send_tts_audio)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="说说你小时候",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=True,
)
agent_responses = [
m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE
]
assert [m["data"]["text"] for m in agent_responses] == [
"第一段",
"第二段",
"第三段",
], "TTS 取消后剩余段的 AGENT_RESPONSE 必须仍然下发"
assert [m["data"]["index"] for m in agent_responses] == [0, 1, 2]
assert all(m["data"]["total"] == 3 for m in agent_responses)
# i=0 已合成;取消后不应再触发 i=1 / i=2 的 TTS 合成
assert tts_calls == [0]
@pytest.mark.asyncio
async def test_tts_cancel_before_any_segment_still_emits_agent_response(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""边界TTS 在第 0 段刚开始就被取消,第 0 段的 AGENT_RESPONSE 仍必须下发。"""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-cancel-pre"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id)
turn_result = ChatTurnResult(
messages=["唯一段"],
skip_tts=False,
decision=ChatTurnDecision(),
)
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(return_value=turn_result),
)
monkeypatch.setattr(ws_pipeline.settings, "enable_tts", True)
async def _fake_tts_then_cancel(
_conv_id: str,
_text: str,
*,
chunk_index: int, # noqa: ARG001
chunk_total: int, # noqa: ARG001
assistant_message_id: str | None, # noqa: ARG001
tts_epoch_start: int, # noqa: ARG001
manual: bool = False, # noqa: ARG001
language: str = "zh", # noqa: ARG001
) -> str | None:
ws_pipeline.bump_tts_cancel_epoch(_conv_id)
return None
monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _fake_tts_then_cancel)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=True,
)
agent_responses = [
m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE
]
assert len(agent_responses) == 1
assert agent_responses[0]["data"]["text"] == "唯一段"
assert agent_responses[0]["data"]["index"] == 0
assert agent_responses[0]["data"]["total"] == 1
@pytest.mark.asyncio
async def test_empty_responses_emits_terminal_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""无回复时必须下发 ERROR否则 FE 会卡在 "正在回复…""""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-empty"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn_with_segment.return_value = None
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(
return_value=ChatTurnResult(
messages=[],
skip_tts=False,
decision=ChatTurnDecision(),
)
),
)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="x",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=False,
)
error_messages = [m for m in sent_messages if m["type"] == MessageType.ERROR]
assert len(error_messages) == 1
@pytest.mark.asyncio
async def test_process_turn_exception_emits_terminal_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""LLM/编排抛错时必须下发 ERROR否则 FE 会卡在 "正在回复…""""
sent_messages, _ = _patch_common(monkeypatch)
conversation_id = "conv-boom"
ws_pipeline.manager.active_connections[conversation_id] = object()
async def _boom(*_args, **_kwargs):
raise RuntimeError("upstream blew up")
monkeypatch.setattr(ws_pipeline.chat_turn_service, "process_turn", _boom)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="y",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=False,
)
error_messages = [m for m in sent_messages if m["type"] == MessageType.ERROR]
assert len(error_messages) == 1
@pytest.mark.asyncio
async def test_tts_disabled_emits_all_segments_without_tts_calls(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""未开启本轮 TTS 时不调用合成,但每段 AGENT_RESPONSE 仍必须下发。"""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-text-only"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id)
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(
return_value=ChatTurnResult(
messages=["A", "B"],
skip_tts=False,
decision=ChatTurnDecision(),
)
),
)
tts_calls: list[int] = []
async def _should_not_be_called(*_args, **_kwargs):
tts_calls.append(1)
return None
monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _should_not_be_called)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="hi",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=False,
)
agent_responses = [
m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE
]
assert [m["data"]["text"] for m in agent_responses] == ["A", "B"]
assert tts_calls == []