配置 SSOT(TOML + .env) 统一错误契约 Auth 与事务边界 Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client 可观测性(OpenTelemetry + LGTM)
251 lines
7.6 KiB
Python
251 lines
7.6 KiB
Python
"""Memoir queue scheduling and phase1 memory ingest idempotency."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.features.conversation.models import Conversation, Segment
|
|
from app.features.user.models import User
|
|
from app.features.conversation.ws import pipeline as ws_pipeline
|
|
from app.features.memory.ingest_service import MemoryIngestService
|
|
from app.tasks import memoir_tasks
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_user_message_queues_memoir_after_lineage() -> None:
|
|
db = MagicMock(spec=AsyncSession)
|
|
user = User(id="u1", language_preference="zh")
|
|
conversation = Conversation(id="conv-1", user_id="u1")
|
|
segment = Segment(
|
|
id="seg-1",
|
|
conversation_id="conv-1",
|
|
user_input_text="hello",
|
|
processed=False,
|
|
)
|
|
call_order: list[str] = []
|
|
|
|
async def _record_turn(*_args, **_kwargs):
|
|
call_order.append("record_turn")
|
|
segment.lineage_json = {"primary_user_message_id": "hum-1"}
|
|
return SimpleNamespace(
|
|
human_message_id="hum-1",
|
|
assistant_message_id="ai-1",
|
|
)
|
|
|
|
async def _queue_segment(*_args, **_kwargs):
|
|
call_order.append("queue_segment")
|
|
|
|
turn = SimpleNamespace(
|
|
messages=["reply"],
|
|
skip_tts=True,
|
|
memory_retrieval_trace=None,
|
|
)
|
|
|
|
with (
|
|
patch.object(
|
|
ws_pipeline.chat_turn_service,
|
|
"process_turn",
|
|
AsyncMock(return_value=turn),
|
|
),
|
|
patch.object(
|
|
ws_pipeline.ConversationHistoryStore,
|
|
"record_human_ai_turn_with_segment",
|
|
AsyncMock(side_effect=_record_turn),
|
|
),
|
|
patch.object(
|
|
ws_pipeline,
|
|
"_schedule_memoir_ingest_for_segment",
|
|
AsyncMock(side_effect=_queue_segment),
|
|
),
|
|
patch.object(ws_pipeline.manager, "active_connections", {}),
|
|
patch.object(ws_pipeline.manager, "send_message", AsyncMock()),
|
|
patch.object(
|
|
ws_pipeline,
|
|
"get_or_create_state",
|
|
AsyncMock(return_value=SimpleNamespace(current_stage="childhood")),
|
|
),
|
|
patch.object(ws_pipeline, "maybe_send_topic_chips_ws", AsyncMock()),
|
|
):
|
|
await ws_pipeline.process_user_message(
|
|
"conv-1",
|
|
"hello",
|
|
conversation,
|
|
segment,
|
|
db,
|
|
user=user,
|
|
user_message_timestamp=datetime.now(timezone.utc),
|
|
)
|
|
|
|
assert call_order == ["record_turn", "queue_segment"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_user_message_queues_memoir_on_ai_failure() -> None:
|
|
db = MagicMock(spec=AsyncSession)
|
|
user = User(id="u1", language_preference="zh")
|
|
conversation = Conversation(id="conv-1", user_id="u1")
|
|
segment = Segment(
|
|
id="seg-1",
|
|
conversation_id="conv-1",
|
|
user_input_text="hello",
|
|
processed=False,
|
|
)
|
|
queued = False
|
|
|
|
async def _queue_segment(*_args, **_kwargs):
|
|
nonlocal queued
|
|
queued = True
|
|
|
|
turn = SimpleNamespace(messages=[], skip_tts=True, memory_retrieval_trace=None)
|
|
|
|
with (
|
|
patch.object(
|
|
ws_pipeline.chat_turn_service,
|
|
"process_turn",
|
|
AsyncMock(return_value=turn),
|
|
),
|
|
patch.object(
|
|
ws_pipeline.ConversationHistoryStore,
|
|
"record_human_ai_turn_with_segment",
|
|
AsyncMock(return_value=None),
|
|
),
|
|
patch.object(
|
|
ws_pipeline,
|
|
"_schedule_memoir_ingest_for_segment",
|
|
AsyncMock(side_effect=_queue_segment),
|
|
),
|
|
patch.object(ws_pipeline.manager, "active_connections", {"conv-1": MagicMock()}),
|
|
patch.object(ws_pipeline.manager, "send_message", AsyncMock()),
|
|
):
|
|
await ws_pipeline.process_user_message(
|
|
"conv-1",
|
|
"hello",
|
|
conversation,
|
|
segment,
|
|
db,
|
|
user=user,
|
|
user_message_timestamp=datetime.now(timezone.utc),
|
|
)
|
|
|
|
assert queued is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ingest_batch_idempotent_by_segment_id(monkeypatch) -> None:
|
|
stored: dict[str, SimpleNamespace] = {}
|
|
create_calls = 0
|
|
|
|
class FakeSession:
|
|
async def commit(self) -> None:
|
|
pass
|
|
|
|
async def flush(self) -> None:
|
|
pass
|
|
|
|
async def fake_get(db, *, user_id: str, segment_id: str):
|
|
return stored.get(segment_id)
|
|
|
|
async def fake_create_source(session, **kwargs):
|
|
nonlocal create_calls
|
|
create_calls += 1
|
|
sid = kwargs["segment_id"]
|
|
src = SimpleNamespace(id=f"src-{create_calls}", segment_id=sid)
|
|
stored[sid] = src
|
|
return src
|
|
|
|
async def fake_create_chunk(*_args, **kwargs):
|
|
return SimpleNamespace(id=f"ch-{kwargs.get('chunk_index')}")
|
|
|
|
class FakeEmbeddingService:
|
|
def __init__(self, *_args, **_kwargs) -> None:
|
|
pass
|
|
|
|
async def embed_source(self, user_id: str, source_id: str) -> dict:
|
|
return {"status": "success", "vectors_written": 1}
|
|
|
|
monkeypatch.setattr(
|
|
"app.features.memory.ingest_service.get_transcript_source_by_segment_id",
|
|
fake_get,
|
|
)
|
|
monkeypatch.setattr(
|
|
"app.features.memory.ingest_service.create_source",
|
|
fake_create_source,
|
|
)
|
|
monkeypatch.setattr(
|
|
"app.features.memory.ingest_service.create_chunk",
|
|
fake_create_chunk,
|
|
)
|
|
monkeypatch.setattr(
|
|
"app.features.memory.ingest_service.MemoryEmbeddingService",
|
|
FakeEmbeddingService,
|
|
)
|
|
monkeypatch.setattr("app.features.memory.constants.memory.enrichment_enabled", False)
|
|
|
|
service = MemoryIngestService(
|
|
FakeSession(), # type: ignore[arg-type]
|
|
embedding_provider=None,
|
|
enrichment_scheduler=MagicMock(schedule_many=MagicMock(return_value=[])),
|
|
)
|
|
items = [
|
|
("c1", "hello", {"primary_user_message_id": "m1"}, "seg-1"),
|
|
]
|
|
first = await service.ingest_transcripts_batch("u1", items)
|
|
second = await service.ingest_transcripts_batch("u1", items)
|
|
|
|
assert first == ["src-1"]
|
|
assert second == ["src-1"]
|
|
assert create_calls == 1
|
|
|
|
|
|
def test_phase1_memory_ingest_batch_sync_reraises_on_failure(monkeypatch) -> None:
|
|
async def _fail(*_args, **_kwargs):
|
|
raise RuntimeError("ingest unavailable")
|
|
|
|
monkeypatch.setattr(memoir_tasks, "_memory_ingest_transcripts_batch", _fail)
|
|
db = MagicMock()
|
|
items = [("c1", "hello", None, "seg-1")]
|
|
|
|
with pytest.raises(RuntimeError, match="ingest unavailable"):
|
|
memoir_tasks._phase1_memory_ingest_batch_sync(
|
|
db,
|
|
"u1",
|
|
items,
|
|
memoir_correlation_id="corr-1",
|
|
)
|
|
|
|
|
|
def test_phase1_memory_ingest_batch_sync_resolves_integrity_race(monkeypatch) -> None:
|
|
async def _race(*_args, **_kwargs):
|
|
raise IntegrityError("insert", {}, Exception("unique"))
|
|
|
|
monkeypatch.setattr(memoir_tasks, "_memory_ingest_transcripts_batch", _race)
|
|
db = MagicMock()
|
|
existing = SimpleNamespace(id="src-existing")
|
|
|
|
def _lookup(_db, *, user_id: str, segment_id: str):
|
|
assert user_id == "u1"
|
|
assert segment_id == "seg-1"
|
|
return existing
|
|
|
|
monkeypatch.setattr(
|
|
memoir_tasks,
|
|
"get_transcript_source_by_segment_id_sync",
|
|
_lookup,
|
|
)
|
|
items = [("c1", "hello", None, "seg-1")]
|
|
|
|
ids = memoir_tasks._phase1_memory_ingest_batch_sync(
|
|
db,
|
|
"u1",
|
|
items,
|
|
memoir_correlation_id="corr-1",
|
|
)
|
|
|
|
assert ids == ["src-existing"]
|