Files
life-echo/api/tests/test_pipeline_tts_cancel_emits_all_segments.py
Kevin ccdc4e4277 feat(i18n): persist language preference and thread through chat, memoir, TTS
- Add users.language_preference (Alembic 0018, default zh); capture at signup/SMS
  only; expose on auth and profile APIs
- Lite English prompts for chat and memoir; localized stage labels and agent
  names (Life Echo / 岁月知己)
- Tencent TTS: language-aware synthesis, ModelType=1 for 501004, English chunking
- WebSocket pipeline: emit all AGENT_RESPONSE segments when TTS cancels; INFO logs
  for tts_this_turn and TTS decisions; on-demand TTS logging
- Expo: device language on auth, i18n tiers/agent name, [SPLIT] streaming UX fixes
- Tests for migration, prompts, pipeline, router tts_this_turn, reply segments

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-11 16:16:49 +08:00

332 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""WS pipeline 回归TTS 取消不能丢段AGENT_RESPONSE 必须为每段下发。
- 历史 bug客户端在多段回复中途发送 ``tts_cancel`` 时,
``process_user_message`` 在 TTS 分支 ``break``,导致剩余段的 ``agent_response``
被静默丢弃FE 失去后续文本气泡,并可能停留在 "正在回复…" 状态。
- 期望:取消仅影响后续 TTS 合成AGENT_RESPONSE 必须为每段完整下发。
- 同时校验:``responses`` 为空时下发 ERROR异常路径下发 ERROR。
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.features.conversation.chat_turn import ChatTurnDecision, ChatTurnResult
from app.features.conversation.history_store import HumanAiTurnIds
from app.features.conversation.ws import pipeline as ws_pipeline
from app.features.conversation.ws.message_types import MessageType
class _FakeResult:
def __init__(self) -> None:
self.rowcount = 0
class _FakeDb:
"""足以驱动 process_user_message 的最小 AsyncSession 替身。"""
def __init__(self) -> None:
self.execute = AsyncMock(return_value=_FakeResult())
self.commit = AsyncMock(return_value=None)
self.rollback = AsyncMock(return_value=None)
def _make_segment(*, segment_id: str = "seg-1") -> MagicMock:
"""构造与 pipeline 字段访问对齐的 Segment 替身(不真正落库)。"""
seg = MagicMock()
seg.id = segment_id
seg.audio_url = None
seg.audio_duration_seconds = None
seg.created_at = None
return seg
def _make_conversation(*, conversation_id: str) -> MagicMock:
conv = MagicMock()
conv.id = conversation_id
conv.last_message_at = None
return conv
def _make_user(*, user_id: str = "user-1", language: str = "zh") -> SimpleNamespace:
return SimpleNamespace(id=user_id, language_preference=language)
def _ids_for(conversation_id: str) -> HumanAiTurnIds:
return HumanAiTurnIds(
human_message_id=f"{conversation_id}-human",
assistant_message_id=f"{conversation_id}-ai",
)
def _patch_common(monkeypatch: pytest.MonkeyPatch) -> tuple[list[dict], MagicMock]:
"""统一 mock 持久化层与段间 sleep返回 manager.send_message 的捕获列表。"""
sent_messages: list[dict] = []
async def _capture_send(_conv_id: str, message: dict) -> None:
sent_messages.append(message)
monkeypatch.setattr(ws_pipeline.manager, "send_message", _capture_send)
monkeypatch.setattr(ws_pipeline.manager, "active_connections", {})
fake_store = MagicMock()
fake_store.record_human_ai_turn = AsyncMock()
fake_store.attach_ai_tts_audio_urls = AsyncMock(return_value=None)
monkeypatch.setattr(
ws_pipeline, "ConversationHistoryStore", lambda _db: fake_store
)
# 段间 sleep 在测试里不需要真等;保留 await 语义
async def _no_sleep(_seconds: float) -> None: # pragma: no cover - trivial
return None
monkeypatch.setattr(ws_pipeline.asyncio, "sleep", _no_sleep)
return sent_messages, fake_store
@pytest.mark.asyncio
async def test_tts_cancel_mid_flight_still_emits_all_agent_response_segments(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""关键回归i=0 完成 TTS 后客户端取消i=1/i=2 的 AGENT_RESPONSE 必须仍下发。"""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-cancel-mid"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn.return_value = _ids_for(conversation_id)
turn_result = ChatTurnResult(
messages=["第一段", "第二段", "第三段"],
skip_tts=False,
decision=ChatTurnDecision(),
)
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(return_value=turn_result),
)
monkeypatch.setattr(ws_pipeline.settings, "enable_tts", True)
tts_calls: list[int] = []
async def _fake_send_tts_audio(
_conv_id: str,
_text: str,
*,
chunk_index: int,
chunk_total: int, # noqa: ARG001
assistant_message_id: str | None, # noqa: ARG001
tts_epoch_start: int, # noqa: ARG001
manual: bool = False, # noqa: ARG001
language: str = "zh", # noqa: ARG001
) -> str | None:
tts_calls.append(chunk_index)
# 第 0 段合成完成后客户端按取消,第 1 段进入循环时已经是新 epoch
if chunk_index == 0:
ws_pipeline.bump_tts_cancel_epoch(_conv_id)
return f"https://cos/{_conv_id}/seg-0.mp3"
return None
monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _fake_send_tts_audio)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="说说你小时候",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=True,
)
agent_responses = [
m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE
]
assert [m["data"]["text"] for m in agent_responses] == [
"第一段",
"第二段",
"第三段",
], "TTS 取消后剩余段的 AGENT_RESPONSE 必须仍然下发"
assert [m["data"]["index"] for m in agent_responses] == [0, 1, 2]
assert all(m["data"]["total"] == 3 for m in agent_responses)
# i=0 已合成;取消后不应再触发 i=1 / i=2 的 TTS 合成
assert tts_calls == [0]
@pytest.mark.asyncio
async def test_tts_cancel_before_any_segment_still_emits_agent_response(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""边界TTS 在第 0 段刚开始就被取消,第 0 段的 AGENT_RESPONSE 仍必须下发。"""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-cancel-pre"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn.return_value = _ids_for(conversation_id)
turn_result = ChatTurnResult(
messages=["唯一段"],
skip_tts=False,
decision=ChatTurnDecision(),
)
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(return_value=turn_result),
)
monkeypatch.setattr(ws_pipeline.settings, "enable_tts", True)
async def _fake_tts_then_cancel(
_conv_id: str,
_text: str,
*,
chunk_index: int, # noqa: ARG001
chunk_total: int, # noqa: ARG001
assistant_message_id: str | None, # noqa: ARG001
tts_epoch_start: int, # noqa: ARG001
manual: bool = False, # noqa: ARG001
language: str = "zh", # noqa: ARG001
) -> str | None:
ws_pipeline.bump_tts_cancel_epoch(_conv_id)
return None
monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _fake_tts_then_cancel)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=True,
)
agent_responses = [
m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE
]
assert len(agent_responses) == 1
assert agent_responses[0]["data"]["text"] == "唯一段"
assert agent_responses[0]["data"]["index"] == 0
assert agent_responses[0]["data"]["total"] == 1
@pytest.mark.asyncio
async def test_empty_responses_emits_terminal_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""无回复时必须下发 ERROR否则 FE 会卡在 "正在回复…""""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-empty"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn.return_value = None
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(
return_value=ChatTurnResult(
messages=[],
skip_tts=False,
decision=ChatTurnDecision(),
)
),
)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="x",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=False,
)
error_messages = [m for m in sent_messages if m["type"] == MessageType.ERROR]
assert len(error_messages) == 1
@pytest.mark.asyncio
async def test_process_turn_exception_emits_terminal_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""LLM/编排抛错时必须下发 ERROR否则 FE 会卡在 "正在回复…""""
sent_messages, _ = _patch_common(monkeypatch)
conversation_id = "conv-boom"
ws_pipeline.manager.active_connections[conversation_id] = object()
async def _boom(*_args, **_kwargs):
raise RuntimeError("upstream blew up")
monkeypatch.setattr(ws_pipeline.chat_turn_service, "process_turn", _boom)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="y",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=False,
)
error_messages = [m for m in sent_messages if m["type"] == MessageType.ERROR]
assert len(error_messages) == 1
@pytest.mark.asyncio
async def test_tts_disabled_emits_all_segments_without_tts_calls(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""未开启本轮 TTS 时不调用合成,但每段 AGENT_RESPONSE 仍必须下发。"""
sent_messages, fake_store = _patch_common(monkeypatch)
conversation_id = "conv-text-only"
ws_pipeline.manager.active_connections[conversation_id] = object()
fake_store.record_human_ai_turn.return_value = _ids_for(conversation_id)
monkeypatch.setattr(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(
return_value=ChatTurnResult(
messages=["A", "B"],
skip_tts=False,
decision=ChatTurnDecision(),
)
),
)
tts_calls: list[int] = []
async def _should_not_be_called(*_args, **_kwargs):
tts_calls.append(1)
return None
monkeypatch.setattr(ws_pipeline, "_send_tts_audio", _should_not_be_called)
db = _FakeDb()
await ws_pipeline.process_user_message(
conversation_id=conversation_id,
user_message="hi",
conversation=_make_conversation(conversation_id=conversation_id),
segment=_make_segment(),
db=db,
user=_make_user(),
tts_this_turn=False,
)
agent_responses = [
m for m in sent_messages if m["type"] == MessageType.AGENT_RESPONSE
]
assert [m["data"]["text"] for m in agent_responses] == ["A", "B"]
assert tts_calls == []