from __future__ import annotations from datetime import datetime, timezone import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine import app.db.models # noqa: F401 from app.db.base import Base from app.db.models import SurgeryResultDetailRow from app.domain.consumption import SurgeryConsumptionStored from app.repositories.surgery_results import SurgeryResultRepository @pytest.fixture async def db_session() -> AsyncSession: engine = create_async_engine("sqlite+aiosqlite:///:memory:") async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) session = factory() yield session await session.close() await engine.dispose() @pytest.mark.asyncio async def test_save_empty_then_load(db_session: AsyncSession) -> None: repo = SurgeryResultRepository() async with db_session.begin(): await repo.save_final_result(db_session, surgery_id="123456", details=[]) async with db_session.begin(): loaded = await repo.load_final_details(db_session, "123456") assert loaded == [] @pytest.mark.asyncio async def test_save_roundtrip(db_session: AsyncSession) -> None: repo = SurgeryResultRepository() ts = datetime(2026, 4, 21, 10, 0, tzinfo=timezone.utc) details = [ SurgeryConsumptionStored( item_id="纱布", item_name="纱布", qty=1, doctor_id="D1", timestamp=ts, source="vision", ), SurgeryConsumptionStored( item_id="纱布", item_name="纱布", qty=1, doctor_id="voice", timestamp=ts, source="voice", ), ] async with db_session.begin(): await repo.save_final_result(db_session, surgery_id="654321", details=details) async with db_session.begin(): loaded = await repo.load_final_details(db_session, "654321") assert loaded is not None assert len(loaded) == 2 assert loaded[0].qty == 1 and loaded[0].item_id == "纱布" assert loaded[1].qty == 1 async with db_session.begin(): res = await db_session.execute( select(SurgeryResultDetailRow) .where(SurgeryResultDetailRow.surgery_id == "654321") .order_by(SurgeryResultDetailRow.id) ) orm_rows = res.scalars().all() assert orm_rows[0].source == "vision" assert orm_rows[1].source == "voice" @pytest.mark.asyncio async def test_missing_surgery_returns_none(db_session: AsyncSession) -> None: repo = SurgeryResultRepository() async with db_session.begin(): missing = await repo.load_final_details(db_session, "000000") assert missing is None @pytest.mark.asyncio async def test_save_overwrites_previous_final_result(db_session: AsyncSession) -> None: repo = SurgeryResultRepository() ts1 = datetime(2026, 4, 21, 9, 0, tzinfo=timezone.utc) ts2 = datetime(2026, 4, 21, 10, 0, tzinfo=timezone.utc) async with db_session.begin(): await repo.save_final_result( db_session, surgery_id="888888", details=[ SurgeryConsumptionStored( item_id="旧", item_name="旧", qty=1, doctor_id="D1", timestamp=ts1, source="vision", ), ], ) async with db_session.begin(): await repo.save_final_result( db_session, surgery_id="888888", details=[ SurgeryConsumptionStored( item_id="新", item_name="新", qty=2, doctor_id="D2", timestamp=ts2, source="voice", ), ], ) async with db_session.begin(): loaded = await repo.load_final_details(db_session, "888888") assert loaded is not None assert len(loaded) == 1 assert loaded[0].item_id == "新" assert loaded[0].qty == 2