Files
life-echo/api/tests/test_ws_pipeline_transactional.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

238 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""WS pipeline / history_store 原子持久化边界memoir 调度顺序。"""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import uuid
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from app.features.auth import models as _auth_models # noqa: F401
from app.features.conversation import models as _conv_models # noqa: F401
from app.features.memory import models as _memory_models # noqa: F401
from app.features.memoir import models as _memoir_models # noqa: F401
from app.features.payment import models as _payment_models # noqa: F401
from app.features.story import models as _story_models # noqa: F401
from app.features.user import models as _user_models # noqa: F401
from app.features.conversation.history_store import ConversationHistoryStore
from app.features.conversation.models import ConversationMessage, Segment
from app.features.conversation.ws import persist
@asynccontextmanager
async def _capture_transactional(db):
yield db
await db.commit()
@pytest.mark.asyncio
async def test_persist_message_tts_url_segment_commits_once() -> None:
db = MagicMock(spec=AsyncSession)
db.commit = AsyncMock()
msg = ConversationMessage(
id="msg-1",
conversation_id="conv-1",
role="ai",
content="hi",
tts_audio_urls=[],
)
with patch("app.features.conversation.ws.persist.transactional", _capture_transactional):
await persist.persist_message_tts_url_segment(db, msg, 0, "https://cos/0.mp3")
assert msg.tts_audio_urls == ["https://cos/0.mp3"]
db.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_persist_voice_segment_row_commits_segment_and_activity() -> None:
from app.features.conversation.models import Conversation
db = MagicMock(spec=AsyncSession)
db.commit = AsyncMock()
db.add = MagicMock()
conv = Conversation(id="conv-1", user_id="u1")
segment = Segment(
id="seg-1",
conversation_id="conv-1",
user_input_text="hello",
processed=False,
)
with patch("app.features.conversation.ws.persist.transactional", _capture_transactional):
await persist.persist_voice_segment_row(db, segment, conv)
db.add.assert_called_once_with(segment)
assert conv.last_message_at is not None
db.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_record_human_ai_turn_with_segment_single_commit() -> None:
db = MagicMock(spec=AsyncSession)
db.commit = AsyncMock()
db.rollback = AsyncMock()
db.flush = AsyncMock()
commit_count = 0
@asynccontextmanager
async def counting_transactional(session):
nonlocal commit_count
yield session
await session.commit()
commit_count += 1
store = ConversationHistoryStore(db)
segment = Segment(
id="seg-1",
conversation_id="conv-1",
user_input_text="hi",
processed=False,
)
with (
patch(
"app.features.conversation.history_store.transactional",
counting_transactional,
),
patch.object(store, "_touch_conversation", AsyncMock()),
patch.object(store, "_sync_redis_best_effort", AsyncMock()),
patch("app.features.conversation.history_store.repo.add_conversation_message"),
):
turn_ids = await store.record_human_ai_turn_with_segment(
"conv-1",
"hello",
["reply"],
segment,
user_message_timestamp=datetime.now(timezone.utc),
is_from_voice=True,
voice_session_id="vs-1",
audio_duration_seconds=3,
agent_response="reply",
)
assert turn_ids is not None
assert commit_count == 1
db.flush.assert_awaited_once()
assert segment.agent_response == "reply"
assert segment.user_message_id == turn_ids.human_message_id
assert segment.lineage_json is not None
assert segment.lineage_json["primary_user_message_id"] == turn_ids.human_message_id
assert segment.lineage_json["turns"][0]["user_message_id"] == turn_ids.human_message_id
@pytest.mark.asyncio
async def test_record_human_ai_turn_with_segment_postgres_flush_order() -> None:
"""Regression: Postgres FK on segments.user_message_id requires message INSERT first."""
from app.core.config import settings
if not settings.database_url.startswith("postgresql"):
pytest.skip("requires PostgreSQL")
from app.core.db import AsyncSessionLocal, transactional
from app.features.conversation.models import Conversation
from app.features.user.models import User
uid = str(uuid.uuid4())
cid = str(uuid.uuid4())
sid = str(uuid.uuid4())
now = datetime.now(timezone.utc)
async with AsyncSessionLocal() as db:
db.add(
User(
id=uid,
phone=f"138{uuid.uuid4().int % 100_000_000:08d}",
password_hash="x",
nickname="t",
subscription_type="free",
created_at=now,
)
)
conv = Conversation(id=cid, user_id=uid, last_message_at=now)
db.add(conv)
segment = Segment(
id=sid,
conversation_id=cid,
user_input_text="hi",
processed=False,
created_at=now,
)
async with transactional(db):
db.add(segment)
await db.refresh(segment)
async with AsyncSessionLocal() as db:
segment = await db.get(Segment, sid)
store = ConversationHistoryStore(db)
turn_ids = await store.record_human_ai_turn_with_segment(
cid,
"hello",
["reply"],
segment,
user_message_timestamp=now,
is_from_voice=False,
voice_session_id=None,
audio_duration_seconds=None,
agent_response="reply",
)
assert turn_ids is not None
assert segment.user_message_id == turn_ids.human_message_id
@pytest.mark.asyncio
async def test_attach_ai_tts_for_turn_single_commit() -> None:
db = MagicMock(spec=AsyncSession)
commit_count = 0
@asynccontextmanager
async def counting_transactional(session):
nonlocal commit_count
yield session
await session.commit()
commit_count += 1
store = ConversationHistoryStore(db)
segment = Segment(
id="seg-1",
conversation_id="conv-1",
user_input_text="hi",
processed=False,
)
ai_row = ConversationMessage(
id="ai-1",
conversation_id="conv-1",
role="ai",
content="reply",
)
async def _set_tts(*_args, **_kwargs):
ai_row.tts_audio_urls = ["https://cos/a.mp3"]
return ai_row
with (
patch(
"app.features.conversation.history_store.transactional",
counting_transactional,
),
patch.object(store, "_sync_redis_best_effort", AsyncMock()),
patch(
"app.features.conversation.history_store.repo.set_latest_ai_message_tts_audio_urls",
_set_tts,
),
):
await store.attach_ai_tts_for_turn(
"conv-1",
tts_audio_urls=["https://cos/a.mp3"],
segment=segment,
)
assert commit_count == 1
assert segment.tts_audio_urls == ["https://cos/a.mp3"]
assert ai_row.tts_audio_urls == ["https://cos/a.mp3"]