Files
operating-room-monitor-server/tests/test_surgery_repository.py

128 lines
4.1 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
import app.db.models # noqa: F401
from app.db.base import Base
from app.db.models import SurgeryResultDetailRow
from app.repositories.surgery_results import SurgeryResultRepository
from app.schemas import SurgeryConsumptionStored
@pytest.fixture
async def db_session() -> AsyncSession:
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
session = factory()
yield session
await session.close()
await engine.dispose()
@pytest.mark.asyncio
async def test_save_empty_then_load(db_session: AsyncSession) -> None:
repo = SurgeryResultRepository()
async with db_session.begin():
await repo.save_final_result(db_session, surgery_id="123456", details=[])
async with db_session.begin():
loaded = await repo.load_final_details(db_session, "123456")
assert loaded == []
@pytest.mark.asyncio
async def test_save_roundtrip(db_session: AsyncSession) -> None:
repo = SurgeryResultRepository()
ts = datetime(2026, 4, 21, 10, 0, tzinfo=timezone.utc)
details = [
SurgeryConsumptionStored(
item_id="纱布",
item_name="纱布",
qty=1,
doctor_id="D1",
timestamp=ts,
source="vision",
),
SurgeryConsumptionStored(
item_id="纱布",
item_name="纱布",
qty=1,
doctor_id="voice",
timestamp=ts,
source="voice",
),
]
async with db_session.begin():
await repo.save_final_result(db_session, surgery_id="654321", details=details)
async with db_session.begin():
loaded = await repo.load_final_details(db_session, "654321")
assert loaded is not None
assert len(loaded) == 2
assert loaded[0].qty == 1 and loaded[0].item_id == "纱布"
assert loaded[1].qty == 1
async with db_session.begin():
res = await db_session.execute(
select(SurgeryResultDetailRow)
.where(SurgeryResultDetailRow.surgery_id == "654321")
.order_by(SurgeryResultDetailRow.id)
)
orm_rows = res.scalars().all()
assert orm_rows[0].source == "vision"
assert orm_rows[1].source == "voice"
@pytest.mark.asyncio
async def test_missing_surgery_returns_none(db_session: AsyncSession) -> None:
repo = SurgeryResultRepository()
async with db_session.begin():
missing = await repo.load_final_details(db_session, "000000")
assert missing is None
@pytest.mark.asyncio
async def test_save_overwrites_previous_final_result(db_session: AsyncSession) -> None:
repo = SurgeryResultRepository()
ts1 = datetime(2026, 4, 21, 9, 0, tzinfo=timezone.utc)
ts2 = datetime(2026, 4, 21, 10, 0, tzinfo=timezone.utc)
async with db_session.begin():
await repo.save_final_result(
db_session,
surgery_id="888888",
details=[
SurgeryConsumptionStored(
item_id="",
item_name="",
qty=1,
doctor_id="D1",
timestamp=ts1,
source="vision",
),
],
)
async with db_session.begin():
await repo.save_final_result(
db_session,
surgery_id="888888",
details=[
SurgeryConsumptionStored(
item_id="",
item_name="",
qty=2,
doctor_id="D2",
timestamp=ts2,
source="voice",
),
],
)
async with db_session.begin():
loaded = await repo.load_final_details(db_session, "888888")
assert loaded is not None
assert len(loaded) == 1
assert loaded[0].item_id == ""
assert loaded[0].qty == 2