Files
life-echo/api/tests/test_memoir_phase1_ingest_idempotency.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

251 lines
7.6 KiB
Python

"""Memoir queue scheduling and phase1 memory ingest idempotency."""
from __future__ import annotations
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from app.features.conversation.models import Conversation, Segment
from app.features.user.models import User
from app.features.conversation.ws import pipeline as ws_pipeline
from app.features.memory.ingest_service import MemoryIngestService
from app.tasks import memoir_tasks
from sqlalchemy.exc import IntegrityError
@pytest.mark.asyncio
async def test_process_user_message_queues_memoir_after_lineage() -> None:
db = MagicMock(spec=AsyncSession)
user = User(id="u1", language_preference="zh")
conversation = Conversation(id="conv-1", user_id="u1")
segment = Segment(
id="seg-1",
conversation_id="conv-1",
user_input_text="hello",
processed=False,
)
call_order: list[str] = []
async def _record_turn(*_args, **_kwargs):
call_order.append("record_turn")
segment.lineage_json = {"primary_user_message_id": "hum-1"}
return SimpleNamespace(
human_message_id="hum-1",
assistant_message_id="ai-1",
)
async def _queue_segment(*_args, **_kwargs):
call_order.append("queue_segment")
turn = SimpleNamespace(
messages=["reply"],
skip_tts=True,
memory_retrieval_trace=None,
)
with (
patch.object(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(return_value=turn),
),
patch.object(
ws_pipeline.ConversationHistoryStore,
"record_human_ai_turn_with_segment",
AsyncMock(side_effect=_record_turn),
),
patch.object(
ws_pipeline,
"_schedule_memoir_ingest_for_segment",
AsyncMock(side_effect=_queue_segment),
),
patch.object(ws_pipeline.manager, "active_connections", {}),
patch.object(ws_pipeline.manager, "send_message", AsyncMock()),
patch.object(
ws_pipeline,
"get_or_create_state",
AsyncMock(return_value=SimpleNamespace(current_stage="childhood")),
),
patch.object(ws_pipeline, "maybe_send_topic_chips_ws", AsyncMock()),
):
await ws_pipeline.process_user_message(
"conv-1",
"hello",
conversation,
segment,
db,
user=user,
user_message_timestamp=datetime.now(timezone.utc),
)
assert call_order == ["record_turn", "queue_segment"]
@pytest.mark.asyncio
async def test_process_user_message_queues_memoir_on_ai_failure() -> None:
db = MagicMock(spec=AsyncSession)
user = User(id="u1", language_preference="zh")
conversation = Conversation(id="conv-1", user_id="u1")
segment = Segment(
id="seg-1",
conversation_id="conv-1",
user_input_text="hello",
processed=False,
)
queued = False
async def _queue_segment(*_args, **_kwargs):
nonlocal queued
queued = True
turn = SimpleNamespace(messages=[], skip_tts=True, memory_retrieval_trace=None)
with (
patch.object(
ws_pipeline.chat_turn_service,
"process_turn",
AsyncMock(return_value=turn),
),
patch.object(
ws_pipeline.ConversationHistoryStore,
"record_human_ai_turn_with_segment",
AsyncMock(return_value=None),
),
patch.object(
ws_pipeline,
"_schedule_memoir_ingest_for_segment",
AsyncMock(side_effect=_queue_segment),
),
patch.object(ws_pipeline.manager, "active_connections", {"conv-1": MagicMock()}),
patch.object(ws_pipeline.manager, "send_message", AsyncMock()),
):
await ws_pipeline.process_user_message(
"conv-1",
"hello",
conversation,
segment,
db,
user=user,
user_message_timestamp=datetime.now(timezone.utc),
)
assert queued is True
@pytest.mark.asyncio
async def test_ingest_batch_idempotent_by_segment_id(monkeypatch) -> None:
stored: dict[str, SimpleNamespace] = {}
create_calls = 0
class FakeSession:
async def commit(self) -> None:
pass
async def flush(self) -> None:
pass
async def fake_get(db, *, user_id: str, segment_id: str):
return stored.get(segment_id)
async def fake_create_source(session, **kwargs):
nonlocal create_calls
create_calls += 1
sid = kwargs["segment_id"]
src = SimpleNamespace(id=f"src-{create_calls}", segment_id=sid)
stored[sid] = src
return src
async def fake_create_chunk(*_args, **kwargs):
return SimpleNamespace(id=f"ch-{kwargs.get('chunk_index')}")
class FakeEmbeddingService:
def __init__(self, *_args, **_kwargs) -> None:
pass
async def embed_source(self, user_id: str, source_id: str) -> dict:
return {"status": "success", "vectors_written": 1}
monkeypatch.setattr(
"app.features.memory.ingest_service.get_transcript_source_by_segment_id",
fake_get,
)
monkeypatch.setattr(
"app.features.memory.ingest_service.create_source",
fake_create_source,
)
monkeypatch.setattr(
"app.features.memory.ingest_service.create_chunk",
fake_create_chunk,
)
monkeypatch.setattr(
"app.features.memory.ingest_service.MemoryEmbeddingService",
FakeEmbeddingService,
)
monkeypatch.setattr("app.features.memory.constants.memory.enrichment_enabled", False)
service = MemoryIngestService(
FakeSession(), # type: ignore[arg-type]
embedding_provider=None,
enrichment_scheduler=MagicMock(schedule_many=MagicMock(return_value=[])),
)
items = [
("c1", "hello", {"primary_user_message_id": "m1"}, "seg-1"),
]
first = await service.ingest_transcripts_batch("u1", items)
second = await service.ingest_transcripts_batch("u1", items)
assert first == ["src-1"]
assert second == ["src-1"]
assert create_calls == 1
def test_phase1_memory_ingest_batch_sync_reraises_on_failure(monkeypatch) -> None:
async def _fail(*_args, **_kwargs):
raise RuntimeError("ingest unavailable")
monkeypatch.setattr(memoir_tasks, "_memory_ingest_transcripts_batch", _fail)
db = MagicMock()
items = [("c1", "hello", None, "seg-1")]
with pytest.raises(RuntimeError, match="ingest unavailable"):
memoir_tasks._phase1_memory_ingest_batch_sync(
db,
"u1",
items,
memoir_correlation_id="corr-1",
)
def test_phase1_memory_ingest_batch_sync_resolves_integrity_race(monkeypatch) -> None:
async def _race(*_args, **_kwargs):
raise IntegrityError("insert", {}, Exception("unique"))
monkeypatch.setattr(memoir_tasks, "_memory_ingest_transcripts_batch", _race)
db = MagicMock()
existing = SimpleNamespace(id="src-existing")
def _lookup(_db, *, user_id: str, segment_id: str):
assert user_id == "u1"
assert segment_id == "seg-1"
return existing
monkeypatch.setattr(
memoir_tasks,
"get_transcript_source_by_segment_id_sync",
_lookup,
)
items = [("c1", "hello", None, "seg-1")]
ids = memoir_tasks._phase1_memory_ingest_batch_sync(
db,
"u1",
items,
memoir_correlation_id="corr-1",
)
assert ids == ["src-existing"]