Mock COS storage for auth router DI and skip Postgres integration tests when the database is unreachable. Co-authored-by: Cursor <cursoragent@cursor.com>
246 lines
7.6 KiB
Python
246 lines
7.6 KiB
Python
"""WS pipeline / history_store 原子持久化边界;memoir 调度顺序。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime, timezone
|
||
from types import SimpleNamespace
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
import uuid
|
||
|
||
import pytest
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.features.auth import models as _auth_models # noqa: F401
|
||
from app.features.conversation import models as _conv_models # noqa: F401
|
||
from app.features.memory import models as _memory_models # noqa: F401
|
||
from app.features.memoir import models as _memoir_models # noqa: F401
|
||
from app.features.payment import models as _payment_models # noqa: F401
|
||
from app.features.story import models as _story_models # noqa: F401
|
||
from app.features.user import models as _user_models # noqa: F401
|
||
from app.features.conversation.history_store import ConversationHistoryStore
|
||
from app.features.conversation.models import ConversationMessage, Segment
|
||
from app.features.conversation.ws import persist
|
||
|
||
|
||
@asynccontextmanager
|
||
async def _capture_transactional(db):
|
||
yield db
|
||
await db.commit()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_persist_message_tts_url_segment_commits_once() -> None:
|
||
db = MagicMock(spec=AsyncSession)
|
||
db.commit = AsyncMock()
|
||
msg = ConversationMessage(
|
||
id="msg-1",
|
||
conversation_id="conv-1",
|
||
role="ai",
|
||
content="hi",
|
||
tts_audio_urls=[],
|
||
)
|
||
|
||
with patch("app.features.conversation.ws.persist.transactional", _capture_transactional):
|
||
await persist.persist_message_tts_url_segment(db, msg, 0, "https://cos/0.mp3")
|
||
|
||
assert msg.tts_audio_urls == ["https://cos/0.mp3"]
|
||
db.commit.assert_awaited_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_persist_voice_segment_row_commits_segment_and_activity() -> None:
|
||
from app.features.conversation.models import Conversation
|
||
|
||
db = MagicMock(spec=AsyncSession)
|
||
db.commit = AsyncMock()
|
||
db.add = MagicMock()
|
||
conv = Conversation(id="conv-1", user_id="u1")
|
||
segment = Segment(
|
||
id="seg-1",
|
||
conversation_id="conv-1",
|
||
user_input_text="hello",
|
||
processed=False,
|
||
)
|
||
|
||
with patch("app.features.conversation.ws.persist.transactional", _capture_transactional):
|
||
await persist.persist_voice_segment_row(db, segment, conv)
|
||
|
||
db.add.assert_called_once_with(segment)
|
||
assert conv.last_message_at is not None
|
||
db.commit.assert_awaited_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_record_human_ai_turn_with_segment_single_commit() -> None:
|
||
db = MagicMock(spec=AsyncSession)
|
||
db.commit = AsyncMock()
|
||
db.rollback = AsyncMock()
|
||
db.flush = AsyncMock()
|
||
commit_count = 0
|
||
|
||
@asynccontextmanager
|
||
async def counting_transactional(session):
|
||
nonlocal commit_count
|
||
yield session
|
||
await session.commit()
|
||
commit_count += 1
|
||
|
||
store = ConversationHistoryStore(db)
|
||
segment = Segment(
|
||
id="seg-1",
|
||
conversation_id="conv-1",
|
||
user_input_text="hi",
|
||
processed=False,
|
||
)
|
||
|
||
with (
|
||
patch(
|
||
"app.features.conversation.history_store.transactional",
|
||
counting_transactional,
|
||
),
|
||
patch.object(store, "_touch_conversation", AsyncMock()),
|
||
patch.object(store, "_sync_redis_best_effort", AsyncMock()),
|
||
patch("app.features.conversation.history_store.repo.add_conversation_message"),
|
||
):
|
||
turn_ids = await store.record_human_ai_turn_with_segment(
|
||
"conv-1",
|
||
"hello",
|
||
["reply"],
|
||
segment,
|
||
user_message_timestamp=datetime.now(timezone.utc),
|
||
is_from_voice=True,
|
||
voice_session_id="vs-1",
|
||
audio_duration_seconds=3,
|
||
agent_response="reply",
|
||
)
|
||
|
||
assert turn_ids is not None
|
||
assert commit_count == 1
|
||
db.flush.assert_awaited_once()
|
||
assert segment.agent_response == "reply"
|
||
assert segment.user_message_id == turn_ids.human_message_id
|
||
assert segment.lineage_json is not None
|
||
assert segment.lineage_json["primary_user_message_id"] == turn_ids.human_message_id
|
||
assert segment.lineage_json["turns"][0]["user_message_id"] == turn_ids.human_message_id
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_record_human_ai_turn_with_segment_postgres_flush_order() -> None:
|
||
"""Regression: Postgres FK on segments.user_message_id requires message INSERT first."""
|
||
from app.core.config import settings
|
||
|
||
if not settings.database_url.startswith("postgresql"):
|
||
pytest.skip("requires PostgreSQL")
|
||
|
||
from sqlalchemy import text
|
||
|
||
from app.core.db import AsyncSessionLocal, transactional
|
||
from app.features.conversation.models import Conversation
|
||
from app.features.user.models import User
|
||
|
||
try:
|
||
async with AsyncSessionLocal() as db:
|
||
await db.execute(text("SELECT 1"))
|
||
except Exception:
|
||
pytest.skip("PostgreSQL not reachable")
|
||
|
||
uid = str(uuid.uuid4())
|
||
cid = str(uuid.uuid4())
|
||
sid = str(uuid.uuid4())
|
||
now = datetime.now(timezone.utc)
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
db.add(
|
||
User(
|
||
id=uid,
|
||
phone=f"138{uuid.uuid4().int % 100_000_000:08d}",
|
||
password_hash="x",
|
||
nickname="t",
|
||
subscription_type="free",
|
||
created_at=now,
|
||
)
|
||
)
|
||
conv = Conversation(id=cid, user_id=uid, last_message_at=now)
|
||
db.add(conv)
|
||
segment = Segment(
|
||
id=sid,
|
||
conversation_id=cid,
|
||
user_input_text="hi",
|
||
processed=False,
|
||
created_at=now,
|
||
)
|
||
async with transactional(db):
|
||
db.add(segment)
|
||
await db.refresh(segment)
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
segment = await db.get(Segment, sid)
|
||
store = ConversationHistoryStore(db)
|
||
turn_ids = await store.record_human_ai_turn_with_segment(
|
||
cid,
|
||
"hello",
|
||
["reply"],
|
||
segment,
|
||
user_message_timestamp=now,
|
||
is_from_voice=False,
|
||
voice_session_id=None,
|
||
audio_duration_seconds=None,
|
||
agent_response="reply",
|
||
)
|
||
|
||
assert turn_ids is not None
|
||
assert segment.user_message_id == turn_ids.human_message_id
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_attach_ai_tts_for_turn_single_commit() -> None:
|
||
db = MagicMock(spec=AsyncSession)
|
||
commit_count = 0
|
||
|
||
@asynccontextmanager
|
||
async def counting_transactional(session):
|
||
nonlocal commit_count
|
||
yield session
|
||
await session.commit()
|
||
commit_count += 1
|
||
|
||
store = ConversationHistoryStore(db)
|
||
segment = Segment(
|
||
id="seg-1",
|
||
conversation_id="conv-1",
|
||
user_input_text="hi",
|
||
processed=False,
|
||
)
|
||
ai_row = ConversationMessage(
|
||
id="ai-1",
|
||
conversation_id="conv-1",
|
||
role="ai",
|
||
content="reply",
|
||
)
|
||
|
||
async def _set_tts(*_args, **_kwargs):
|
||
ai_row.tts_audio_urls = ["https://cos/a.mp3"]
|
||
return ai_row
|
||
|
||
with (
|
||
patch(
|
||
"app.features.conversation.history_store.transactional",
|
||
counting_transactional,
|
||
),
|
||
patch.object(store, "_sync_redis_best_effort", AsyncMock()),
|
||
patch(
|
||
"app.features.conversation.history_store.repo.set_latest_ai_message_tts_audio_urls",
|
||
_set_tts,
|
||
),
|
||
):
|
||
await store.attach_ai_tts_for_turn(
|
||
"conv-1",
|
||
tts_audio_urls=["https://cos/a.mp3"],
|
||
segment=segment,
|
||
)
|
||
|
||
assert commit_count == 1
|
||
assert segment.tts_audio_urls == ["https://cos/a.mp3"]
|
||
assert ai_row.tts_audio_urls == ["https://cos/a.mp3"]
|