diff --git a/.env.example b/.env.example index e96e24f..c1cdfb5 100644 --- a/.env.example +++ b/.env.example @@ -11,6 +11,16 @@ POSTGRES_PORT=35432 # Optional: full async SQLAlchemy URL (overrides POSTGRES_* when set and matches defaults logic — see Settings). # DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:35432/operation_room +# 开发/首次部署:启动时执行 Base.metadata.create_all 确保表存在(默认 true)。 +# 生产请置 false,并通过 `alembic upgrade head` 应用迁移(见 alembic/ 目录)。 +# AUTO_CREATE_SCHEMA=true + +# --- Uvicorn / API server --- +# SERVER_HOST=0.0.0.0 +# SERVER_PORT=38080 +# 生产必须 false;本地开发调试可改 true(等价于旧版 reload=True)。 +# SERVER_RELOAD=false + # --- YOLO 视觉推理(内部调用,无独立 HTTP)--- # 耗材分类权重默认 app/resources/consumable_classifier.pt;手部检测为空时退化为全帧分类。 # CONSUMABLE_CLASSIFIER_WEIGHTS=/absolute/path/to/consumable_classifier.pt @@ -22,6 +32,7 @@ CONSUMABLE_CLASSIFIER_TOPK=5 # 时间窗(秒):窗内多次推理取众数后再走自动记账 / 待确认。 # CONSUMABLE_VISION_WINDOW_SEC=15 # 可选:Excel「商品名称」「产品编码」表;空则物品 id 用名称。 +# 开始手术请求体 candidate_consumables 缺省或 [] 时,优先用本表全部商品名参与推理;未配置则用分类模型全部类名。 # CONSUMABLE_CATALOG_XLSX_PATH=/path/to/视频中的商品信息表.xlsx # HAND_DETECTION_WEIGHTS=/absolute/path/to/hand_detect.pt # HAND_DETECTION_IMGSZ=640 diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..c03dfd2 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,40 @@ +[alembic] +script_location = alembic +prepend_sys_path = . +timezone = UTC +version_path_separator = os +output_encoding = utf-8 + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..b3d247e --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,66 @@ +"""Alembic environment. + +生产请用 `alembic upgrade head`;开发/测试可让 ``Settings.auto_create_schema`` 调用 +``init_db_schema()``。本文件读取 `app.config.settings`,把 asyncpg URL 转为同步 +`psycopg` URL 供 Alembic 使用(仅迁移期间)。 +""" + +from __future__ import annotations + +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import engine_from_config, pool + +from app.config import settings +import app.db.models # noqa: F401 - register ORM tables on Base.metadata +from app.db.base import Base + + +alembic_config = context.config + +if alembic_config.config_file_name is not None: + fileConfig(alembic_config.config_file_name) + +target_metadata = Base.metadata + + +def _sync_database_url() -> str: + """把 asyncpg URL 转为同步 psycopg2 URL,避免 Alembic 强依赖 async 驱动。""" + url = settings.sqlalchemy_database_url + return url.replace("postgresql+asyncpg://", "postgresql+psycopg://", 1) + + +def run_migrations_offline() -> None: + context.configure( + url=_sync_database_url(), + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + config_section = alembic_config.get_section(alembic_config.config_ini_section) or {} + config_section["sqlalchemy.url"] = _sync_database_url() + connectable = engine_from_config( + config_section, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + ) + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..b1f8b89 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,27 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/0001_initial.py b/alembic/versions/0001_initial.py new file mode 100644 index 0000000..6abaa49 --- /dev/null +++ b/alembic/versions/0001_initial.py @@ -0,0 +1,119 @@ +"""initial schema: surgery_final_results / surgery_result_details / voice_confirmation_audits + +Revision ID: 0001_initial +Revises: +Create Date: 2026-04-23 + +对应 `app.db.models` 中的三张表,与 `init_db_schema()` 的 ``Base.metadata.create_all`` +等价,作为生产环境 `alembic upgrade head` 的初始版本。 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + + +revision: str = "0001_initial" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "surgery_final_results", + sa.Column("surgery_id", sa.String(length=6), primary_key=True), + sa.Column( + "completed_at", + sa.DateTime(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + ) + + op.create_table( + "surgery_result_details", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column( + "surgery_id", + sa.String(length=6), + sa.ForeignKey( + "surgery_final_results.surgery_id", ondelete="CASCADE" + ), + nullable=False, + index=True, + ), + sa.Column("item_id", sa.String(length=256), nullable=False), + sa.Column("item_name", sa.String(length=256), nullable=False), + sa.Column("quantity", sa.Integer(), nullable=False), + sa.Column("doctor_id", sa.String(length=128), nullable=False), + sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False), + sa.Column( + "source", sa.String(length=32), nullable=False, server_default="vision" + ), + ) + op.create_index( + op.f("ix_surgery_result_details_surgery_id"), + "surgery_result_details", + ["surgery_id"], + unique=False, + ) + + op.create_table( + "voice_confirmation_audits", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("surgery_id", sa.String(length=6), nullable=False, index=True), + sa.Column( + "confirmation_id", sa.String(length=128), nullable=False, index=True + ), + sa.Column("status", sa.String(length=32), nullable=False), + sa.Column("audio_object_key", sa.String(length=512), nullable=True), + sa.Column("audio_content_type", sa.String(length=128), nullable=True), + sa.Column("audio_size_bytes", sa.Integer(), nullable=True), + sa.Column("audio_sha256", sa.String(length=64), nullable=True), + sa.Column("asr_text", sa.String(length=2048), nullable=True), + sa.Column("resolved_label", sa.String(length=256), nullable=True), + sa.Column("options_snapshot_json", sa.Text(), nullable=True), + sa.Column("error_message", sa.String(length=1024), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + ) + op.create_index( + op.f("ix_voice_confirmation_audits_surgery_id"), + "voice_confirmation_audits", + ["surgery_id"], + unique=False, + ) + op.create_index( + op.f("ix_voice_confirmation_audits_confirmation_id"), + "voice_confirmation_audits", + ["confirmation_id"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index( + op.f("ix_voice_confirmation_audits_confirmation_id"), + table_name="voice_confirmation_audits", + ) + op.drop_index( + op.f("ix_voice_confirmation_audits_surgery_id"), + table_name="voice_confirmation_audits", + ) + op.drop_table("voice_confirmation_audits") + + op.drop_index( + op.f("ix_surgery_result_details_surgery_id"), + table_name="surgery_result_details", + ) + op.drop_table("surgery_result_details") + + op.drop_table("surgery_final_results") diff --git a/app/api.py b/app/api.py index 1bf7fb5..2d4ccbd 100644 --- a/app/api.py +++ b/app/api.py @@ -235,7 +235,7 @@ async def end_surgery( responses={ status.HTTP_503_SERVICE_UNAVAILABLE: { "description": ( - "结果尚不可查询:未同时满足「已开录」且「算法已产生可返回的实时计算结果」。" + "结果尚不可查询:无至少一条消耗明细,或手术未开始、未开录成功、尚无可查归档等。" ), "model": SurgeryClientErrorResponse, }, @@ -245,7 +245,7 @@ async def end_surgery( description=( "根据手术 6 位号查询该台手术的耗材消耗明细(多行)及按物品汇总。" "手术进行中返回当前内存已记账结果;结束后返回数据库持久化结果。" - "若手术从未开始或尚无可查的最终归档,返回 503。" + "若无至少一条消耗明细(含已归档但明细为空)、手术从未开始或尚无可查归档,返回 503。" "使用 GET:只读、幂等。\n\n" "响应体 `details` 与 `summary` 的字段定义见模式 SurgeryConsumptionDetail / SurgeryConsumptionSummary;" "若服务端启用耗材 TSV 文本日志,文件明细列为 tab 分隔的 " @@ -267,13 +267,14 @@ async def get_surgery_result( ) -> SurgeryResultResponse: logger.info("Query surgery result: surgery_id={}", surgery_id) details = await pipeline.get_consumption_details_for_client(surgery_id) - if details is None: + if not details: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail={ "code": "RESULT_NOT_READY", "message": ( - "当前无该手术的可查询结果:手术未开始、未成功开录或尚未产生可返回的数据。" + "当前无该手术的可查询结果:手术未开始、未成功开录、尚无至少一条消耗明细," + "或尚无可返回的数据。" ), "surgery_id": surgery_id, }, diff --git a/app/config.py b/app/config.py index 158d5cc..1e47942 100644 --- a/app/config.py +++ b/app/config.py @@ -6,6 +6,156 @@ from typing import Any, Literal from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict + +class _SettingsGroup: + """按主题分组的 Settings 视图;属性访问代理回主 Settings 实例。 + + 主 Settings 保留所有原始平坦字段作为事实来源;此类仅提供 ``settings.video.xxx`` + 之类的分组读写入口,减少跨文件的耦合面,同时保持向后兼容。 + """ + + _FIELDS: tuple[str, ...] = () + + def __init__(self, root: "Settings") -> None: + object.__setattr__(self, "_root", root) + + def __getattr__(self, name: str) -> Any: + if name not in self._FIELDS: + raise AttributeError( + f"{type(self).__name__} has no field '{name}'; " + f"available: {self._FIELDS}" + ) + return getattr(self._root, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in self._FIELDS: + setattr(self._root, name, value) + else: + object.__setattr__(self, name, value) + + +class _VideoGroup(_SettingsGroup): + _FIELDS = ( + "video_default_backend", + "video_camera_backend_overrides_json", + "video_rtsp_url_template", + "video_rtsp_urls_json", + "video_rtsp_urls_json_file", + "video_open_timeout_sec", + "video_read_failure_reconnect_threshold", + "video_reconnect_backoff_seconds", + "video_inference_interval_sec", + "video_inference_confidence_threshold", + "video_auto_confirm_confidence", + "video_voice_confirm_min_confidence", + "video_voice_confirm_doctor_id", + "video_detail_cooldown_sec", + "video_jpeg_quality", + "video_result_doctor_id", + "video_log_inference_results", + "consumable_classifier_weights", + "consumable_classifier_imgsz", + "consumable_classifier_device", + "consumable_classifier_topk", + "consumable_min_cls_confidence", + "consumable_catalog_xlsx_path", + "consumable_vision_window_sec", + "hand_detection_weights", + "hand_detection_imgsz", + "hand_detection_conf", + "hand_detection_pad_ratio", + "hand_detection_min_crop_px", + "hand_detection_device", + "surgery_recording_max_attempts", + "surgery_recording_retry_delay_seconds", + "archive_persist_retry_interval_seconds", + "archive_persist_max_retries", + "archive_persist_backoff_cap_seconds", + "archive_persist_durable_fallback_dir", + "archive_persist_durable_fallback_enabled", + "consumption_tsv_log_enabled", + "consumption_tsv_log_path", + "consumption_log_markdown_terminal", + "consumption_log_timezone", + ) + + +class _VoiceGroup(_SettingsGroup): + _FIELDS = ( + "voice_confirmation_enabled", + "voice_upload_max_bytes", + "voice_confirm_max_failed_parse_rounds", + "voice_file_log_enabled", + "voice_file_log_path", + ) + + +class _HikvisionGroup(_SettingsGroup): + _FIELDS = ( + "hikvision_lib_dir", + "hikvision_sdk_enabled", + "hikvision_device_ip", + "hikvision_device_port", + "hikvision_user", + "hikvision_password", + "hikvision_channel", + "hikvision_preview_rtsp_template", + "hikvision_camera_rtsp_urls_json", + "hikvision_sdk_fallback_to_rtsp", + ) + + +class _MinioGroup(_SettingsGroup): + _FIELDS = ( + "minio_endpoint", + "minio_access_key", + "minio_secret_key", + "minio_bucket", + "minio_secure", + "minio_region", + ) + + +class _BaiduGroup(_SettingsGroup): + _FIELDS = ( + "baidu_speech_app_id", + "baidu_speech_api_key", + "baidu_speech_secret_key", + "baidu_speech_connection_timeout_ms", + "baidu_speech_socket_timeout_ms", + "baidu_speech_asr_dev_pid", + ) + + +class _DemoGroup(_SettingsGroup): + _FIELDS = ( + "demo_cors_enabled", + "demo_cors_origins", + "demo_orchestrator_enabled", + "demo_orchestrator_rtsp_port", + "demo_orchestrator_rtsp_json_host", + ) + + +class _DatabaseGroup(_SettingsGroup): + _FIELDS = ( + "database_url", + "postgres_user", + "postgres_password", + "postgres_db", + "postgres_host", + "postgres_port", + "auto_create_schema", + ) + + +class _ServerGroup(_SettingsGroup): + _FIELDS = ( + "server_host", + "server_port", + "server_reload", + ) + _PACKAGE_DIR = Path(__file__).resolve().parent # 仓库根目录(含 .env)。用绝对路径读 .env,避免从子目录/IDE 启动时 cwd 不同导致联调项未生效。 _REPO_ROOT = _PACKAGE_DIR.parent @@ -31,6 +181,17 @@ class Settings(BaseSettings): postgres_db: str = "operation_room" postgres_host: str = "localhost" postgres_port: int = 35432 + #: 为 true 时,lifespan 启动会调用 Base.metadata.create_all 确保表存在(开发/测试用)。 + #: 生产请置 false,并通过 ``alembic upgrade head`` 进行版本化迁移。 + auto_create_schema: bool = True + + # --- Uvicorn / API server --- + #: `uvicorn.run` 绑定的地址;默认监听所有接口(开发/容器联调常用)。 + server_host: str = "0.0.0.0" + #: HTTP 端口;生产请按部署策略显式设置。 + server_port: int = Field(default=38080, ge=1, le=65535) + #: 是否启用 `--reload`(仅本地开发;生产必须 false)。 + server_reload: bool = False consumable_classifier_weights: str | None = None consumable_classifier_imgsz: int = 224 #: Explicit Ultralytics device (e.g. cpu, mps, cuda:0). Empty -> macOS prefers MPS; Linux prefers CUDA if available. @@ -85,14 +246,20 @@ class Settings(BaseSettings): voice_confirmation_enabled: bool = True #: 语音确认记帐时的 doctor_id。 video_voice_confirm_doctor_id: str = "voice" - #: (已弃用)服务端本机录音秒数;当前闭环由客户端采集语音,此项仅保留兼容旧配置。 - voice_record_seconds: float = Field(default=5.0, ge=1.0, le=30.0) - #: (已弃用)服务端 ffmpeg 音频输入;当前闭环不依赖服务端录音。 - voice_ffmpeg_input: str = "" - #: 手术结束后归档写库失败时,后台重试落库的间隔(秒)。 + #: 手术结束后归档写库失败时,后台重试落库的间隔(秒),用作指数退避的基数。 archive_persist_retry_interval_seconds: float = Field( default=30.0, ge=5.0, le=3600.0 ) + #: 单条归档允许的最大连续重试次数。达到上限后保持 durable fallback,直到进程重启或手动介入。 + archive_persist_max_retries: int = Field(default=12, ge=1, le=10000) + #: 指数退避上限(秒),防止间隔被放大到不切实际的值。 + archive_persist_backoff_cap_seconds: float = Field( + default=900.0, ge=5.0, le=86400.0 + ) + #: 归档 durable fallback 的磁盘目录;启动/重试时会扫描其中 `*.json` 尝试恢复。 + archive_persist_durable_fallback_dir: str = "logs/pending_archive" + #: 为 true 时,首次写库失败后立即把归档写到 durable fallback 目录,避免进程重启丢数据。 + archive_persist_durable_fallback_enabled: bool = True #: 同一物品重复记一条消耗的最短间隔(秒)。 video_detail_cooldown_sec: float = Field(default=15.0, ge=0.0, le=3600.0) #: 送模型 JPEG 质量。 @@ -280,5 +447,37 @@ class Settings(BaseSettings): """仓库内示例映射路径(供文档与联调引用)。""" return _default_camera_rtsp_urls_sample_path() + @property + def video(self) -> _VideoGroup: + return _VideoGroup(self) + + @property + def voice(self) -> _VoiceGroup: + return _VoiceGroup(self) + + @property + def hikvision(self) -> _HikvisionGroup: + return _HikvisionGroup(self) + + @property + def minio(self) -> _MinioGroup: + return _MinioGroup(self) + + @property + def baidu(self) -> _BaiduGroup: + return _BaiduGroup(self) + + @property + def demo(self) -> _DemoGroup: + return _DemoGroup(self) + + @property + def database(self) -> _DatabaseGroup: + return _DatabaseGroup(self) + + @property + def server(self) -> _ServerGroup: + return _ServerGroup(self) + settings = Settings() diff --git a/app/dependencies.py b/app/dependencies.py index 322a2a0..ac942f0 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -1,65 +1,132 @@ -from loguru import logger +"""组合根:显式以 Settings 构造所有服务,挂到 app.state.container。 -from app.config import settings +避免「import 即实例化」的副作用;lifespan 内 build + shutdown,测试时可注入自定义容器。 +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from fastapi import Request +from loguru import logger +from sqlalchemy.ext.asyncio import async_sessionmaker + +from app.config import Settings +from app.config import settings as _default_settings +from app.database import AsyncSessionLocal from app.repositories.surgery_results import SurgeryResultRepository from app.repositories.voice_audits import VoiceAuditRepository from app.services.baidu_speech import BaiduSpeechService from app.services.consumable_vision_algorithm import ConsumableVisionAlgorithmService from app.services.minio_audio_storage import MinioAudioStorageService from app.services.surgery_pipeline import SurgeryPipeline -from app.services.voice_resolution import VoiceConfirmationService from app.services.video.hikvision_runtime import HikvisionRuntime from app.services.video.session_manager import CameraSessionManager +from app.services.voice_resolution import VoiceConfirmationService -consumable_vision_algorithm_service = ConsumableVisionAlgorithmService() -hikvision_runtime = HikvisionRuntime.try_load(settings.hikvision_lib_dir) -if settings.hikvision_sdk_enabled and hikvision_runtime is None: - logger.warning( - "HIKVISION_SDK_ENABLED=true but no HCNetSDK library loaded " - "(check HIKVISION_LIB_DIR / mount /opt/hikvision/lib)" +@dataclass +class AppContainer: + """显式容器:构造时即装配完所有服务,lifespan 掌控生命周期。""" + + settings: Settings + consumable_vision_algorithm_service: ConsumableVisionAlgorithmService + hikvision_runtime: HikvisionRuntime | None + surgery_result_repository: SurgeryResultRepository + voice_audit_repository: VoiceAuditRepository + baidu_speech_service: BaiduSpeechService + minio_audio_storage_service: MinioAudioStorageService + camera_session_manager: CameraSessionManager + voice_confirmation_service: VoiceConfirmationService + surgery_pipeline: SurgeryPipeline + + async def start(self) -> None: + await self.camera_session_manager.start_archive_retry_loop() + + async def shutdown(self) -> None: + await self.camera_session_manager.shutdown() + + +def build_container( + app_settings: Settings | None = None, + *, + session_factory: async_sessionmaker | None = None, +) -> AppContainer: + """基于 Settings 显式装配所有服务;不做任何 import-time 副作用。""" + s = app_settings or _default_settings + sf: async_sessionmaker = session_factory or AsyncSessionLocal + vision = ConsumableVisionAlgorithmService(app_settings=s) + hik_runtime = HikvisionRuntime.try_load(s.hikvision_lib_dir) + if s.hikvision_sdk_enabled and hik_runtime is None: + logger.warning( + "HIKVISION_SDK_ENABLED=true but no HCNetSDK library loaded " + "(check HIKVISION_LIB_DIR / mount /opt/hikvision/lib)" + ) + surgery_repo = SurgeryResultRepository() + voice_audit_repo = VoiceAuditRepository() + baidu = BaiduSpeechService(app_settings=s) + minio = MinioAudioStorageService(s) + camera_mgr = CameraSessionManager( + settings=s, + vision_algorithm=vision, + hikvision_runtime=hik_runtime, + result_repository=surgery_repo, + session_factory=sf, + ) + voice = VoiceConfirmationService( + settings=s, + sessions=camera_mgr, + baidu=baidu, + minio=minio, + audits=voice_audit_repo, + session_factory=sf, + ) + pipeline = SurgeryPipeline( + camera_mgr, + result_repository=surgery_repo, + voice_confirmation=voice, + session_factory=sf, + ) + return AppContainer( + settings=s, + consumable_vision_algorithm_service=vision, + hikvision_runtime=hik_runtime, + surgery_result_repository=surgery_repo, + voice_audit_repository=voice_audit_repo, + baidu_speech_service=baidu, + minio_audio_storage_service=minio, + camera_session_manager=camera_mgr, + voice_confirmation_service=voice, + surgery_pipeline=pipeline, ) -surgery_result_repository = SurgeryResultRepository() -voice_audit_repository = VoiceAuditRepository() -baidu_speech_service = BaiduSpeechService() -minio_audio_storage_service = MinioAudioStorageService(settings) -camera_session_manager = CameraSessionManager( - settings=settings, - vision_algorithm=consumable_vision_algorithm_service, - hikvision_runtime=hikvision_runtime, - result_repository=surgery_result_repository, -) -voice_confirmation_service = VoiceConfirmationService( - settings=settings, - sessions=camera_session_manager, - baidu=baidu_speech_service, - minio=minio_audio_storage_service, - audits=voice_audit_repository, -) -surgery_pipeline = SurgeryPipeline( - camera_session_manager, - result_repository=surgery_result_repository, - voice_confirmation=voice_confirmation_service, -) +def get_container(request: Request) -> AppContainer: + container: AppContainer | None = getattr(request.app.state, "container", None) + if container is None: + raise RuntimeError( + "AppContainer is not initialized; lifespan should set app.state.container" + ) + return container -def get_consumable_vision_algorithm_service() -> ConsumableVisionAlgorithmService: - return consumable_vision_algorithm_service +def get_consumable_vision_algorithm_service( + request: Request, +) -> ConsumableVisionAlgorithmService: + return get_container(request).consumable_vision_algorithm_service -def get_surgery_pipeline() -> SurgeryPipeline: - return surgery_pipeline +def get_surgery_pipeline(request: Request) -> SurgeryPipeline: + return get_container(request).surgery_pipeline -def get_camera_session_manager() -> CameraSessionManager: - return camera_session_manager +def get_camera_session_manager(request: Request) -> CameraSessionManager: + return get_container(request).camera_session_manager -def get_surgery_result_repository() -> SurgeryResultRepository: - return surgery_result_repository +def get_surgery_result_repository(request: Request) -> SurgeryResultRepository: + return get_container(request).surgery_result_repository -def get_voice_confirmation_service() -> VoiceConfirmationService: - return voice_confirmation_service +def get_voice_confirmation_service(request: Request) -> VoiceConfirmationService: + return get_container(request).voice_confirmation_service diff --git a/app/domain/__init__.py b/app/domain/__init__.py new file mode 100644 index 0000000..12fe447 --- /dev/null +++ b/app/domain/__init__.py @@ -0,0 +1 @@ +"""领域对象:与 HTTP / 持久化层无关的业务数据结构。""" diff --git a/app/domain/consumption.py b/app/domain/consumption.py new file mode 100644 index 0000000..ab707ca --- /dev/null +++ b/app/domain/consumption.py @@ -0,0 +1,26 @@ +"""手术消耗明细的领域对象(服务端内部流转 / 持久化使用)。""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime + + +@dataclass +class SurgeryConsumptionStored: + """内存 / 数据库持久化用的明细行(含 source,仅服务端内部使用,不随 HTTP 返回)。 + + HTTP 层面的表示为 ``app.schemas.SurgeryConsumptionDetail``; + 转换由上层(pipeline)在返回给客户端前完成,以隔离领域与 API。 + + ``pending_confirmation_id``:非空表示该行仍为医生待确认占位(``item_name`` 一般为「待确认」); + 确认后由注册表替换为最终耗材行并清空此字段。 + """ + + item_id: str + item_name: str + qty: int + doctor_id: str + timestamp: datetime + source: str = "vision" + pending_confirmation_id: str | None = None diff --git a/app/repositories/surgery_results.py b/app/repositories/surgery_results.py index d4ab07a..c84b910 100644 --- a/app/repositories/surgery_results.py +++ b/app/repositories/surgery_results.py @@ -6,7 +6,7 @@ from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from app.db.models import SurgeryFinalResult, SurgeryResultDetailRow -from app.schemas import SurgeryConsumptionDetail, SurgeryConsumptionStored +from app.domain.consumption import SurgeryConsumptionStored class SurgeryResultRepository: @@ -47,7 +47,8 @@ class SurgeryResultRepository: async def load_final_details( self, session: AsyncSession, surgery_id: str - ) -> list[SurgeryConsumptionDetail] | None: + ) -> list[SurgeryConsumptionStored] | None: + """返回领域对象列表(含 source);HTTP 层的转换由 pipeline 负责。""" res = await session.execute( select(SurgeryFinalResult).where(SurgeryFinalResult.surgery_id == surgery_id) ) @@ -60,13 +61,21 @@ class SurgeryResultRepository: .order_by(SurgeryResultDetailRow.id) ) rows = q.scalars().all() - return [ - SurgeryConsumptionDetail( - item_id=r.item_id, - item_name=r.item_name, - qty=r.quantity, - doctor_id=r.doctor_id, - timestamp=r.recorded_at, + out: list[SurgeryConsumptionStored] = [] + for r in rows: + pend: str | None = None + iid = r.item_id + if iid.startswith("pending:"): + pend = iid.removeprefix("pending:") + out.append( + SurgeryConsumptionStored( + item_id=r.item_id, + item_name=r.item_name, + qty=r.quantity, + doctor_id=r.doctor_id, + timestamp=r.recorded_at, + source=r.source, + pending_confirmation_id=pend, + ) ) - for r in rows - ] + return out diff --git a/app/schemas.py b/app/schemas.py index 18c17e5..cd79a49 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -1,6 +1,5 @@ from __future__ import annotations -from dataclasses import dataclass from datetime import datetime from pydantic import BaseModel, ConfigDict, Field @@ -35,9 +34,10 @@ class SurgeryStartRequest(BaseModel): candidate_consumables: list[str] = Field( default_factory=list, description=( - "本次手术可能使用到的耗材清单。" - "服务端仅对该清单内的耗材做自动记账与待确认追问;" - "若为空则不会写入任何消耗(仅拉流推理)。" + "本次手术可能使用到的耗材子集(可选)。" + "非空时仅对该清单内名称做自动记账与待确认追问。" + "缺省或空数组时,使用服务端配置的耗材目录 Excel 全部商品名;" + "未配置目录则使用分类模型全部类名。" ), ) @@ -101,32 +101,14 @@ class SurgeryConsumptionDetail(BaseModel): ), ) item_name: str = Field(description="物品名称(分类或确认后的展示名)。") - qty: int = Field(ge=0, description="本条记录对应的消耗数量。") + qty: int = Field( + ge=0, + description="本条记录对应的消耗数量;当前一次识别或一次人工确认仅追加一条明细,因此固定为 1。", + ) doctor_id: str = Field(description="医生 ID。") timestamp: datetime = Field(description="记录时间(ISO 8601,date-time)。") -@dataclass -class SurgeryConsumptionStored: - """内存 / 数据库持久化用的明细行(含 source,仅服务端内部使用,不随 HTTP 返回)。""" - - item_id: str - item_name: str - qty: int - doctor_id: str - timestamp: datetime - source: str = "vision" - - def as_response(self) -> SurgeryConsumptionDetail: - return SurgeryConsumptionDetail( - item_id=self.item_id, - item_name=self.item_name, - qty=self.qty, - doctor_id=self.doctor_id, - timestamp=self.timestamp, - ) - - class SurgeryConsumptionSummary(BaseModel): """按物品汇总:该手术下该物品消耗数量合计(item_id、item_name、total_quantity)。""" @@ -222,7 +204,7 @@ class SurgeryResultResponse(BaseModel): { "item_id": "HC001", "item_name": "纱布", - "qty": 2, + "qty": 1, "doctor_id": "D1001", "timestamp": "2026-04-21T10:30:00+08:00", }, @@ -242,7 +224,7 @@ class SurgeryResultResponse(BaseModel): }, ], "summary": [ - {"item_id": "HC001", "item_name": "纱布", "total_quantity": 3}, + {"item_id": "HC001", "item_name": "纱布", "total_quantity": 2}, {"item_id": "HC002", "item_name": "缝线", "total_quantity": 1}, ], } diff --git a/app/services/baidu_speech.py b/app/services/baidu_speech.py index daf2b92..ecf8ff0 100644 --- a/app/services/baidu_speech.py +++ b/app/services/baidu_speech.py @@ -5,7 +5,7 @@ from typing import Any from aip import AipSpeech -from app.config import settings +from app.config import Settings, settings as _default_settings class BaiduSpeechNotConfiguredError(RuntimeError): @@ -15,13 +15,14 @@ class BaiduSpeechNotConfiguredError(RuntimeError): class BaiduSpeechService: """百度短语音识别(asr)与在线语音合成(synthesis),基于 `baidu-aip` 的 `AipSpeech`。""" - def __init__(self) -> None: + def __init__(self, app_settings: Settings | None = None) -> None: + self._s = app_settings or _default_settings self._client: AipSpeech | None = None self._lock = Lock() @property def configured(self) -> bool: - return settings.baidu_speech_configured + return self._s.baidu_speech_configured def _client_or_raise(self) -> AipSpeech: if not self.configured: @@ -32,16 +33,16 @@ class BaiduSpeechService: with self._lock: if self._client is None: client = AipSpeech( - settings.baidu_speech_app_id, - settings.baidu_speech_api_key, - settings.baidu_speech_secret_key, + self._s.baidu_speech_app_id, + self._s.baidu_speech_api_key, + self._s.baidu_speech_secret_key, ) - if settings.baidu_speech_connection_timeout_ms is not None: + if self._s.baidu_speech_connection_timeout_ms is not None: client.setConnectionTimeoutInMillis( - settings.baidu_speech_connection_timeout_ms + self._s.baidu_speech_connection_timeout_ms ) - if settings.baidu_speech_socket_timeout_ms is not None: - client.setSocketTimeoutInMillis(settings.baidu_speech_socket_timeout_ms) + if self._s.baidu_speech_socket_timeout_ms is not None: + client.setSocketTimeoutInMillis(self._s.baidu_speech_socket_timeout_ms) self._client = client return self._client @@ -57,7 +58,7 @@ class BaiduSpeechService: 固定使用普通话模型(`dev_pid` 来自配置),避免未传参时误用服务端默认导致偏英语等结果。 """ merged: dict[str, Any] = dict(options or {}) - merged["dev_pid"] = int(settings.baidu_speech_asr_dev_pid) + merged["dev_pid"] = int(self._s.baidu_speech_asr_dev_pid) return self._client_or_raise().asr(speech, format, rate, merged) def synthesis( diff --git a/app/services/consumable_vision_algorithm.py b/app/services/consumable_vision_algorithm.py index 52997d8..9a3f827 100644 --- a/app/services/consumable_vision_algorithm.py +++ b/app/services/consumable_vision_algorithm.py @@ -20,7 +20,11 @@ from ultralytics import YOLO from app.config import Settings, settings -os.environ["YOLO_CONFIG_DIR"] = "/tmp" + +def _ensure_yolo_config_dir() -> None: + """Ultralytics 需要可写 YOLO_CONFIG_DIR;仅在未设置时给一个安全默认,不覆盖用户配置。""" + if not os.environ.get("YOLO_CONFIG_DIR"): + os.environ["YOLO_CONFIG_DIR"] = "/tmp" def resolve_inference_device(explicit: str) -> str | None: @@ -184,40 +188,62 @@ def pad_box( return nx1, ny1, nx2, ny2 +def _probs_data_to_numpy1d(raw) -> np.ndarray: + """分类 logits/probs 向量 → 1D float64 NumPy 数组。 + + PyTorch 张量若在 ``cuda``、``mps`` 等设备上,**必须先** ``.cpu()`` 再转 NumPy: + NumPy 只支持 CPU(主机)内存,没有 CUDA/MPS 后端;``np.asarray(cuda_tensor)`` / + ``tensor.numpy()``(设备上)都会失败。``.cpu()`` 会做一次设备→主机的拷贝(已是 CPU + 时开销很小),因此 CUDA 与 MPS 共用同一路径即可。 + """ + if raw is None: + return np.zeros((0,), dtype=np.float64) + x = raw + if hasattr(x, "detach"): + x = x.detach() + if hasattr(x, "cpu"): + x = x.cpu() + if hasattr(x, "numpy"): + # torch.Tensor / ultralytics BaseTensor 等 + x = x.numpy() + return np.asarray(x, dtype=np.float64).reshape(-1) + + def cls_top3_from_result( cls: YOLO, r, name_to_code: dict[str, str] ) -> ClsTop3 | None: pr = r[0].probs - if pr is None or not hasattr(pr, "top5") or not pr.top5: + if pr is None: return None - t5i = list(pr.top5) - tc = pr.top5conf - if tc is None: + arr = _probs_data_to_numpy1d(pr.data) + if arr.size == 0: return None + order = np.argsort(-arr, kind="stable") + t5i = [int(order[i]) for i in range(min(5, int(order.size)))] - def _ci(i: int) -> float: - if i < 0 or i >= len(tc): + def _conf_for_idx(idx: int) -> float: + if idx < 0 or idx >= arr.size: return 0.0 try: - v = tc[i] + v = arr[idx] return float(v.item() if hasattr(v, "item") else v) except (IndexError, ValueError, TypeError): return 0.0 - t1i = int(pr.top1) - c1 = _ci(0) if t5i and int(t5i[0]) == t1i else float( - pr.top1conf.item() if hasattr(pr.top1conf, "item") else pr.top1conf - ) + t1i = int(t5i[0]) + c1 = _conf_for_idx(t1i) n1 = str(cls.names.get(t1i, "")).strip() n2 = n3 = "" c2 = c3 = 0.0 if len(t5i) > 1: - n2 = str(cls.names.get(int(t5i[1]), "")).strip() - c2 = _ci(1) + i2 = int(t5i[1]) + n2 = str(cls.names.get(i2, "")).strip() + c2 = _conf_for_idx(i2) if len(t5i) > 2: - n3 = str(cls.names.get(int(t5i[2]), "")).strip() - c3 = _ci(2) + i3 = int(t5i[2]) + n3 = str(cls.names.get(i3, "")).strip() + c3 = _conf_for_idx(i3) def _pid(label: str) -> str: lb = (label or "").strip() @@ -283,12 +309,50 @@ class ConsumableVisionAlgorithmService: """手部检测(可选)+ 耗材分类;供 CameraSessionManager 在视频线程中调用。""" def __init__(self, app_settings: Settings | None = None) -> None: + _ensure_yolo_config_dir() self._s = app_settings or settings self._det: YOLO | None = None self._cls: YOLO | None = None self._det_lock = Lock() self._cls_lock = Lock() + def effective_candidate_consumables(self, requested: list[str]) -> list[str]: + """请求体中的耗材子集;未提供(缺省或仅空白)时用目录 Excel 全部商品名,无目录则用分类模型全部类名。""" + out: list[str] = [] + seen: set[str] = set() + for c in requested: + n = _norm_product_name((c or "").strip()) + if not n or n in seen: + continue + seen.add(n) + out.append(n) + if out: + return out + + xlsx_raw = (self._s.consumable_catalog_xlsx_path or "").strip() + if xlsx_raw: + path = Path(xlsx_raw).expanduser() + if path.is_file(): + try: + full = load_name_to_product_code(path) + except Exception as exc: + logger.warning("读取耗材目录 Excel 失败,回退到模型类名: {}", exc) + else: + if full: + return sorted(full.keys()) + logger.warning("耗材目录 Excel 无有效行,回退到模型类名") + else: + logger.warning( + "耗材目录 Excel 路径已配置但文件不存在: {},回退到模型类名", + path, + ) + + cls_model = self._get_cls() + labels = sorted( + {str(v).strip() for v in cls_model.names.values() if str(v).strip()} + ) + return labels + def build_name_mapping( self, candidate_consumables: list[str] ) -> dict[str, str]: diff --git a/app/services/consumption_tsv_log.py b/app/services/consumption_tsv_log.py index 6546abf..327f511 100644 --- a/app/services/consumption_tsv_log.py +++ b/app/services/consumption_tsv_log.py @@ -19,8 +19,12 @@ from app.config import settings from app.services.consumable_vision_algorithm import ClsTop3, _norm_product_name from app.terminal_markdown import print_markdown_stderr -# 制表符分隔;时间范围用 U+2013 连接;本窗消耗数量恒为 1 -HEADER = "item_id\titem_name\tqty\tdoctor_id\ttimestamp\n" +# 制表符分隔;时间范围用 U+2013 连接;本窗消耗数量恒为 1。 +# top2/top3 为模型原始排序(未按手术候选重排);item_id 仅写产品编码,无编码时留空。 +HEADER = ( + "item_id\titem_name\tqty\tdoctor_id\ttimestamp\t" + "top2_name\ttop2_conf\ttop3_name\ttop3_conf\n" +) SUMMARY_HEADER = "item_id\titem_name\tqty\n" _RANGE_SEP = "\u2013" # en dash,与样例 `00:00:00.000–00:00:45.000` 一致 @@ -86,21 +90,43 @@ def _encode_cell(value: str) -> str: return s +def resolve_consumption_ids( + t1_name: str, + t1_pid: str, + name_to_code: dict[str, str], +) -> tuple[str, str]: + """TSV 第一列 item_id 与内存汇总键。 + + - ``tsv_item_id``:仅产品编码(或模型侧 t1_pid);与展示名相同则视为无独立编码,留空。 + - ``totals_key``:汇总用稳定键;无编码时用归一化名称,避免多行空 id 碰撞。 + """ + n = (t1_name or "").strip() + norm = _norm_product_name(n) + code = (name_to_code.get(norm) or name_to_code.get(n) or "").strip() + p = (t1_pid or "").strip() + catalog = (code or p).strip() + if catalog and catalog != n: + return catalog, catalog + if catalog == n and catalog: + return "", norm + return "", norm if norm else (n or "unknown") + + def resolve_consumption_item_id( t1_name: str, t1_pid: str, name_to_code: dict[str, str], ) -> str: - """业务物品 id:`name_to_code` 的键为归一化名称,须与分类输出一同参与查找。""" - n = (t1_name or "").strip() - norm = _norm_product_name(n) - code = (name_to_code.get(norm) or name_to_code.get(n) or "").strip() - if code: - return code - p = (t1_pid or "").strip() - if p: - return p - return n + """兼容旧调用:有编码则返回编码,否则返回汇总键(归一化名或 unknown)。""" + tsv_id, totals_key = resolve_consumption_ids(t1_name, t1_pid, name_to_code) + return tsv_id or totals_key + + +def _fmt_top_conf(v: float) -> str: + try: + return f"{float(v):.4f}" + except (TypeError, ValueError): + return "0.0000" def build_tsv_line( @@ -112,15 +138,23 @@ def build_tsv_line( wall_start_epoch: float, wall_end_epoch: float, ) -> str: - id1 = resolve_consumption_item_id(best.t1_name, best.t1_pid, name_to_code) + tsv_id, _tot_key = resolve_consumption_ids( + best.t1_name, best.t1_pid, name_to_code + ) name1 = (best.t1_name or "").strip() ts = format_consumption_timestamp(camera_id, wall_start_epoch, wall_end_epoch) + n2 = (best.t2_name or "").strip() + n3 = (best.t3_name or "").strip() row = [ - _encode_cell(id1), + _encode_cell(tsv_id), _encode_cell(name1), "1", _encode_cell(doctor_id), _encode_cell(ts), + _encode_cell(n2), + _fmt_top_conf(best.t2_conf), + _encode_cell(n3), + _fmt_top_conf(best.t3_conf), ] return "\t".join(row) + "\n" @@ -185,24 +219,185 @@ def build_consumption_markdown( wall_end_epoch: float, ) -> str: """终端用:与落盘列一致;本窗 qty 恒为 1。""" - id1 = resolve_consumption_item_id(best.t1_name, best.t1_pid, name_to_code) + tsv_id, _ = resolve_consumption_ids(best.t1_name, best.t1_pid, name_to_code) n1 = (best.t1_name or "").strip() + n2 = (best.t2_name or "").strip() + n3 = (best.t3_name or "").strip() ts = format_consumption_timestamp_readable(camera_id, wall_start_epoch, wall_end_epoch) return "\n".join( [ - "| item_id | item_name | qty | doctor_id | timestamp |", - "| :--- | :--- | ---: | :--- | :--- |", - "| {} | {} | 1 | {} | {} |".format( - _md_cell(id1), + "| item_id | item_name | qty | doctor_id | timestamp | top2 | top3 |", + "| :--- | :--- | ---: | :--- | :--- | :--- | :--- |", + "| {} | {} | 1 | {} | {} | {} | {} |".format( + _md_cell(tsv_id), _md_cell(n1), _md_cell(doctor_id), _md_cell(ts), + _md_cell( + f"{n2} ({_fmt_top_conf(best.t2_conf)})" if n2 else "—", + ), + _md_cell( + f"{n3} ({_fmt_top_conf(best.t3_conf)})" if n3 else "—", + ), ), "", ] ) +PENDING_CONSUMPTION_ITEM_NAME = "待确认" + + +def _build_pending_tsv_line( + *, + confirmation_id: str, + model_snap: ClsTop3, + doctor_id: str, + camera_id: str, + wall_start_epoch: float, + wall_end_epoch: float, +) -> str: + pid = f"pending:{confirmation_id}" + ts = format_consumption_timestamp(camera_id, wall_start_epoch, wall_end_epoch) + n2 = (model_snap.t2_name or "").strip() + n3 = (model_snap.t3_name or "").strip() + row = [ + _encode_cell(pid), + _encode_cell(PENDING_CONSUMPTION_ITEM_NAME), + "1", + _encode_cell(doctor_id), + _encode_cell(ts), + _encode_cell(n2), + _fmt_top_conf(model_snap.t2_conf), + _encode_cell(n3), + _fmt_top_conf(model_snap.t3_conf), + ] + return "\t".join(row) + "\n" + + +def build_pending_consumption_markdown( + *, + confirmation_id: str, + model_snap: ClsTop3, + doctor_id: str, + camera_id: str, + wall_start_epoch: float, + wall_end_epoch: float, +) -> str: + pid = f"pending:{confirmation_id}" + n2 = (model_snap.t2_name or "").strip() + n3 = (model_snap.t3_name or "").strip() + ts = format_consumption_timestamp_readable(camera_id, wall_start_epoch, wall_end_epoch) + return "\n".join( + [ + "| item_id | item_name | qty | doctor_id | timestamp | top2 | top3 |", + "| :--- | :--- | ---: | :--- | :--- | :--- | :--- |", + "| {} | {} | 1 | {} | {} | {} | {} |".format( + _md_cell(pid), + _md_cell(PENDING_CONSUMPTION_ITEM_NAME), + _md_cell(doctor_id), + _md_cell(ts), + _md_cell( + f"{n2} ({_fmt_top_conf(model_snap.t2_conf)})" if n2 else "—", + ), + _md_cell( + f"{n3} ({_fmt_top_conf(model_snap.t3_conf)})" if n3 else "—", + ), + ), + "", + ] + ) + + +def append_consumption_pending_window( + *, + surgery_id: str, + confirmation_id: str, + model_snap: ClsTop3, + doctor_id: str, + camera_id: str, + wall_start_epoch: float, + wall_end_epoch: float, + tsv_enabled: bool | None = None, + markdown_terminal: bool | None = None, +) -> None: + """需医生确认的时间窗:落盘/终端记「待确认」,top2/3 仍保留模型提示;不更新消耗汇总。""" + en_tsv = settings.consumption_tsv_log_enabled if tsv_enabled is None else tsv_enabled + en_md = ( + settings.consumption_log_markdown_terminal + if markdown_terminal is None + else markdown_terminal + ) + if not en_tsv and not en_md: + return + line = _build_pending_tsv_line( + confirmation_id=confirmation_id, + model_snap=model_snap, + doctor_id=doctor_id, + camera_id=camera_id, + wall_start_epoch=wall_start_epoch, + wall_end_epoch=wall_end_epoch, + ) + if en_tsv: + append_consumption_tsv_line(surgery_id, line) + if en_md: + print_markdown_stderr( + build_pending_consumption_markdown( + confirmation_id=confirmation_id, + model_snap=model_snap, + doctor_id=doctor_id, + camera_id=camera_id, + wall_start_epoch=wall_start_epoch, + wall_end_epoch=wall_end_epoch, + ), + ) + + +def append_consumption_voice_resolution_line( + *, + surgery_id: str, + name_to_code: dict[str, str], + chosen_label: str, + doctor_id: str, + wall_epoch: float, + tsv_enabled: bool | None = None, +) -> None: + """语音确认后追加一行最终耗材(医生 ID 多为 voice);top2/3 空列。 + + 待确认流程下,时间窗仅记「待确认」,此处写入医生选定后的正式记录。 + """ + en = settings.consumption_tsv_log_enabled if tsv_enabled is None else tsv_enabled + if not en: + return + lb = (chosen_label or "").strip() + if not lb: + return + norm = _norm_product_name(lb) + p = ( + name_to_code.get(norm) or name_to_code.get(lb) or "" + ).strip() + snap = ClsTop3( + t1_name=lb, + t1_conf=1.0, + t2_name="", + t2_conf=0.0, + t3_name="", + t3_conf=0.0, + t1_pid=p, + t2_pid="", + t3_pid="", + ) + line = build_tsv_line( + name_to_code=name_to_code, + best=snap, + doctor_id=doctor_id, + camera_id="voice", + wall_start_epoch=wall_epoch, + wall_end_epoch=wall_epoch, + ) + append_consumption_tsv_line(surgery_id, line) + + def append_consumption_log_summary( surgery_id: str, totals: dict[str, tuple[str, int]], @@ -244,6 +439,107 @@ def print_consumption_summary_markdown( print_markdown_stderr("\n".join(lines)) +class ConsumptionTsvWriter: + """注入式 consumption 日志写入器,取代模块全局 ``settings`` 读取。 + + 行为与模块级函数完全一致;保留模块级函数以维持旧调用点的兼容期。 + """ + + def __init__(self, app_settings) -> None: + self._s = app_settings + + def init_file(self, surgery_id: str) -> None: + if not self._s.consumption_tsv_log_enabled: + return + path = resolved_consumption_log_path(surgery_id) + path.parent.mkdir(parents=True, exist_ok=True) + with _lock: + with path.open("w", encoding="utf-8") as f: + f.write(HEADER) + + def append_window( + self, + *, + surgery_id: str, + name_to_code: dict[str, str], + best: ClsTop3, + doctor_id: str, + camera_id: str, + wall_start_epoch: float, + wall_end_epoch: float, + running_totals: dict[str, tuple[str, int]] | None = None, + ) -> None: + if not self._s.consumption_tsv_log_enabled and not self._s.consumption_log_markdown_terminal: + return + _tsv_id, totals_key = resolve_consumption_ids( + best.t1_name, best.t1_pid, name_to_code + ) + iname = (best.t1_name or "").strip() + if running_totals is not None: + if totals_key not in running_totals: + running_totals[totals_key] = (iname, 0) + prev_name, q = running_totals[totals_key] + running_totals[totals_key] = (prev_name, q + 1) + if self._s.consumption_tsv_log_enabled: + line = build_tsv_line( + name_to_code=name_to_code, + best=best, + doctor_id=doctor_id, + camera_id=camera_id, + wall_start_epoch=wall_start_epoch, + wall_end_epoch=wall_end_epoch, + ) + append_consumption_tsv_line(surgery_id, line) + if self._s.consumption_log_markdown_terminal: + print_markdown_stderr( + build_consumption_markdown( + name_to_code=name_to_code, + best=best, + doctor_id=doctor_id, + camera_id=camera_id, + wall_start_epoch=wall_start_epoch, + wall_end_epoch=wall_end_epoch, + ), + ) + + def append_summary( + self, + surgery_id: str, + totals: dict[str, tuple[str, int]], + ) -> None: + if not self._s.consumption_tsv_log_enabled or not totals: + return + path = resolved_consumption_log_path(surgery_id) + if not path.is_file(): + return + body = "".join( + ["\n", SUMMARY_HEADER] + + [ + "\t".join([_encode_cell(iid), _encode_cell(name), str(qty)]) + "\n" + for iid, (name, qty) in sorted(totals.items(), key=lambda x: x[0]) + ] + ) + with _lock: + with path.open("a", encoding="utf-8") as f: + f.write(body) + + def print_summary_markdown(self, totals: dict[str, tuple[str, int]]) -> None: + if not self._s.consumption_log_markdown_terminal or not totals: + return + lines = [ + "## 消耗汇总", + "", + "| item_id | item_name | qty |", + "| :--- | :--- | ---: |", + ] + for iid, (name, qty) in sorted(totals.items(), key=lambda x: x[0]): + lines.append( + "| {} | {} | {} |".format(_md_cell(iid), _md_cell(name), qty) + ) + lines.append("") + print_markdown_stderr("\n".join(lines)) + + def append_consumption_window( *, surgery_id: str, @@ -257,13 +553,15 @@ def append_consumption_window( ) -> None: if not settings.consumption_tsv_log_enabled and not settings.consumption_log_markdown_terminal: return - iid = resolve_consumption_item_id(best.t1_name, best.t1_pid, name_to_code) + _tsv_id, totals_key = resolve_consumption_ids( + best.t1_name, best.t1_pid, name_to_code + ) iname = (best.t1_name or "").strip() if running_totals is not None: - if iid not in running_totals: - running_totals[iid] = (iname, 0) - prev_name, q = running_totals[iid] - running_totals[iid] = (prev_name, q + 1) + if totals_key not in running_totals: + running_totals[totals_key] = (iname, 0) + prev_name, q = running_totals[totals_key] + running_totals[totals_key] = (prev_name, q + 1) if settings.consumption_tsv_log_enabled: line = build_tsv_line( name_to_code=name_to_code, diff --git a/app/services/pending_confirmation_port.py b/app/services/pending_confirmation_port.py new file mode 100644 index 0000000..fcc21e7 --- /dev/null +++ b/app/services/pending_confirmation_port.py @@ -0,0 +1,54 @@ +"""语音确认服务访问会话状态的端口协议。 + +把 `VoiceConfirmationService` 对 `CameraSessionManager` 的强依赖解耦为 +`PendingConfirmationStore` 协议;便于单元测试用 fake,并为后续拆分会话管理器 +(`SurgerySessionRegistry` 等)保留切换点。 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from app.services.video.session_manager import PendingConsumableConfirmation + + +@runtime_checkable +class PendingConfirmationStore(Protocol): + """语音确认链路需要的最小会话接口。""" + + def get_pending_confirmation_by_id( + self, + surgery_id: str, + confirmation_id: str, + ) -> "PendingConsumableConfirmation | None": + ... + + def get_surgery_candidate_consumables(self, surgery_id: str) -> list[str]: + ... + + async def record_voice_parse_failure( + self, + surgery_id: str, + confirmation_id: str, + ) -> tuple[int, int]: + ... + + async def resolve_pending_confirmation( + self, + surgery_id: str, + confirmation_id: str, + *, + chosen_label: str | None, + rejected: bool, + ) -> None: + ... + + def record_voice_trace( + self, + surgery_id: str, + *, + asr_text: str | None, + error: str | None, + ) -> None: + ... diff --git a/app/services/surgery_pipeline.py b/app/services/surgery_pipeline.py index f45ebf7..d299827 100644 --- a/app/services/surgery_pipeline.py +++ b/app/services/surgery_pipeline.py @@ -4,7 +4,10 @@ from __future__ import annotations import base64 +from sqlalchemy.ext.asyncio import async_sessionmaker + from app.database import AsyncSessionLocal +from app.domain.consumption import SurgeryConsumptionStored from app.repositories.surgery_results import SurgeryResultRepository from app.schemas import ( PendingConfirmationOption, @@ -18,6 +21,20 @@ from app.services.voice_resolution import VoiceConfirmationService, VoiceResolve from app.surgery_errors import SurgeryPipelineError +def _stored_to_response(rows: list[SurgeryConsumptionStored]) -> list[SurgeryConsumptionDetail]: + """领域对象 → HTTP DTO 的单向转换,仅在返回给客户端的边界调用。""" + return [ + SurgeryConsumptionDetail( + item_id=r.item_id, + item_name=r.item_name, + qty=r.qty, + doctor_id=r.doctor_id, + timestamp=r.timestamp, + ) + for r in rows + ] + + class SurgeryPipeline: """协调开录、停录与算法产出。路由仅在子系统确认后返回 HTTP 200。""" @@ -27,10 +44,12 @@ class SurgeryPipeline: *, result_repository: SurgeryResultRepository, voice_confirmation: VoiceConfirmationService, + session_factory: async_sessionmaker | None = None, ) -> None: self._sessions = sessions self._repo = result_repository self._voice = voice_confirmation + self._session_factory: async_sessionmaker = session_factory or AsyncSessionLocal async def start_recording( self, @@ -72,13 +91,16 @@ class SurgeryPipeline: """进行中:返回内存明细;已结束:返回数据库最终结果;持久化失败时回退内存归档。""" live = self._sessions.live_consumption_if_active(surgery_id) if live is not None: - return live - async with AsyncSessionLocal() as session: + return _stored_to_response(live) + async with self._session_factory() as session: async with session.begin(): persisted = await self._repo.load_final_details(session, surgery_id) if persisted is not None: - return persisted - return self._sessions.archived_consumption_fallback(surgery_id) + return _stored_to_response(persisted) + archived = self._sessions.archived_consumption_fallback(surgery_id) + if archived is not None: + return _stored_to_response(archived) + return None async def get_pending_confirmation_for_client( self, surgery_id: str diff --git a/app/services/video/archive_persister.py b/app/services/video/archive_persister.py new file mode 100644 index 0000000..af5f3fe --- /dev/null +++ b/app/services/video/archive_persister.py @@ -0,0 +1,330 @@ +"""手术归档持久化:写库失败后的内存归档 + 指数退避重试 + durable fallback。 + +设计目标: +- ``CameraSessionManager`` 停录后把「待落库明细」交给本模块,不再自行持有重试状态。 +- 首次写库失败时: + 1. 将归档放入内存 ``_archive`` 以便下次重试。 + 2. 若开启 durable fallback,同步写一个 JSON 文件到磁盘,进程重启后可从中恢复。 +- 后台循环以指数退避 + 最大重试次数的方式尝试把内存中的归档写库成功。达到上限仍失败时记 + 告警并保留 durable 文件,等待人工介入。 +""" + +from __future__ import annotations + +import asyncio +import json +import os +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING + +from loguru import logger +from sqlalchemy.ext.asyncio import async_sessionmaker + +from app.config import Settings +from app.domain.consumption import SurgeryConsumptionStored + +if TYPE_CHECKING: + from app.repositories.surgery_results import SurgeryResultRepository + + +@dataclass +class _ArchiveEntry: + """内存归档条目,记录尝试次数以驱动指数退避。""" + + details: list[SurgeryConsumptionStored] + attempts: int = 0 + next_attempt_monotonic: float = 0.0 + durable_path: Path | None = None + + +def _serialize_details(details: list[SurgeryConsumptionStored]) -> list[dict]: + return [ + { + "item_id": d.item_id, + "item_name": d.item_name, + "qty": d.qty, + "doctor_id": d.doctor_id, + "timestamp": d.timestamp.isoformat(), + "source": d.source, + "pending_confirmation_id": d.pending_confirmation_id, + } + for d in details + ] + + +def _deserialize_details(rows: list[dict]) -> list[SurgeryConsumptionStored]: + out: list[SurgeryConsumptionStored] = [] + for r in rows: + ts_raw = r["timestamp"] + try: + ts = datetime.fromisoformat(ts_raw) + except ValueError: + ts = datetime.now(timezone.utc) + iid = str(r["item_id"]) + pend = r.get("pending_confirmation_id") + if pend is None and iid.startswith("pending:"): + pend = iid.removeprefix("pending:") + out.append( + SurgeryConsumptionStored( + item_id=iid, + item_name=str(r["item_name"]), + qty=int(r["qty"]), + doctor_id=str(r["doctor_id"]), + timestamp=ts, + source=str(r.get("source", "vision")), + pending_confirmation_id=pend, + ) + ) + return out + + +class ArchivePersister: + """把手术结束明细写入 DB;失败时进入退避重试 + 可选 durable fallback。""" + + def __init__( + self, + *, + settings: Settings, + repository: "SurgeryResultRepository | None", + session_factory: async_sessionmaker, + ) -> None: + self._s = settings + self._repo = repository + self._session_factory = session_factory + self._archive: dict[str, _ArchiveEntry] = {} + self._lock = asyncio.Lock() + self._retry_task: asyncio.Task[None] | None = None + self._retry_stop = asyncio.Event() + + @property + def repository(self) -> "SurgeryResultRepository | None": + return self._repo + + @property + def has_pending(self) -> bool: + return bool(self._archive) + + def archived_details( + self, surgery_id: str + ) -> list[SurgeryConsumptionStored] | None: + """供 API 回退查询:读取内存归档,不访问 DB。""" + entry = self._archive.get(surgery_id) + if entry is None: + return None + return list(entry.details) + + async def take_archived_details( + self, surgery_id: str + ) -> list[SurgeryConsumptionStored] | None: + """弹出归档(用于同一手术号重新开始前的强制落库 / 移交)。""" + async with self._lock: + entry = self._archive.pop(surgery_id, None) + if entry is None: + return None + return list(entry.details) + + async def restore(self, surgery_id: str, details: list[SurgeryConsumptionStored]) -> None: + """把此前弹出的归档重新放回(比如「强制落库」再次失败时回退)。""" + async with self._lock: + self._archive[surgery_id] = _ArchiveEntry(details=list(details)) + + async def persist_or_archive( + self, + surgery_id: str, + details: list[SurgeryConsumptionStored], + ) -> bool: + """尝试立即写库;失败则放入内存归档,并按配置写入 durable fallback。""" + if await self._write_to_db(surgery_id, details): + return True + entry = _ArchiveEntry(details=list(details)) + if self._s.archive_persist_durable_fallback_enabled: + entry.durable_path = self._write_durable(surgery_id, details) + async with self._lock: + self._archive[surgery_id] = entry + logger.error( + "Surgery {} final result kept in memory archive (durable={}); " + "background retry will attempt persist", + surgery_id, + bool(entry.durable_path), + ) + return False + + async def try_persist_archive(self, surgery_id: str) -> bool: + """尝试把一条内存归档写入数据库;成功则清理内存及 durable 文件。""" + async with self._lock: + entry = self._archive.get(surgery_id) + if entry is None: + return True + if self._repo is None: + return False + ok = await self._write_to_db(surgery_id, entry.details) + if not ok: + entry.attempts += 1 + return False + async with self._lock: + removed = self._archive.pop(surgery_id, None) + if removed is not None and removed.durable_path is not None: + self._safe_remove(removed.durable_path) + logger.info("Archive persisted after retry surgery_id={}", surgery_id) + return True + + async def start_retry_loop(self) -> None: + if self._retry_task is not None and not self._retry_task.done(): + return + self._retry_stop.clear() + self._retry_task = asyncio.create_task( + self._retry_loop(), + name="archive_persist_retry", + ) + + async def shutdown(self) -> None: + self._retry_stop.set() + if self._retry_task is not None: + self._retry_task.cancel() + try: + await self._retry_task + except asyncio.CancelledError: + pass + except Exception as exc: + logger.debug("archive retry shutdown: {}", exc) + self._retry_task = None + + async def recover_from_durable_fallback(self) -> int: + """进程启动时调用:从 durable 目录把未写库的归档读回内存。""" + if not self._s.archive_persist_durable_fallback_enabled: + return 0 + directory = Path(self._s.archive_persist_durable_fallback_dir) + if not directory.exists(): + return 0 + loaded = 0 + for path in sorted(directory.glob("*.json")): + try: + raw = json.loads(path.read_text(encoding="utf-8")) + surgery_id = str(raw["surgery_id"]) + details = _deserialize_details(list(raw.get("details") or [])) + except Exception as exc: + logger.warning("Skip unreadable durable archive {}: {}", path, exc) + continue + async with self._lock: + if surgery_id in self._archive: + continue + self._archive[surgery_id] = _ArchiveEntry( + details=details, + durable_path=path, + ) + loaded += 1 + if loaded: + logger.warning( + "Recovered {} durable archive(s) from {}; retry loop will attempt persist", + loaded, + directory, + ) + return loaded + + async def _write_to_db( + self, + surgery_id: str, + details: list[SurgeryConsumptionStored], + ) -> bool: + if self._repo is None: + return True + try: + async with self._session_factory() as session: + async with session.begin(): + await self._repo.save_final_result( + session, + surgery_id=surgery_id, + details=list(details), + ) + except Exception as exc: + logger.warning( + "Persist surgery {} failed (will archive/retry): {}", surgery_id, exc + ) + return False + return True + + def _write_durable( + self, + surgery_id: str, + details: list[SurgeryConsumptionStored], + ) -> Path | None: + directory = Path(self._s.archive_persist_durable_fallback_dir) + try: + directory.mkdir(parents=True, exist_ok=True) + except Exception as exc: + logger.warning("mkdir durable archive dir {} failed: {}", directory, exc) + return None + path = directory / f"{surgery_id}.json" + payload = { + "surgery_id": surgery_id, + "saved_at": datetime.now(timezone.utc).isoformat(), + "details": _serialize_details(details), + } + try: + tmp = path.with_suffix(".json.tmp") + tmp.write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + os.replace(tmp, path) + return path + except Exception as exc: + logger.warning("write durable archive {} failed: {}", path, exc) + return None + + def _safe_remove(self, path: Path) -> None: + try: + path.unlink(missing_ok=True) + except Exception as exc: + logger.debug("remove durable archive {} failed: {}", path, exc) + + def _next_backoff_seconds(self, attempts: int) -> float: + base = float(self._s.archive_persist_retry_interval_seconds) + cap = float(self._s.archive_persist_backoff_cap_seconds) + # 指数退避:base * 2^(attempts-1),首个间隔即 base。 + exp = max(0, attempts - 1) + return min(cap, base * (2**exp)) + + async def _retry_loop(self) -> None: + base = float(self._s.archive_persist_retry_interval_seconds) + max_attempts = int(self._s.archive_persist_max_retries) + while not self._retry_stop.is_set(): + try: + await asyncio.wait_for(self._retry_stop.wait(), timeout=base) + break + except TimeoutError: + pass + + loop = asyncio.get_running_loop() + now = loop.time() + # 快照当前归档条目;后续尝试可能改变 _archive 内部状态。 + async with self._lock: + entries = [(sid, ent) for sid, ent in self._archive.items()] + + for surgery_id, entry in entries: + if self._retry_stop.is_set(): + break + if entry.attempts >= max_attempts: + # 达到上限,放弃自动重试,等待进程重启或人工介入。 + continue + if entry.next_attempt_monotonic > now: + continue + ok = await self.try_persist_archive(surgery_id) + if not ok: + # 失败:更新退避时间 + async with self._lock: + current = self._archive.get(surgery_id) + if current is not None: + current.next_attempt_monotonic = now + self._next_backoff_seconds( + current.attempts + ) + if current.attempts >= max_attempts: + logger.error( + "Archive persist exhausted retries surgery_id={} " + "attempts={}; durable={} kept for manual recovery", + surgery_id, + current.attempts, + bool(current.durable_path), + ) diff --git a/app/services/video/classification_handler.py b/app/services/video/classification_handler.py new file mode 100644 index 0000000..0d507ba --- /dev/null +++ b/app/services/video/classification_handler.py @@ -0,0 +1,211 @@ +"""视觉分类结果处理:把 ``PredictionResult`` 转成自动记账 or 待人工确认。 + +从 ``CameraSessionManager`` 抽出,保持原先行为: +- 置信度低于 ``video_voice_confirm_min_confidence`` → 丢弃。 +- 会话状态中候选清单为空 → 丢弃(开录时通常会由空请求解析为全量目录/模型类名)。 +- 置信度 ≥ ``video_auto_confirm_confidence`` 且 Top1 在候选内 → 自动追加 vision 明细,并写消耗 TSV(记具体耗材)。 +- 置信度 ≥ 自动阈值但 Top1 不在候选内 → 视 voice_confirmation_enabled 入 pending。 +- 中等置信度 → 入 pending(若有可展示候选项)。 + +需医生确认时:消耗 TSV / 内存明细记「待确认」(不写模型 top1 商品名);语音确认后再落最终耗材并更新汇总。 +""" + +from __future__ import annotations + +from loguru import logger + +from app.config import Settings +from app.services.consumable_vision_algorithm import ( + PredictionCandidate, + PredictionResult, +) +from app.services.consumption_tsv_log import ( + append_consumption_pending_window, + append_consumption_window, + resolve_consumption_item_id, +) +from app.services.video.inference_aggregator import WindowInferenceReady +from app.services.video.session_registry import ( + SurgerySessionRegistry, + SurgerySessionState, +) + + +def rank_topk_for_candidates( + topk: list[PredictionCandidate], + ordered_candidates: list[str], + *, + limit: int = 5, +) -> list[PredictionCandidate]: + if not topk: + return [] + stripped_order = [c.strip() for c in ordered_candidates if c.strip()] + if not stripped_order: + return topk[:limit] + order_index = {name: i for i, name in enumerate(stripped_order)} + picked = [c for c in topk if c.label.strip() in order_index] + picked.sort(key=lambda c: order_index[c.label.strip()]) + return picked[:limit] + + +class VisionClassificationHandler: + """把分类结果转化为 registry 上的状态变更(追加明细 / 入队待确认)。""" + + def __init__( + self, + *, + settings: Settings, + registry: SurgerySessionRegistry, + ) -> None: + self._s = settings + self._registry = registry + + def _append_vision_consumption_window_if_ready( + self, + state: SurgerySessionState, + ready: WindowInferenceReady | None, + surgery_id: str, + camera_id: str, + ) -> None: + if ( + ready is None + or not surgery_id + or not camera_id + or ( + not self._s.consumption_tsv_log_enabled + and not self._s.consumption_log_markdown_terminal + ) + ): + return + append_consumption_window( + surgery_id=surgery_id, + name_to_code=state.name_to_code, + best=ready.best, + doctor_id=self._s.video_result_doctor_id, + camera_id=camera_id, + wall_start_epoch=ready.wall_lo, + wall_end_epoch=ready.wall_hi, + running_totals=state.consumption_log_totals, + ) + + async def handle( + self, + *, + state: SurgerySessionState, + cls_res: PredictionResult, + ready: WindowInferenceReady | None = None, + surgery_id: str = "", + camera_id: str = "", + ) -> None: + conf = cls_res.confidence + label = (cls_res.label or "").strip() + item_id = resolve_consumption_item_id(label, "", state.name_to_code) + voice_floor = self._s.video_voice_confirm_min_confidence + if conf < voice_floor: + return + + cand_order = [c.strip() for c in state.candidate_consumables if c.strip()] + if not cand_order: + return + + cand_set = set(cand_order) + ranked = rank_topk_for_candidates(cls_res.topk, cand_order) + auto_th = self._s.video_auto_confirm_confidence + + def in_allowed(name: str) -> bool: + return name in cand_set + + if conf >= auto_th and in_allowed(label): + self._append_vision_consumption_window_if_ready( + state, ready, surgery_id, camera_id + ) + await self._registry.append_confirmed_detail( + state=state, + item_id=item_id or "unknown", + item_name=label or "unknown", + doctor_id=self._s.video_result_doctor_id, + source="vision", + ) + return + + if conf >= auto_th and not in_allowed(label): + if ranked and self._s.voice_confirmation_enabled: + await self._enqueue( + state, + ranked, + label, + conf, + ready=ready, + surgery_id=surgery_id, + camera_id=camera_id, + ) + return + + if not self._s.voice_confirmation_enabled: + return + + if ranked: + await self._enqueue( + state, + ranked, + label, + conf, + ready=ready, + surgery_id=surgery_id, + camera_id=camera_id, + ) + elif in_allowed(label): + await self._enqueue( + state, + [PredictionCandidate(label=label, confidence=conf)], + label, + conf, + ready=ready, + surgery_id=surgery_id, + camera_id=camera_id, + ) + + async def _enqueue( + self, + state: SurgerySessionState, + ranked: list[PredictionCandidate], + top_key: str, + top_confidence: float, + *, + ready: WindowInferenceReady | None = None, + surgery_id: str = "", + camera_id: str = "", + ) -> None: + cid = await self._registry.enqueue_pending_confirmation( + state, + ranked, + top_key=top_key, + top_confidence=top_confidence, + ) + if cid is None: + return + logger.info( + "Enqueued pending consumable confirmation id={} top_key={}", + cid, + top_key, + ) + if ready is not None and surgery_id and camera_id and ( + self._s.consumption_tsv_log_enabled + or self._s.consumption_log_markdown_terminal + ): + append_consumption_pending_window( + surgery_id=surgery_id, + confirmation_id=cid, + model_snap=ready.best, + doctor_id=self._s.video_result_doctor_id, + camera_id=camera_id, + wall_start_epoch=ready.wall_lo, + wall_end_epoch=ready.wall_hi, + tsv_enabled=self._s.consumption_tsv_log_enabled, + markdown_terminal=self._s.consumption_log_markdown_terminal, + ) + await self._registry.append_pending_consumption_detail( + state=state, + confirmation_id=cid, + doctor_id=self._s.video_result_doctor_id, + ) diff --git a/app/services/video/inference_aggregator.py b/app/services/video/inference_aggregator.py new file mode 100644 index 0000000..be1326c --- /dev/null +++ b/app/services/video/inference_aggregator.py @@ -0,0 +1,93 @@ +"""时间窗聚合:按 ``consumable_vision_window_sec`` 桶内众数投票,产出 ``WindowInferenceReady``。 + +从 ``CameraSessionManager._camera_worker`` 的时间窗计票逻辑独立出来,便于单测。 +消耗 TSV / 终端 Markdown 在 ``VisionClassificationHandler`` 中按「自动确认 / 待确认」分支写入, +避免待确认事件在日志中先记成具体耗材名。 +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass + +from app.config import Settings +from app.services.consumable_vision_algorithm import ( + ClsTop3, + PredictionResult, + cls_top3_to_prediction_result, + window_bucket_to_best_snap, +) +from app.services.video.session_registry import ( + CameraStreamInferState, + SurgerySessionState, +) + + +@dataclass(frozen=True) +class WindowInferenceReady: + """单个已完成时间窗:原始 top3 快照 + 分类结果 + 墙钟区间(与 monotonic 窗对齐)。""" + + best: ClsTop3 + prediction: PredictionResult + wall_lo: float + wall_hi: float + + +class WindowInferenceAggregator: + """负责把单路相机的推理快照按时间窗分桶,并产出「桶内最佳」结果。 + + 本类无状态:状态保存在 ``SurgerySessionState.camera_infer`` 中, + 便于与原逻辑保持一致;调用方在持有 ``state.lock`` 时调用下面的方法。 + """ + + def __init__(self, *, settings: Settings) -> None: + self._s = settings + + def ingest_snapshot_and_collect_ready( + self, + *, + surgery_id: str, + camera_id: str, + snap: ClsTop3, + state: SurgerySessionState, + ) -> list[WindowInferenceReady]: + """摄入一条推理快照,返回本次因桶满而产出的窗口列表。 + + 调用方必须已持有 ``state.lock``。 + """ + _ = surgery_id + _ = camera_id + wsec = self._s.consumable_vision_window_sec + ready: list[WindowInferenceReady] = [] + cis = state.camera_infer.setdefault(camera_id, CameraStreamInferState()) + if cis.stream_t0 is None: + cis.stream_t0 = time.monotonic() + cis.stream_wall_start = time.time() + t_rel = time.monotonic() - cis.stream_t0 + cis.votes.append((t_rel, snap.t1_name, snap)) + current_b = int(t_rel // wsec) + while cis.next_bucket < current_b: + b = cis.next_bucket + cis.next_bucket += 1 + lo, hi = b * wsec, (b + 1) * wsec + bucket_pts = [(p, sn) for (t, p, sn) in cis.votes if lo <= t < hi] + cis.votes = [ + (t, p, sn) for (t, p, sn) in cis.votes if not (lo <= t < hi) + ] + if not bucket_pts: + continue + best = window_bucket_to_best_snap(bucket_pts) + if best is None or cis.stream_wall_start is None: + continue + wall_lo = cis.stream_wall_start + lo + wall_hi = cis.stream_wall_start + hi + pred = cls_top3_to_prediction_result(best) + ready.append( + WindowInferenceReady( + best=best, + prediction=pred, + wall_lo=wall_lo, + wall_hi=wall_hi, + ) + ) + return ready diff --git a/app/services/video/session_manager.py b/app/services/video/session_manager.py index 50ca0d3..573dd69 100644 --- a/app/services/video/session_manager.py +++ b/app/services/video/session_manager.py @@ -2,120 +2,61 @@ from __future__ import annotations import asyncio import time -import uuid -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Literal from loguru import logger from app.config import Settings +from sqlalchemy.ext.asyncio import async_sessionmaker + from app.database import AsyncSessionLocal +from app.domain.consumption import SurgeryConsumptionStored from app.repositories.surgery_results import SurgeryResultRepository -from app.schemas import SurgeryConsumptionDetail, SurgeryConsumptionStored from app.services.consumable_vision_algorithm import ( - ClsTop3, ConsumableVisionAlgorithmService, - PredictionCandidate, PredictionResult, - _norm_product_name, - cls_top3_to_prediction_result, - window_bucket_to_best_snap, ) +from app.services.video.archive_persister import ArchivePersister from app.services.video.backend_resolver import BackendResolver +from app.services.video.classification_handler import VisionClassificationHandler from app.services.video.hikvision_runtime import HikvisionInitRefCount, HikvisionRuntime -from app.services.video.rtsp_capture import RtspCapture +from app.services.video.inference_aggregator import WindowInferenceAggregator +from app.services.video.session_registry import ( + CameraStreamInferState, + PendingConsumableConfirmation, + RunningSurgery, + SurgerySessionRegistry, + SurgerySessionState, +) +from app.services.video.stream_worker import CameraStreamWorker, redact_rtsp_url from app.services.video.types import VideoBackendKind from app.services.consumption_tsv_log import ( append_consumption_log_summary, - append_consumption_window, init_consumption_log_file, print_consumption_summary_markdown, ) from app.services.voice_file_log import init_voice_log_file -from app.services.voice_confirm import build_prompt_text from app.surgery_errors import SurgeryPipelineError -@dataclass -class PendingConsumableConfirmation: - """待客户端确认的一条低置信度识别(不阻塞后续帧推理)。""" - - id: str - status: Literal["pending", "confirmed", "rejected"] - options: list[tuple[str, float]] - prompt_text: str - created_at: datetime - model_top1_label: str - model_top1_confidence: float - #: 本轮待确认在解析失败时累计次数(首败 + 重试),供 API 计算 retry_remaining。 - voice_parse_failures: int = 0 - - -@dataclass -class CameraStreamInferState: - """单路视频上的时间窗投票(与离线算法一致)。""" - - votes: list[tuple[float, str, ClsTop3]] = field(default_factory=list) - stream_t0: float | None = None - #: 与 `stream_t0` 同一次初始化时的 `time.time()`,与 monotonic 流逝秒相加得到墙钟时间戳 - stream_wall_start: float | None = None - next_bucket: int = 0 - - -@dataclass -class SurgerySessionState: - candidate_consumables: list[str] - #: 分类类名(归一化) -> 业务物品 id(Excel 产品编码或名称)。 - name_to_code: dict[str, str] = field(default_factory=dict) - camera_infer: dict[str, CameraStreamInferState] = field(default_factory=dict) - details: list[SurgeryConsumptionStored] = field(default_factory=list) - lock: asyncio.Lock = field(default_factory=asyncio.Lock) - ready: asyncio.Event = field(default_factory=asyncio.Event) - last_detail_monotonic: dict[str, float] = field(default_factory=dict) - #: 仅含 status=pending 的确认任务 id,FIFO。 - pending_fifo: list[str] = field(default_factory=list) - pending_by_id: dict[str, PendingConsumableConfirmation] = field(default_factory=dict) - last_pending_prompt_snippet: str | None = None - #: 最近一次语音确认 ASR 文本(成功识别时写入)。 - last_asr_text: str | None = None - #: 最近一次语音确认错误说明(ASR/解析失败等)。 - last_voice_error: str | None = None - #: 视觉时间窗落盘用量累计,供停录时写汇总(item_id -> 首次名称, 次数)。 - consumption_log_totals: dict[str, tuple[str, int]] = field(default_factory=dict) - - -@dataclass -class RunningSurgery: - stop_event: asyncio.Event - state: SurgerySessionState - tasks: list[asyncio.Task[None]] - - -@dataclass -class ArchivedSurgery: - details: list[SurgeryConsumptionStored] - - -def _rank_topk_for_candidates( - topk: list[PredictionCandidate], - ordered_candidates: list[str], - *, - limit: int = 5, -) -> list[PredictionCandidate]: - if not topk: - return [] - stripped_order = [c.strip() for c in ordered_candidates if c.strip()] - if not stripped_order: - return topk[:limit] - order_index = {name: i for i, name in enumerate(stripped_order)} - picked = [c for c in topk if c.label.strip() in order_index] - picked.sort(key=lambda c: order_index[c.label.strip()]) - return picked[:limit] +__all__ = [ + "CameraSessionManager", + "CameraStreamInferState", + "PendingConsumableConfirmation", + "RunningSurgery", + "SurgerySessionState", +] class CameraSessionManager: - """Per-surgery camera streams, RTSP + optional Hikvision SDK login, inference, client-side human confirm.""" + """Per-surgery camera orchestration. + + 本类负责: + 1. 开始/停止手术:创建 `SurgerySessionState`、拉起相机 worker、停录时收尾。 + 2. 把「语音确认所需的内存态」委托给 ``SurgerySessionRegistry``(实现 `PendingConfirmationStore`)。 + 3. 把「结果写库 + 失败重试 + durable fallback」委托给 ``ArchivePersister``。 + + 对外接口保持不变,上游(``SurgeryPipeline`` / ``VoiceConfirmationService``)无需感知拆分。 + """ def __init__( self, @@ -124,131 +65,92 @@ class CameraSessionManager: vision_algorithm: ConsumableVisionAlgorithmService, hikvision_runtime: HikvisionRuntime | None, result_repository: SurgeryResultRepository | None = None, + session_factory: async_sessionmaker | None = None, + registry: SurgerySessionRegistry | None = None, + archive_persister: ArchivePersister | None = None, ) -> None: self._s = settings self._vision = vision_algorithm self._hik = hikvision_runtime - self._repo = result_repository + self._session_factory: async_sessionmaker = session_factory or AsyncSessionLocal self._resolver = BackendResolver(settings, hikvision_runtime=hikvision_runtime) - self._active: dict[str, RunningSurgery] = {} - self._archive: dict[str, ArchivedSurgery] = {} - self._manager_lock = asyncio.Lock() - self._retry_task: asyncio.Task[None] | None = None - self._retry_stop = asyncio.Event() - - async def start_archive_retry_loop(self) -> None: - if self._retry_task is not None and not self._retry_task.done(): - return - self._retry_stop.clear() - self._retry_task = asyncio.create_task( - self._archive_persist_retry_loop(), - name="archive_persist_retry", + self._registry = registry or SurgerySessionRegistry(settings=settings) + self._archive = archive_persister or ArchivePersister( + settings=settings, + repository=result_repository, + session_factory=self._session_factory, + ) + self._aggregator = WindowInferenceAggregator(settings=settings) + self._classifier_handler = VisionClassificationHandler( + settings=settings, + registry=self._registry, ) + # ------------------------------------------------------------------ + # 生命周期 + # ------------------------------------------------------------------ + async def start_archive_retry_loop(self) -> None: + await self._archive.recover_from_durable_fallback() + await self._archive.start_retry_loop() + async def shutdown(self) -> None: - self._retry_stop.set() - if self._retry_task is not None: - self._retry_task.cancel() - try: - await self._retry_task - except asyncio.CancelledError: - pass - except Exception as exc: - logger.debug("retry task shutdown: {}", exc) - self._retry_task = None - async with self._manager_lock: - ids = list(self._active.keys()) + await self._archive.shutdown() + ids = self._registry.active_ids() for sid in ids: try: await self.stop_surgery(sid, require_active=False) except Exception as exc: logger.warning("shutdown stop_surgery {}: {}", sid, exc) - async def _archive_persist_retry_loop(self) -> None: - while not self._retry_stop.is_set(): - try: - await asyncio.wait_for( - self._retry_stop.wait(), - timeout=self._s.archive_persist_retry_interval_seconds, - ) - break - except TimeoutError: - pass - ids = list(self._archive.keys()) - for sid in ids: - if self._retry_stop.is_set(): - break - await self._try_persist_archive(sid) - - async def _try_persist_archive(self, surgery_id: str) -> bool: - if self._repo is None: - return False - async with self._manager_lock: - arch = self._archive.get(surgery_id) - if arch is None: - return True - try: - async with AsyncSessionLocal() as session: - async with session.begin(): - await self._repo.save_final_result( - session, - surgery_id=surgery_id, - details=list(arch.details), - ) - except Exception as exc: - logger.warning( - "Archive persist retry failed surgery_id={}: {}", - surgery_id, - exc, - ) - return False - async with self._manager_lock: - self._archive.pop(surgery_id, None) - logger.info("Archive persisted after retry surgery_id={}", surgery_id) - return True - + # ------------------------------------------------------------------ + # Surgery start / stop + # ------------------------------------------------------------------ async def start_surgery( self, surgery_id: str, camera_ids: list[str], candidate_consumables: list[str], ) -> None: - stale_archive: ArchivedSurgery | None = None - async with self._manager_lock: - if surgery_id in self._active: - raise SurgeryPipelineError( - "RECORDING_CANNOT_START", - "该手术已在录制中,请勿重复开始。", - ) - if surgery_id in self._archive: - logger.warning( - "surgery_id={} 仍有未落库归档,尝试写入数据库后再开始新会话", - surgery_id, - ) - stale_archive = self._archive.pop(surgery_id) - - if stale_archive is not None: - if self._repo is None: + if self._registry.has_active(surgery_id): + raise SurgeryPipelineError( + "RECORDING_CANNOT_START", + "该手术已在录制中,请勿重复开始。", + ) + stale = await self._archive.take_archived_details(surgery_id) + if stale is not None: + logger.warning( + "surgery_id={} 仍有未落库归档,尝试写入数据库后再开始新会话", + surgery_id, + ) + if self._archive.repository is None: logger.error( "surgery_id={} 有内存归档但未配置数据库仓库,无法持久化;" "开始新会话将丢弃该归档(仅开发/无库模式)", surgery_id, ) else: - ok = await self._persist_archived_details( - surgery_id, list(stale_archive.details) - ) + ok = await self._archive.persist_or_archive(surgery_id, stale) if not ok: - async with self._manager_lock: - self._archive[surgery_id] = stale_archive raise SurgeryPipelineError( "RECORDING_CANNOT_START", "该手术号存在尚未写入数据库的历史结果,请修复数据库或等待自动重试成功后再开始。", ) - name_to_code = self._vision.build_name_mapping(candidate_consumables) + resolved = self._vision.effective_candidate_consumables(candidate_consumables) + if not resolved: + raise SurgeryPipelineError( + "RECORDING_CANNOT_START", + "耗材候选为空:请在请求中传入 candidate_consumables,或配置耗材目录 Excel / 分类模型。", + ) + if not any(str(x).strip() for x in candidate_consumables): + logger.info( + "surgery {}: candidate_consumables 未提供,使用默认全量 {} 项", + surgery_id, + len(resolved), + ) + name_to_code = self._vision.build_name_mapping(resolved) state = SurgerySessionState( - candidate_consumables=list(candidate_consumables), + candidate_consumables=list(resolved), name_to_code=name_to_code, ) stop_event = asyncio.Event() @@ -273,8 +175,7 @@ class CameraSessionManager: run = RunningSurgery(stop_event=stop_event, state=state, tasks=tasks) init_consumption_log_file(surgery_id) init_voice_log_file(surgery_id, self._s) - async with self._manager_lock: - self._active[surgery_id] = run + await self._registry.register(surgery_id, run) try: await asyncio.wait_for( @@ -297,33 +198,8 @@ class CameraSessionManager: await self.stop_surgery(surgery_id, require_active=True) raise - async def _persist_archived_details( - self, - surgery_id: str, - details: list[SurgeryConsumptionStored], - ) -> bool: - if self._repo is None: - return True - try: - async with AsyncSessionLocal() as session: - async with session.begin(): - await self._repo.save_final_result( - session, - surgery_id=surgery_id, - details=details, - ) - except Exception as exc: - logger.exception( - "Persist archived surgery {} failed (will keep archive): {}", - surgery_id, - exc, - ) - return False - return True - async def stop_surgery(self, surgery_id: str, *, require_active: bool = True) -> None: - async with self._manager_lock: - run = self._active.pop(surgery_id, None) + run = await self._registry.unregister(surgery_id) if run is None: if require_active: raise SurgeryPipelineError( @@ -343,45 +219,20 @@ class CameraSessionManager: print_consumption_summary_markdown(totals) details = list(run.state.details) + await self._archive.persist_or_archive(surgery_id, details) - persisted = False - if self._repo is not None: - try: - async with AsyncSessionLocal() as session: - async with session.begin(): - await self._repo.save_final_result( - session, - surgery_id=surgery_id, - details=details, - ) - persisted = True - except Exception as exc: - logger.exception("Persist surgery {} failed: {}", surgery_id, exc) + # ------------------------------------------------------------------ + # PendingConfirmationStore 协议委托 + # ------------------------------------------------------------------ + def live_consumption_if_active( + self, surgery_id: str + ) -> list[SurgeryConsumptionStored] | None: + return self._registry.live_consumption_if_active(surgery_id) - async with self._manager_lock: - if not persisted: - self._archive[surgery_id] = ArchivedSurgery(details=details) - logger.error( - "Surgery {} final result kept in memory archive only; " - "background retry will attempt persist", - surgery_id, - ) - - def live_consumption_if_active(self, surgery_id: str) -> list[SurgeryConsumptionDetail] | None: - if surgery_id not in self._active: - return None - if not self._active[surgery_id].state.ready.is_set(): - return None - rows = list(self._active[surgery_id].state.details) - if not rows: - return None - return [r.as_response() for r in rows] - - def archived_consumption_fallback(self, surgery_id: str) -> list[SurgeryConsumptionDetail] | None: - arch = self._archive.get(surgery_id) - if arch is None: - return None - return [r.as_response() for r in arch.details] + def archived_consumption_fallback( + self, surgery_id: str + ) -> list[SurgeryConsumptionStored] | None: + return self._archive.archived_details(surgery_id) def record_voice_trace( self, @@ -390,57 +241,27 @@ class CameraSessionManager: asr_text: str | None, error: str | None, ) -> None: - if surgery_id not in self._active: - return - st = self._active[surgery_id].state - st.last_asr_text = asr_text - st.last_voice_error = error + self._registry.record_voice_trace(surgery_id, asr_text=asr_text, error=error) def get_pending_confirmation_by_id( self, surgery_id: str, confirmation_id: str, ) -> PendingConsumableConfirmation | None: - if surgery_id not in self._active: - return None - p = self._active[surgery_id].state.pending_by_id.get(confirmation_id) - if p is None or p.status != "pending": - return None - return p + return self._registry.get_pending_confirmation_by_id(surgery_id, confirmation_id) def get_surgery_candidate_consumables(self, surgery_id: str) -> list[str]: - """本台手术开始手术时传入的耗材候选清单(语音可任选其中一项,不限于模型 topk)。""" - if surgery_id not in self._active: - return [] - return list(self._active[surgery_id].state.candidate_consumables) + return self._registry.get_surgery_candidate_consumables(surgery_id) async def record_voice_parse_failure( self, surgery_id: str, confirmation_id: str ) -> tuple[int, int]: - """解析失败时累加计数,返回 (当前失败次数, 距上限还剩几次「重试机会」)。""" - if surgery_id not in self._active: - return 0, 0 - st = self._active[surgery_id].state - max_r = int(self._s.voice_confirm_max_failed_parse_rounds) - async with st.lock: - p = st.pending_by_id.get(confirmation_id) - if p is None or p.status != "pending": - return 0, 0 - p.voice_parse_failures += 1 - remaining = max(0, max_r - p.voice_parse_failures) - return p.voice_parse_failures, remaining + return await self._registry.record_voice_parse_failure(surgery_id, confirmation_id) def next_pending_confirmation( self, surgery_id: str ) -> PendingConsumableConfirmation | None: - if surgery_id not in self._active: - return None - st = self._active[surgery_id].state - for cid in st.pending_fifo: - p = st.pending_by_id.get(cid) - if p is not None and p.status == "pending": - return p - return None + return self._registry.next_pending_confirmation(surgery_id) async def resolve_pending_confirmation( self, @@ -450,107 +271,16 @@ class CameraSessionManager: chosen_label: str | None, rejected: bool, ) -> None: - if surgery_id not in self._active: - raise SurgeryPipelineError( - "CONFIRMATION_NOT_ACTIVE", - "该手术当前不在进行中,无法提交确认。", - ) - st = self._active[surgery_id].state - async with st.lock: - pending = st.pending_by_id.get(confirmation_id) - if pending is None: - raise SurgeryPipelineError( - "CONFIRMATION_NOT_FOUND", - "未找到该待确认项或已处理。", - ) - if pending.status != "pending": - raise SurgeryPipelineError( - "CONFIRMATION_ALREADY_RESOLVED", - "该待确认项已处理。", - ) - if rejected and chosen_label: - raise SurgeryPipelineError( - "CONFIRMATION_INVALID", - "拒绝确认时不应同时提供 chosen_label。", - ) - if not rejected and not chosen_label: - raise SurgeryPipelineError( - "CONFIRMATION_INVALID", - "请提供 chosen_label 或设置 rejected=true。", - ) - allowed_pending = {lbl.strip() for lbl, _ in pending.options if lbl.strip()} - allowed_surgery = {c.strip() for c in st.candidate_consumables if c.strip()} - if rejected: - pending.status = "rejected" - else: - label = chosen_label.strip() if chosen_label else "" - if label not in allowed_pending and label not in allowed_surgery: - raise SurgeryPipelineError( - "CONFIRMATION_INVALID", - f"所选耗材不在本台手术候选清单或本次追问选项中:{chosen_label!r}", - ) - pending.status = "confirmed" - norm = _norm_product_name(label) - item_id = st.name_to_code.get(norm, label) - self._append_confirmed_detail_locked( - state=st, - item_id=item_id, - item_name=label, - doctor_id=self._s.video_voice_confirm_doctor_id, - source="voice", - ) - try: - idx = st.pending_fifo.index(confirmation_id) - st.pending_fifo.pop(idx) - except ValueError: - pass - st.pending_by_id.pop(confirmation_id, None) - - def _append_confirmed_detail_locked( - self, - *, - state: SurgerySessionState, - item_id: str, - item_name: str, - doctor_id: str, - source: str, - ) -> None: - """在已持有 `state.lock` 时追加一条消耗明细。""" - now_m = time.monotonic() - cooldown = self._s.video_detail_cooldown_sec - prev = state.last_detail_monotonic.get(item_id) - if prev is not None and (now_m - prev) < cooldown: - return - state.last_detail_monotonic[item_id] = now_m - state.details.append( - SurgeryConsumptionStored( - item_id=item_id, - item_name=item_name, - qty=1, - doctor_id=doctor_id, - timestamp=datetime.now(timezone.utc), - source=source, - ) + await self._registry.resolve_pending_confirmation( + surgery_id, + confirmation_id, + chosen_label=chosen_label, + rejected=rejected, ) - async def _append_confirmed_detail( - self, - *, - state: SurgerySessionState, - item_id: str, - item_name: str, - doctor_id: str, - source: str, - ) -> None: - async with state.lock: - self._append_confirmed_detail_locked( - state=state, - item_id=item_id, - item_name=item_name, - doctor_id=doctor_id, - source=source, - ) - + # ------------------------------------------------------------------ + # Camera worker(拉流 + 推理节流 + 时间窗分桶 + 分类结果处理) + # ------------------------------------------------------------------ async def _camera_worker( self, *, @@ -561,70 +291,23 @@ class CameraSessionManager: state: SurgerySessionState, ) -> None: kind = self._resolver.backend_for_camera(camera_id) - cap: RtspCapture | None = None hik_user_id: int | None = None hik_init_retained = False - url: str | None = None - consecutive_failures = 0 - first_ready = True - try: url, hik_user_id, hik_init_retained = await self._resolve_rtsp_url( camera_id=camera_id, kind=kind, ) assert url is not None + last_infer = 0.0 - while not stop_event.is_set(): - if cap is None: - try: - cap = RtspCapture(url, open_timeout_sec=self._s.video_open_timeout_sec) - await asyncio.to_thread(cap.open) - consecutive_failures = 0 - if first_ready: - stream_ready.set() - first_ready = False - logger.info( - "RTSP stream opened camera={} surgery={}", - camera_id, - surgery_id, - ) - except Exception as exc: - logger.warning( - "RTSP open failed camera={} surgery={}: {}", - camera_id, - surgery_id, - exc, - ) - if cap is not None: - await asyncio.to_thread(cap.release) - cap = None - await asyncio.sleep(self._s.video_reconnect_backoff_seconds) - continue - ok, frame = await asyncio.to_thread(cap.read) - if not ok or frame is None: - consecutive_failures += 1 - if consecutive_failures >= self._s.video_read_failure_reconnect_threshold: - logger.warning( - "RTSP reconnect camera={} surgery={} after {} read failures", - camera_id, - surgery_id, - consecutive_failures, - ) - await asyncio.to_thread(cap.release) - cap = None - consecutive_failures = 0 - await asyncio.sleep(self._s.video_reconnect_backoff_seconds) - else: - await asyncio.sleep(0.05) - continue - - consecutive_failures = 0 + async def _frame_handler(frame: object) -> None: + nonlocal last_infer now = time.monotonic() if now - last_infer < self._s.video_inference_interval_sec: await asyncio.sleep(0.01) - continue + return last_infer = now try: snap = await asyncio.to_thread( @@ -639,10 +322,10 @@ class CameraSessionManager: surgery_id, exc, ) - continue + return if snap is None: - continue + return if self._s.video_log_inference_results: logger.info( @@ -657,59 +340,35 @@ class CameraSessionManager: snap.t3_conf, ) - wsec = self._s.consumable_vision_window_sec - pending_preds: list[PredictionResult] = [] async with state.lock: - cis = state.camera_infer.setdefault( - camera_id, CameraStreamInferState() - ) - if cis.stream_t0 is None: - cis.stream_t0 = time.monotonic() - cis.stream_wall_start = time.time() - t_rel = time.monotonic() - cis.stream_t0 - cis.votes.append((t_rel, snap.t1_name, snap)) - current_b = int(t_rel // wsec) - while cis.next_bucket < current_b: - b = cis.next_bucket - cis.next_bucket += 1 - lo, hi = b * wsec, (b + 1) * wsec - bucket_pts = [ - (p, sn) for (t, p, sn) in cis.votes if lo <= t < hi - ] - cis.votes = [ - (t, p, sn) - for (t, p, sn) in cis.votes - if not (lo <= t < hi) - ] - if not bucket_pts: - continue - best = window_bucket_to_best_snap(bucket_pts) - if best is not None and cis.stream_wall_start is not None: - if self._s.consumption_tsv_log_enabled or self._s.consumption_log_markdown_terminal: - wall_lo = cis.stream_wall_start + lo - wall_hi = cis.stream_wall_start + hi - append_consumption_window( - surgery_id=surgery_id, - name_to_code=state.name_to_code, - best=best, - doctor_id=self._s.video_result_doctor_id, - camera_id=camera_id, - wall_start_epoch=wall_lo, - wall_end_epoch=wall_hi, - running_totals=state.consumption_log_totals, - ) - pending_preds.append( - cls_top3_to_prediction_result(best) - ) - - for cls_res in pending_preds: - await self._handle_classification_result( + ready_windows = self._aggregator.ingest_snapshot_and_collect_ready( + surgery_id=surgery_id, + camera_id=camera_id, + snap=snap, state=state, - cls_res=cls_res, ) + + for win in ready_windows: + await self._classifier_handler.handle( + state=state, + cls_res=win.prediction, + ready=win, + surgery_id=surgery_id, + camera_id=camera_id, + ) + + worker = CameraStreamWorker( + settings=self._s, + surgery_id=surgery_id, + camera_id=camera_id, + url=url, + ) + await worker.run( + stream_ready=stream_ready, + stop_event=stop_event, + frame_handler=_frame_handler, + ) finally: - if cap is not None: - await asyncio.to_thread(cap.release) if hik_user_id is not None and self._hik is not None: await asyncio.to_thread(self._hik.logout, hik_user_id) if hik_init_retained and self._hik is not None: @@ -721,96 +380,11 @@ class CameraSessionManager: state: SurgerySessionState, cls_res: PredictionResult, ) -> None: - conf = cls_res.confidence - label = (cls_res.label or "").strip() - item_id = state.name_to_code.get(label, label) - voice_floor = self._s.video_voice_confirm_min_confidence - if conf < voice_floor: - return + """Deprecated test-shim:沿用旧签名,转发给 ``VisionClassificationHandler``。 - cand_order = [c.strip() for c in state.candidate_consumables if c.strip()] - if not cand_order: - return - - cand_set = set(cand_order) - ranked = _rank_topk_for_candidates(cls_res.topk, cand_order) - auto_th = self._s.video_auto_confirm_confidence - - def in_allowed(name: str) -> bool: - return name in cand_set - - if conf >= auto_th and in_allowed(label): - await self._append_confirmed_detail( - state=state, - item_id=item_id or label or "unknown", - item_name=label or "unknown", - doctor_id=self._s.video_result_doctor_id, - source="vision", - ) - return - - if conf >= auto_th and not in_allowed(label): - if ranked and self._s.voice_confirmation_enabled: - await self._maybe_enqueue_pending_confirmation( - state, ranked, top_key=label, top_confidence=conf - ) - return - - if not self._s.voice_confirmation_enabled: - return - - if ranked: - await self._maybe_enqueue_pending_confirmation( - state, ranked, top_key=label, top_confidence=conf - ) - elif in_allowed(label): - await self._maybe_enqueue_pending_confirmation( - state, - [PredictionCandidate(label=label, confidence=conf)], - top_key=label, - top_confidence=conf, - ) - - async def _maybe_enqueue_pending_confirmation( - self, - state: SurgerySessionState, - ranked: list[PredictionCandidate], - *, - top_key: str, - top_confidence: float, - ) -> None: - opts = [(c.label.strip(), float(c.confidence)) for c in ranked if c.label.strip()] - if not opts: - return - now_m = time.monotonic() - cooldown = self._s.video_detail_cooldown_sec - dedupe_key = f"pending_confirm:{top_key}:{opts[0][0]}" - async with state.lock: - prev = state.last_detail_monotonic.get(dedupe_key) - if prev is not None and (now_m - prev) < cooldown: - return - state.last_detail_monotonic[dedupe_key] = now_m - - confirm_id = str(uuid.uuid4()) - prompt = build_prompt_text(opts) - pending = PendingConsumableConfirmation( - id=confirm_id, - status="pending", - options=list(opts), - prompt_text=prompt, - created_at=datetime.now(timezone.utc), - model_top1_label=top_key, - model_top1_confidence=top_confidence, - ) - state.pending_by_id[confirm_id] = pending - state.pending_fifo.append(confirm_id) - state.last_pending_prompt_snippet = prompt[:200] - - logger.info( - "Enqueued pending consumable confirmation id={} top_key={}", - confirm_id, - top_key, - ) + 保留此方法是因为单元测试直接调用了它。新代码应使用 ``self._classifier_handler.handle``。 + """ + await self._classifier_handler.handle(state=state, cls_res=cls_res) async def _resolve_rtsp_url( self, diff --git a/app/services/video/session_registry.py b/app/services/video/session_registry.py new file mode 100644 index 0000000..b5adfa8 --- /dev/null +++ b/app/services/video/session_registry.py @@ -0,0 +1,434 @@ +"""内存态的手术会话注册表。 + +该模块只管「活跃会话的共享内存状态」:候选耗材、推理投票、明细、语音待确认项、近期 +语音 trace。不知道 RTSP、数据库或持久化细节,便于 `VoiceConfirmationService` 等组件通过 +`PendingConfirmationStore` 协议依赖。 +""" + +from __future__ import annotations + +import asyncio +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Literal + +from app.config import Settings +from app.domain.consumption import SurgeryConsumptionStored +from app.services.consumable_vision_algorithm import ( + ClsTop3, + PredictionCandidate, + _norm_product_name, +) +from app.services.consumption_tsv_log import ( + append_consumption_voice_resolution_line, + resolve_consumption_ids, + resolve_consumption_item_id, +) +from app.services.voice_confirm import build_prompt_text +from app.surgery_errors import SurgeryPipelineError + + +@dataclass +class PendingConsumableConfirmation: + """待客户端确认的一条低置信度识别(不阻塞后续帧推理)。""" + + id: str + status: Literal["pending", "confirmed", "rejected"] + options: list[tuple[str, float]] + prompt_text: str + created_at: datetime + model_top1_label: str + model_top1_confidence: float + #: 本轮待确认在解析失败时累计次数(首败 + 重试),供 API 计算 retry_remaining。 + voice_parse_failures: int = 0 + + +@dataclass +class CameraStreamInferState: + """单路视频上的时间窗投票(与离线算法一致)。""" + + votes: list[tuple[float, str, ClsTop3]] = field(default_factory=list) + stream_t0: float | None = None + #: 与 ``stream_t0`` 同一次初始化时的 ``time.time()``,与 monotonic 流逝秒相加得到墙钟时间戳 + stream_wall_start: float | None = None + next_bucket: int = 0 + + +@dataclass +class SurgerySessionState: + candidate_consumables: list[str] + #: 分类类名(归一化) -> 业务物品 id(Excel 产品编码或名称)。 + name_to_code: dict[str, str] = field(default_factory=dict) + camera_infer: dict[str, CameraStreamInferState] = field(default_factory=dict) + details: list[SurgeryConsumptionStored] = field(default_factory=list) + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + ready: asyncio.Event = field(default_factory=asyncio.Event) + last_detail_monotonic: dict[str, float] = field(default_factory=dict) + #: 仅含 status=pending 的确认任务 id,FIFO。 + pending_fifo: list[str] = field(default_factory=list) + pending_by_id: dict[str, PendingConsumableConfirmation] = field(default_factory=dict) + last_pending_prompt_snippet: str | None = None + #: 最近一次语音确认 ASR 文本(成功识别时写入)。 + last_asr_text: str | None = None + #: 最近一次语音确认错误说明(ASR/解析失败等)。 + last_voice_error: str | None = None + #: 视觉时间窗落盘用量累计,供停录时写汇总(item_id -> 首次名称, 次数)。 + consumption_log_totals: dict[str, tuple[str, int]] = field(default_factory=dict) + + +@dataclass +class RunningSurgery: + stop_event: asyncio.Event + state: SurgerySessionState + tasks: list[asyncio.Task[None]] + + +class SurgerySessionRegistry: + """活跃手术会话的内存索引;实现 ``PendingConfirmationStore`` 协议。 + + 持有 ``_active`` 与 ``_manager_lock``;暴露只读查询与原子写入方法。 + 生命周期归 ``CameraSessionManager`` 负责,新增/停止会话都走本类。 + """ + + def __init__(self, *, settings: Settings) -> None: + self._s = settings + self._active: dict[str, RunningSurgery] = {} + self._manager_lock = asyncio.Lock() + + @property + def manager_lock(self) -> asyncio.Lock: + return self._manager_lock + + def has_active(self, surgery_id: str) -> bool: + return surgery_id in self._active + + def get_running(self, surgery_id: str) -> RunningSurgery | None: + return self._active.get(surgery_id) + + def active_ids(self) -> list[str]: + return list(self._active.keys()) + + async def register(self, surgery_id: str, running: RunningSurgery) -> None: + async with self._manager_lock: + self._active[surgery_id] = running + + async def unregister(self, surgery_id: str) -> RunningSurgery | None: + async with self._manager_lock: + return self._active.pop(surgery_id, None) + + def live_consumption_if_active( + self, surgery_id: str + ) -> list[SurgeryConsumptionStored] | None: + run = self._active.get(surgery_id) + if run is None: + return None + if not run.state.ready.is_set(): + return None + rows = list(run.state.details) + if not rows: + return None + return rows + + def record_voice_trace( + self, + surgery_id: str, + *, + asr_text: str | None, + error: str | None, + ) -> None: + run = self._active.get(surgery_id) + if run is None: + return + st = run.state + st.last_asr_text = asr_text + st.last_voice_error = error + + def get_pending_confirmation_by_id( + self, + surgery_id: str, + confirmation_id: str, + ) -> PendingConsumableConfirmation | None: + run = self._active.get(surgery_id) + if run is None: + return None + p = run.state.pending_by_id.get(confirmation_id) + if p is None or p.status != "pending": + return None + return p + + def get_surgery_candidate_consumables(self, surgery_id: str) -> list[str]: + run = self._active.get(surgery_id) + if run is None: + return [] + return list(run.state.candidate_consumables) + + async def record_voice_parse_failure( + self, surgery_id: str, confirmation_id: str + ) -> tuple[int, int]: + run = self._active.get(surgery_id) + if run is None: + return 0, 0 + st = run.state + max_r = int(self._s.voice_confirm_max_failed_parse_rounds) + async with st.lock: + p = st.pending_by_id.get(confirmation_id) + if p is None or p.status != "pending": + return 0, 0 + p.voice_parse_failures += 1 + remaining = max(0, max_r - p.voice_parse_failures) + return p.voice_parse_failures, remaining + + def next_pending_confirmation( + self, surgery_id: str + ) -> PendingConsumableConfirmation | None: + run = self._active.get(surgery_id) + if run is None: + return None + st = run.state + for cid in st.pending_fifo: + p = st.pending_by_id.get(cid) + if p is not None and p.status == "pending": + return p + return None + + async def resolve_pending_confirmation( + self, + surgery_id: str, + confirmation_id: str, + *, + chosen_label: str | None, + rejected: bool, + ) -> None: + run = self._active.get(surgery_id) + if run is None: + raise SurgeryPipelineError( + "CONFIRMATION_NOT_ACTIVE", + "该手术当前不在进行中,无法提交确认。", + ) + st = run.state + async with st.lock: + pending = st.pending_by_id.get(confirmation_id) + if pending is None: + raise SurgeryPipelineError( + "CONFIRMATION_NOT_FOUND", + "未找到该待确认项或已处理。", + ) + if pending.status != "pending": + raise SurgeryPipelineError( + "CONFIRMATION_ALREADY_RESOLVED", + "该待确认项已处理。", + ) + if rejected and chosen_label: + raise SurgeryPipelineError( + "CONFIRMATION_INVALID", + "拒绝确认时不应同时提供 chosen_label。", + ) + if not rejected and not chosen_label: + raise SurgeryPipelineError( + "CONFIRMATION_INVALID", + "请提供 chosen_label 或设置 rejected=true。", + ) + allowed_pending = {lbl.strip() for lbl, _ in pending.options if lbl.strip()} + allowed_surgery = {c.strip() for c in st.candidate_consumables if c.strip()} + if rejected: + pending.status = "rejected" + st.details = [ + d + for d in st.details + if d.pending_confirmation_id != confirmation_id + and d.item_id != f"pending:{confirmation_id}" + ] + else: + label = chosen_label.strip() if chosen_label else "" + if label not in allowed_pending and label not in allowed_surgery: + raise SurgeryPipelineError( + "CONFIRMATION_INVALID", + f"所选耗材不在本台手术候选清单或本次追问选项中:{chosen_label!r}", + ) + pending.status = "confirmed" + item_id = resolve_consumption_item_id(label, "", st.name_to_code) + resolved = SurgeryConsumptionStored( + item_id=item_id, + item_name=label, + qty=1, + doctor_id=self._s.video_voice_confirm_doctor_id, + timestamp=datetime.now(timezone.utc), + source="voice", + pending_confirmation_id=None, + ) + replaced = False + for i, d in enumerate(st.details): + if d.pending_confirmation_id == confirmation_id: + st.details[i] = resolved + replaced = True + break + if not replaced: + for i, d in enumerate(st.details): + if d.item_id == f"pending:{confirmation_id}": + st.details[i] = resolved + replaced = True + break + if not replaced: + self._append_confirmed_detail_locked( + state=st, + item_id=item_id, + item_name=label, + doctor_id=self._s.video_voice_confirm_doctor_id, + source="voice", + ) + self._finalize_voice_confirmed_consumption_log( + state=st, + surgery_id=surgery_id, + chosen_label=label, + ) + try: + idx = st.pending_fifo.index(confirmation_id) + st.pending_fifo.pop(idx) + except ValueError: + pass + st.pending_by_id.pop(confirmation_id, None) + + def _finalize_voice_confirmed_consumption_log( + self, + *, + state: SurgerySessionState, + surgery_id: str, + chosen_label: str, + ) -> None: + """待确认流程在语音落锤后:汇总 +1 最终耗材,并追加 TSV 正式行。""" + cl = (chosen_label or "").strip() + if not cl: + return + _, key_chosen = resolve_consumption_ids(cl, "", state.name_to_code) + tot = state.consumption_log_totals + if key_chosen not in tot: + tot[key_chosen] = (cl, 0) + nm, q = tot[key_chosen] + tot[key_chosen] = (nm, q + 1) + append_consumption_voice_resolution_line( + surgery_id=surgery_id, + name_to_code=state.name_to_code, + chosen_label=cl, + doctor_id=self._s.video_voice_confirm_doctor_id, + wall_epoch=time.time(), + tsv_enabled=self._s.consumption_tsv_log_enabled, + ) + + def _append_confirmed_detail_locked( + self, + *, + state: SurgerySessionState, + item_id: str, + item_name: str, + doctor_id: str, + source: str, + ) -> None: + """在已持有 ``state.lock`` 时追加一条消耗明细。""" + now_m = time.monotonic() + cooldown = self._s.video_detail_cooldown_sec + prev = state.last_detail_monotonic.get(item_id) + if prev is not None and (now_m - prev) < cooldown: + return + state.last_detail_monotonic[item_id] = now_m + state.details.append( + SurgeryConsumptionStored( + item_id=item_id, + item_name=item_name, + qty=1, + doctor_id=doctor_id, + timestamp=datetime.now(timezone.utc), + source=source, + pending_confirmation_id=None, + ) + ) + + def _append_pending_detail_locked( + self, + *, + state: SurgerySessionState, + confirmation_id: str, + doctor_id: str, + ) -> None: + pid = f"pending:{confirmation_id}" + state.details.append( + SurgeryConsumptionStored( + item_id=pid, + item_name="待确认", + qty=1, + doctor_id=doctor_id, + timestamp=datetime.now(timezone.utc), + source="pending_confirmation", + pending_confirmation_id=confirmation_id, + ) + ) + + async def append_pending_consumption_detail( + self, + *, + state: SurgerySessionState, + confirmation_id: str, + doctor_id: str, + ) -> None: + async with state.lock: + self._append_pending_detail_locked( + state=state, + confirmation_id=confirmation_id, + doctor_id=doctor_id, + ) + + async def append_confirmed_detail( + self, + *, + state: SurgerySessionState, + item_id: str, + item_name: str, + doctor_id: str, + source: str, + ) -> None: + async with state.lock: + self._append_confirmed_detail_locked( + state=state, + item_id=item_id, + item_name=item_name, + doctor_id=doctor_id, + source=source, + ) + + async def enqueue_pending_confirmation( + self, + state: SurgerySessionState, + ranked: list[PredictionCandidate], + *, + top_key: str, + top_confidence: float, + ) -> str | None: + """向 pending FIFO 追加一条待人工确认项;返回分配的 confirmation_id;冷却期内则返回 None。""" + opts = [(c.label.strip(), float(c.confidence)) for c in ranked if c.label.strip()] + if not opts: + return None + now_m = time.monotonic() + cooldown = self._s.video_detail_cooldown_sec + dedupe_key = f"pending_confirm:{top_key}:{opts[0][0]}" + async with state.lock: + prev = state.last_detail_monotonic.get(dedupe_key) + if prev is not None and (now_m - prev) < cooldown: + return None + state.last_detail_monotonic[dedupe_key] = now_m + + confirm_id = str(uuid.uuid4()) + prompt = build_prompt_text(opts) + pending = PendingConsumableConfirmation( + id=confirm_id, + status="pending", + options=list(opts), + prompt_text=prompt, + created_at=datetime.now(timezone.utc), + model_top1_label=top_key, + model_top1_confidence=top_confidence, + ) + state.pending_by_id[confirm_id] = pending + state.pending_fifo.append(confirm_id) + state.last_pending_prompt_snippet = prompt[:200] + return confirm_id diff --git a/app/services/video/stream_worker.py b/app/services/video/stream_worker.py new file mode 100644 index 0000000..30586de --- /dev/null +++ b/app/services/video/stream_worker.py @@ -0,0 +1,121 @@ +"""单路 RTSP 拉流 worker:负责打开、重连、读帧分发。 + +从 ``CameraSessionManager._camera_worker`` 抽出,保持同样的行为: +- 打开失败 → 退避 → 重试。 +- 连续读帧失败达到阈值 → 释放连接 → 退避 → 重试。 +- 读到可用帧后交给上游 ``frame_handler``,由其决定是否推理 / 跳帧。 + +不知道手术会话、推理结果或数据库。日志中出现 RTSP URL 时会脱敏 user:password。 +""" + +from __future__ import annotations + +import asyncio +import re +from typing import Awaitable, Callable + +from loguru import logger + +from app.config import Settings +from app.services.video.rtsp_capture import RtspCapture + + +FrameHandler = Callable[[object], Awaitable[None]] + +_RTSP_CRED_RE = re.compile(r"(?Prtsp://)(?P[^@/\s]+@)") + + +def redact_rtsp_url(url: str | None) -> str: + """把 ``rtsp://user:pwd@host/...`` 脱敏为 ``rtsp://***@host/...``。""" + if not url: + return "" + return _RTSP_CRED_RE.sub(r"\g***@", url) + + +class CameraStreamWorker: + """以 async 循环封装单路 RTSP 的重连/读帧,交由 handler 处理帧。""" + + def __init__( + self, + *, + settings: Settings, + surgery_id: str, + camera_id: str, + url: str, + ) -> None: + self._s = settings + self._surgery_id = surgery_id + self._camera_id = camera_id + self._url = url + + async def run( + self, + *, + stream_ready: asyncio.Event, + stop_event: asyncio.Event, + frame_handler: FrameHandler, + ) -> None: + cap: RtspCapture | None = None + consecutive_failures = 0 + first_ready = True + safe_url = redact_rtsp_url(self._url) + + try: + while not stop_event.is_set(): + if cap is None: + try: + cap = RtspCapture( + self._url, open_timeout_sec=self._s.video_open_timeout_sec + ) + await asyncio.to_thread(cap.open) + consecutive_failures = 0 + if first_ready: + stream_ready.set() + first_ready = False + logger.info( + "RTSP stream opened camera={} surgery={} url={}", + self._camera_id, + self._surgery_id, + safe_url, + ) + except Exception as exc: + logger.warning( + "RTSP open failed camera={} surgery={} url={}: {}", + self._camera_id, + self._surgery_id, + safe_url, + exc, + ) + if cap is not None: + await asyncio.to_thread(cap.release) + cap = None + await asyncio.sleep(self._s.video_reconnect_backoff_seconds) + continue + + ok, frame = await asyncio.to_thread(cap.read) + if not ok or frame is None: + consecutive_failures += 1 + if ( + consecutive_failures + >= self._s.video_read_failure_reconnect_threshold + ): + logger.warning( + "RTSP reconnect camera={} surgery={} url={} after {} read failures", + self._camera_id, + self._surgery_id, + safe_url, + consecutive_failures, + ) + await asyncio.to_thread(cap.release) + cap = None + consecutive_failures = 0 + await asyncio.sleep(self._s.video_reconnect_backoff_seconds) + else: + await asyncio.sleep(0.05) + continue + + consecutive_failures = 0 + await frame_handler(frame) + finally: + if cap is not None: + await asyncio.to_thread(cap.release) diff --git a/app/services/voice_audit_emitter.py b/app/services/voice_audit_emitter.py new file mode 100644 index 0000000..8416b6d --- /dev/null +++ b/app/services/voice_audit_emitter.py @@ -0,0 +1,166 @@ +"""统一语音确认的「审计 + trace + 抛错」三段式。 + +`VoiceConfirmationService` 过去在 `resolve_from_wav` / `resolve_from_recognized_text` 各分支 +中重复执行 `_persist_audit + record_voice_trace + emit_voice_event + raise +SurgeryPipelineError` 三件套,本类把它们聚合成一个方法,便于线性化主流程。 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from loguru import logger +from sqlalchemy.ext.asyncio import async_sessionmaker + +from app.config import Settings +from app.repositories.voice_audits import VoiceAuditRepository +from app.services.voice_file_log import emit_voice_event +from app.surgery_errors import SurgeryPipelineError + +VoiceSource = Literal["wav", "text", "n/a"] + + +@dataclass(frozen=True) +class VoiceAuditContext: + """审计所需的「音频侧」上下文快照。""" + + audio_object_key: str | None = None + audio_content_type: str | None = None + audio_size_bytes: int | None = None + audio_sha256: str | None = None + + +class VoiceAuditEmitter: + def __init__( + self, + *, + settings: Settings, + audits: VoiceAuditRepository, + session_factory: async_sessionmaker, + ) -> None: + self._s = settings + self._audits = audits + self._session_factory = session_factory + + async def _persist_audit( + self, + *, + surgery_id: str, + confirmation_id: str, + status: str, + ctx: VoiceAuditContext, + asr_text: str | None, + resolved_label: str | None, + options_snapshot_json: str | None, + error_message: str | None, + ) -> None: + try: + async with self._session_factory() as session: + async with session.begin(): + await self._audits.save_audit( + session, + surgery_id=surgery_id, + confirmation_id=confirmation_id, + status=status, + audio_object_key=ctx.audio_object_key, + audio_content_type=ctx.audio_content_type, + audio_size_bytes=ctx.audio_size_bytes, + audio_sha256=ctx.audio_sha256, + asr_text=asr_text, + resolved_label=resolved_label, + options_snapshot_json=options_snapshot_json, + error_message=error_message, + ) + except Exception as exc: + logger.error("Persist voice audit failed: {}", exc) + + async def fail( + self, + *, + source: VoiceSource, + status: str, + code: str, + message: str, + surgery_id: str, + confirmation_id: str, + ctx: VoiceAuditContext | None = None, + asr_text: str | None = None, + options_snapshot_json: str | None = None, + record_session_trace: bool = True, + session_trace_recorder=None, # Callable[[str | None, str | None], None] + include_extra: dict[str, object] | None = None, + persist_audit: bool = True, + emit_trace: bool = True, + ) -> SurgeryPipelineError: + """统一失败路径:audit + trace + session trace + 返回待抛错。 + + 调用方使用 `raise await emitter.fail(...)` 完成抛出。 + """ + ctx = ctx or VoiceAuditContext() + if persist_audit: + await self._persist_audit( + surgery_id=surgery_id, + confirmation_id=confirmation_id, + status=status, + ctx=ctx, + asr_text=asr_text, + resolved_label=None, + options_snapshot_json=options_snapshot_json, + error_message=message, + ) + if record_session_trace and session_trace_recorder is not None: + try: + session_trace_recorder(asr_text, message) + except Exception as exc: + logger.debug("session trace recorder failed: {}", exc) + if emit_trace: + emit_voice_event( + self._s, + surgery_id=surgery_id, + source=source, + status=status, + confirmation_id=confirmation_id, + asr_text=asr_text, + error_message=message, + audio_object_key=ctx.audio_object_key, + ) + if include_extra is not None: + return SurgeryPipelineError(code, message, extra=include_extra) + return SurgeryPipelineError(code, message) + + async def success( + self, + *, + source: VoiceSource, + status: str, + surgery_id: str, + confirmation_id: str, + ctx: VoiceAuditContext | None = None, + asr_text: str | None, + resolved_label: str | None, + rejected: bool, + options_snapshot_json: str | None, + ) -> None: + ctx = ctx or VoiceAuditContext() + await self._persist_audit( + surgery_id=surgery_id, + confirmation_id=confirmation_id, + status=status, + ctx=ctx, + asr_text=asr_text, + resolved_label=resolved_label, + options_snapshot_json=options_snapshot_json, + error_message=None, + ) + emit_voice_event( + self._s, + surgery_id=surgery_id, + source=source, + status=status, + confirmation_id=confirmation_id, + asr_text=asr_text, + resolved_label=resolved_label, + rejected=rejected, + audio_object_key=ctx.audio_object_key, + ) diff --git a/app/services/voice_confirm.py b/app/services/voice_confirm.py index f05a26a..f4b824c 100644 --- a/app/services/voice_confirm.py +++ b/app/services/voice_confirm.py @@ -1,19 +1,6 @@ from __future__ import annotations -import asyncio -import os -import platform import re -import shutil -import subprocess -import tempfile -from dataclasses import dataclass - -from fastapi.concurrency import run_in_threadpool -from loguru import logger - -from app.config import Settings -from app.services.baidu_speech import BaiduSpeechNotConfiguredError, BaiduSpeechService _CN_DIGITS = { @@ -200,199 +187,3 @@ def build_prompt_text(options: list[tuple[str, float]]) -> str: for i, (name, _conf) in enumerate(options, start=1): parts.append(f"第{i}个,{name}。") return "".join(parts) - - -@dataclass -class VoiceAttemptResult: - chosen_label: str | None - asr_text: str | None - error: str | None - - -class VoiceConfirmationOrchestrator: - """服务端 TTS 播报 + ffmpeg 采集 + 百度 ASR + 文本解析。""" - - def __init__(self, settings: Settings, baidu: BaiduSpeechService) -> None: - self._s = settings - self._baidu = baidu - self._lock = asyncio.Lock() - - def _ffplay_path(self) -> str | None: - return shutil.which("ffplay") - - def _ffmpeg_path(self) -> str | None: - return shutil.which("ffmpeg") - - def _record_pcm_ffmpeg(self, seconds: float) -> tuple[bytes | None, str | None]: - ffmpeg = self._ffmpeg_path() - if not ffmpeg: - return None, "ffmpeg not found in PATH" - system = platform.system() - if system == "Darwin": - dev = self._s.voice_ffmpeg_input.strip() or ":0" - input_args = ["-f", "avfoundation", "-i", dev] - else: - dev = self._s.voice_ffmpeg_input.strip() or "default" - input_args = ["-f", "alsa", "-i", dev] - - cmd = [ - ffmpeg, - "-nostdin", - "-loglevel", - "error", - "-y", - *input_args, - "-t", - str(seconds), - "-ar", - "16000", - "-ac", - "1", - "-f", - "s16le", - "-acodec", - "pcm_s16le", - "pipe:1", - ] - try: - proc = subprocess.run( - cmd, - capture_output=True, - timeout=seconds + 5.0, - check=False, - ) - except subprocess.TimeoutExpired: - return None, "ffmpeg record timeout" - if proc.returncode != 0: - err = (proc.stderr or b"").decode("utf-8", errors="replace") - return None, f"ffmpeg failed: {err or proc.returncode}" - return proc.stdout, None - - def _play_mp3_file(self, path: str) -> str | None: - ffplay = self._ffplay_path() - if not ffplay: - return "ffplay not found in PATH" - try: - proc = subprocess.run( - [ - ffplay, - "-nodisp", - "-autoexit", - "-loglevel", - "quiet", - path, - ], - capture_output=True, - timeout=120.0, - check=False, - ) - except subprocess.TimeoutExpired: - return "ffplay timeout" - if proc.returncode != 0: - return f"ffplay exit {proc.returncode}" - return None - - def _synthesize_to_temp_mp3(self, text: str) -> tuple[str | None, str | None]: - try: - audio = self._baidu.synthesis( - text, - "zh", - 1, - {"spd": 5, "pit": 5, "vol": 9, "per": 0}, - ) - except BaiduSpeechNotConfiguredError as exc: - return None, str(exc) - if isinstance(audio, dict): - return None, f"TTS error: {audio!r}" - tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) - try: - tmp.write(audio) - tmp.flush() - path = tmp.name - finally: - tmp.close() - return path, None - - async def speak_prompt(self, text: str) -> None: - """仅百度 TTS + ffplay 播报,不录音。供待确认入队时提示手术室。""" - if not (text or "").strip(): - return - if not self._s.voice_tts_on_pending_enqueued: - return - if not self._s.voice_confirmation_enabled: - return - if not self._baidu.configured: - logger.debug("speak_prompt skipped: baidu_speech not configured") - return - async with self._lock: - mp3_path, err = await run_in_threadpool(self._synthesize_to_temp_mp3, text) - if err or not mp3_path: - logger.warning("TTS synthesis failed: {}", err) - return - try: - play_err = await run_in_threadpool(self._play_mp3_file, mp3_path) - if play_err: - logger.warning("TTS play failed: {}", play_err) - finally: - try: - os.unlink(mp3_path) - except OSError: - pass - - async def run_confirmation( - self, - *, - surgery_id: str, - options: list[tuple[str, float]], - ) -> VoiceAttemptResult: - if not self._s.voice_confirmation_enabled: - return VoiceAttemptResult(None, None, "voice_confirmation_disabled") - if not options: - return VoiceAttemptResult(None, None, "no_options") - if not self._baidu.configured: - return VoiceAttemptResult(None, None, "baidu_speech_not_configured") - - labels = [o[0] for o in options] - prompt = build_prompt_text(options) - logger.info("Voice confirm surgery={} prompt_len={}", surgery_id, len(prompt)) - - async with self._lock: - mp3_path, err = await run_in_threadpool(self._synthesize_to_temp_mp3, prompt) - if err or not mp3_path: - return VoiceAttemptResult(None, None, err or "tts_failed") - try: - play_err = await run_in_threadpool(self._play_mp3_file, mp3_path) - if play_err: - return VoiceAttemptResult(None, None, play_err) - finally: - try: - os.unlink(mp3_path) - except OSError: - pass - - pcm, rec_err = await run_in_threadpool( - self._record_pcm_ffmpeg, float(self._s.voice_record_seconds) - ) - if rec_err or not pcm: - return VoiceAttemptResult(None, None, rec_err or "empty_audio") - - asr_payload = await run_in_threadpool(self._baidu.asr, pcm, "pcm", 16000, None) - if not isinstance(asr_payload, dict): - return VoiceAttemptResult(None, None, "asr_invalid_response") - if asr_payload.get("err_no") != 0: - return VoiceAttemptResult( - None, - None, - f"asr_err_{asr_payload.get('err_no')}: {asr_payload.get('err_msg')}", - ) - results = asr_payload.get("result") - text: str | None = None - if isinstance(results, list) and results: - text = str(results[0]) - elif isinstance(results, str): - text = results - if not text: - return VoiceAttemptResult(None, None, "asr_empty_text") - - chosen = parse_voice_choice(text, labels) - return VoiceAttemptResult(chosen, text, None) diff --git a/app/services/voice_file_log.py b/app/services/voice_file_log.py index 86992d6..010804b 100644 --- a/app/services/voice_file_log.py +++ b/app/services/voice_file_log.py @@ -88,6 +88,45 @@ def append_voice_tsv_line(surgery_id: str, line: str, settings: Settings) -> Non f.write(line) +class VoiceTextLogWriter: + """注入式 voice 日志写入器,封装 `init_file` / `emit_event`。 + + 行为等价于模块级函数;保留模块级函数以兼容既有调用点。 + """ + + def __init__(self, app_settings: Settings) -> None: + self._s = app_settings + + def init_file(self, surgery_id: str) -> None: + init_voice_log_file(surgery_id, self._s) + + def emit_event( + self, + *, + surgery_id: str, + source: str, + status: str, + confirmation_id: str, + asr_text: str | None = None, + resolved_label: str | None = None, + rejected: str | bool | None = None, + error_message: str | None = None, + audio_object_key: str | None = None, + ) -> None: + emit_voice_event( + self._s, + surgery_id=surgery_id, + source=source, + status=status, + confirmation_id=confirmation_id, + asr_text=asr_text, + resolved_label=resolved_label, + rejected=rejected, + error_message=error_message, + audio_object_key=audio_object_key, + ) + + def emit_voice_event( settings: Settings, *, diff --git a/app/services/voice_resolution.py b/app/services/voice_resolution.py index 261b183..f651250 100644 --- a/app/services/voice_resolution.py +++ b/app/services/voice_resolution.py @@ -1,4 +1,8 @@ -"""Resolve pending consumable confirmation from uploaded WAV: MinIO + Baidu ASR + parse.""" +"""Resolve pending consumable confirmation from uploaded WAV: MinIO + Baidu ASR + parse. + +本模块把语音识别流程线性化为一系列「阶段」,每个失败阶段都走 `VoiceAuditEmitter.fail` +统一写审计 + trace + 抛 `SurgeryPipelineError`。成功路径走 `emitter.success`。 +""" from __future__ import annotations @@ -8,14 +12,16 @@ from dataclasses import dataclass from fastapi.concurrency import run_in_threadpool from loguru import logger +from sqlalchemy.ext.asyncio import async_sessionmaker + from app.config import Settings -from app.services.voice_file_log import emit_voice_event from app.database import AsyncSessionLocal from app.repositories.voice_audits import VoiceAuditRepository from app.services.audio_wav import WavDecodeError, wav_bytes_to_pcm16k_mono_s16le from app.services.baidu_speech import BaiduSpeechNotConfiguredError, BaiduSpeechService from app.services.minio_audio_storage import MinioAudioStorageService, StoredAudio -from app.services.video.session_manager import CameraSessionManager +from app.services.pending_confirmation_port import PendingConfirmationStore +from app.services.voice_audit_emitter import VoiceAuditContext, VoiceAuditEmitter from app.services.voice_confirm import ( is_rejection_phrase, match_voice_choice_against_candidates, @@ -39,45 +45,29 @@ class VoiceConfirmationService: def __init__( self, settings: Settings, - sessions: CameraSessionManager, + sessions: PendingConfirmationStore, baidu: BaiduSpeechService, minio: MinioAudioStorageService, audits: VoiceAuditRepository, + session_factory: async_sessionmaker | None = None, + audit_emitter: VoiceAuditEmitter | None = None, ) -> None: self._s = settings self._sessions = sessions self._baidu = baidu self._minio = minio self._audits = audits - - def _emit_voice_trace( - self, - *, - source: str, - status: str, - surgery_id: str, - confirmation_id: str, - asr_text: str | None = None, - resolved_label: str | None = None, - rejected: bool | str | None = None, - error_message: str | None = None, - audio_object_key: str | None = None, - ) -> None: - emit_voice_event( - self._s, - surgery_id=surgery_id, - source=source, - status=status, - confirmation_id=confirmation_id, - asr_text=asr_text, - resolved_label=resolved_label, - rejected=rejected, - error_message=error_message, - audio_object_key=audio_object_key, + self._session_factory: async_sessionmaker = session_factory or AsyncSessionLocal + self._emitter = audit_emitter or VoiceAuditEmitter( + settings=settings, + audits=audits, + session_factory=self._session_factory, ) + # ------------------------------------------------------------------ + # TTS:保持对外接口不变 + # ------------------------------------------------------------------ def synthesize_prompt_to_mp3(self, text: str) -> bytes: - """百度在线语音合成,供浏览器直接播放,与 `voice_confirm._synthesize_to_temp_mp3` 同参。""" t = (text or "").strip() if not t: raise SurgeryPipelineError("TTS_TEXT_EMPTY", "提示文本为空。") @@ -94,6 +84,17 @@ class VoiceConfirmationService: raise SurgeryPipelineError("TTS_ERROR", f"百度 TTS 失败: {r!r}") return r + def _session_trace(self, surgery_id: str): + def _recorder(asr_text: str | None, error: str | None) -> None: + self._sessions.record_voice_trace( + surgery_id, asr_text=asr_text, error=error + ) + + return _recorder + + # ------------------------------------------------------------------ + # 主入口 + # ------------------------------------------------------------------ async def resolve_from_wav( self, *, @@ -103,74 +104,65 @@ class VoiceConfirmationService: filename: str, content_type: str | None, ) -> VoiceResolveResult: - _ = filename # reserved for logging / future MIME sniff + _ = filename # reserved for future MIME sniff + # 1) validate_size if len(wav_bytes) > self._s.voice_upload_max_bytes: - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status="invalid_audio", - audio_object_key=None, - audio_content_type=content_type, - audio_size_bytes=len(wav_bytes), - audio_sha256=None, - asr_text=None, - resolved_label=None, - options_snapshot_json=None, - error_message="音频超过大小限制", - ) - self._emit_voice_trace( + raise await self._emitter.fail( source="wav", status="invalid_audio", + code="VOICE_AUDIO_INVALID", + message=( + f"音频大小超过限制(最大 {self._s.voice_upload_max_bytes} 字节)。" + ), surgery_id=surgery_id, confirmation_id=confirmation_id, - error_message="音频超过大小限制", - ) - raise SurgeryPipelineError( - "VOICE_AUDIO_INVALID", - f"音频大小超过限制(最大 {self._s.voice_upload_max_bytes} 字节)。", + ctx=VoiceAuditContext( + audio_content_type=content_type, + audio_size_bytes=len(wav_bytes), + ), + # 仅对大小越界的情况按原实现不记 session trace + record_session_trace=False, ) + # 2) ensure_providers_configured if not self._minio.configured: - self._emit_voice_trace( + raise await self._emitter.fail( source="wav", status="minio_not_configured", + code="MINIO_NOT_CONFIGURED", + message="服务端未配置 MinIO,无法保存语音追溯文件。", surgery_id=surgery_id, confirmation_id=confirmation_id, - error_message="服务端未配置 MinIO,无法保存语音追溯文件。", + persist_audit=False, + record_session_trace=False, ) - raise SurgeryPipelineError( - "MINIO_NOT_CONFIGURED", - "服务端未配置 MinIO,无法保存语音追溯文件。", - ) - if not self._baidu.configured: - self._emit_voice_trace( + raise await self._emitter.fail( source="wav", status="baidu_not_configured", + code="BAIDU_NOT_CONFIGURED", + message="服务端未配置百度语音,无法进行语音识别。", surgery_id=surgery_id, confirmation_id=confirmation_id, - error_message="服务端未配置百度语音,无法进行语音识别。", - ) - raise SurgeryPipelineError( - "BAIDU_NOT_CONFIGURED", - "服务端未配置百度语音,无法进行语音识别。", + persist_audit=False, + record_session_trace=False, ) + # 3) fetch_pending pending = self._sessions.get_pending_confirmation_by_id( surgery_id, confirmation_id ) if pending is None: - self._emit_voice_trace( + raise await self._emitter.fail( source="wav", status="confirmation_not_found", + code="CONFIRMATION_NOT_FOUND", + message="未找到该待确认项或已处理。", surgery_id=surgery_id, confirmation_id=confirmation_id, - error_message="未找到该待确认项或已处理。", - ) - raise SurgeryPipelineError( - "CONFIRMATION_NOT_FOUND", - "未找到该待确认项或已处理。", + persist_audit=False, + record_session_trace=False, ) option_labels = [a.strip() for a, _ in pending.options if a.strip()] @@ -178,305 +170,86 @@ class VoiceConfirmationService: [{"label": a, "confidence": b} for a, b in pending.options], ensure_ascii=False, ) + session_trace = self._session_trace(surgery_id) - stored: StoredAudio | None = None - try: - await run_in_threadpool(self._minio.ensure_bucket) - stored = await run_in_threadpool( - lambda: self._minio.upload_voice_wav( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - data=wav_bytes, - content_type=content_type, - ) - ) - except Exception as exc: - logger.warning("MinIO upload failed: {}", exc) - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status="upload_failed", - audio_object_key=None, - audio_content_type=content_type, - audio_size_bytes=len(wav_bytes), - audio_sha256=None, - asr_text=None, - resolved_label=None, - options_snapshot_json=options_snapshot, - error_message=str(exc), - ) - self._sessions.record_voice_trace(surgery_id, asr_text=None, error=str(exc)) - self._emit_voice_trace( - source="wav", - status="upload_failed", - surgery_id=surgery_id, - confirmation_id=confirmation_id, - error_message=str(exc), - ) - raise SurgeryPipelineError( - "MINIO_UPLOAD_FAILED", - f"语音文件上传失败:{exc}", - ) from exc + # 4) upload_wav + stored = await self._upload_wav( + surgery_id=surgery_id, + confirmation_id=confirmation_id, + wav_bytes=wav_bytes, + content_type=content_type, + options_snapshot=options_snapshot, + session_trace=session_trace, + ) - try: - pcm = await run_in_threadpool(wav_bytes_to_pcm16k_mono_s16le, wav_bytes) - except WavDecodeError as exc: - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status="invalid_audio", - audio_object_key=stored.object_key, - audio_content_type=content_type, - audio_size_bytes=stored.size_bytes, - audio_sha256=stored.sha256_hex, - asr_text=None, - resolved_label=None, - options_snapshot_json=options_snapshot, - error_message=str(exc), - ) - self._sessions.record_voice_trace(surgery_id, asr_text=None, error=str(exc)) - self._emit_voice_trace( - source="wav", - status="invalid_audio", - surgery_id=surgery_id, - confirmation_id=confirmation_id, - error_message=str(exc), - audio_object_key=stored.object_key, - ) - raise SurgeryPipelineError( - "VOICE_AUDIO_INVALID", - f"无法解析 WAV 音频:{exc}", - ) from exc + audio_ctx = VoiceAuditContext( + audio_object_key=stored.object_key, + audio_content_type=content_type, + audio_size_bytes=stored.size_bytes, + audio_sha256=stored.sha256_hex, + ) - try: - asr_payload = await run_in_threadpool( - self._baidu.asr, pcm, "pcm", 16000, None - ) - except BaiduSpeechNotConfiguredError as exc: - self._emit_voice_trace( - source="wav", - status="baidu_not_configured", - surgery_id=surgery_id, - confirmation_id=confirmation_id, - error_message=str(exc), - audio_object_key=stored.object_key, - ) - raise SurgeryPipelineError( - "BAIDU_NOT_CONFIGURED", - str(exc), - ) from exc - except Exception as exc: - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status="asr_failed", - audio_object_key=stored.object_key, - audio_content_type=content_type, - audio_size_bytes=stored.size_bytes, - audio_sha256=stored.sha256_hex, - asr_text=None, - resolved_label=None, - options_snapshot_json=options_snapshot, - error_message=str(exc), - ) - self._sessions.record_voice_trace(surgery_id, asr_text=None, error=str(exc)) - self._emit_voice_trace( - source="wav", - status="asr_failed", - surgery_id=surgery_id, - confirmation_id=confirmation_id, - error_message=str(exc), - audio_object_key=stored.object_key, - ) - raise SurgeryPipelineError( - "VOICE_ASR_FAILED", - f"语音识别调用失败:{exc}", - ) from exc + # 5) decode_pcm + pcm = await self._decode_pcm( + surgery_id=surgery_id, + confirmation_id=confirmation_id, + wav_bytes=wav_bytes, + ctx=audio_ctx, + options_snapshot=options_snapshot, + session_trace=session_trace, + ) - if not isinstance(asr_payload, dict): - msg = "ASR 返回格式异常" - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status="asr_failed", - audio_object_key=stored.object_key, - audio_content_type=content_type, - audio_size_bytes=stored.size_bytes, - audio_sha256=stored.sha256_hex, - asr_text=None, - resolved_label=None, - options_snapshot_json=options_snapshot, - error_message=msg, - ) - self._sessions.record_voice_trace(surgery_id, asr_text=None, error=msg) - self._emit_voice_trace( - source="wav", - status="asr_failed", - surgery_id=surgery_id, - confirmation_id=confirmation_id, - error_message=msg, - audio_object_key=stored.object_key, - ) - raise SurgeryPipelineError("VOICE_ASR_FAILED", msg) - - if asr_payload.get("err_no") != 0: - msg = ( - f"asr_err_{asr_payload.get('err_no')}: " - f"{asr_payload.get('err_msg')}" - ) - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status="asr_failed", - audio_object_key=stored.object_key, - audio_content_type=content_type, - audio_size_bytes=stored.size_bytes, - audio_sha256=stored.sha256_hex, - asr_text=None, - resolved_label=None, - options_snapshot_json=options_snapshot, - error_message=msg, - ) - self._sessions.record_voice_trace(surgery_id, asr_text=None, error=msg) - self._emit_voice_trace( - source="wav", - status="asr_failed", - surgery_id=surgery_id, - confirmation_id=confirmation_id, - error_message=msg, - audio_object_key=stored.object_key, - ) - raise SurgeryPipelineError("VOICE_ASR_FAILED", msg) - - results = asr_payload.get("result") - text: str | None = None - if isinstance(results, list) and results: - text = str(results[0]) - elif isinstance(results, str): - text = results - text = (text or "").strip() - - if not text: - msg = "语音识别结果为空" - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status="asr_failed", - audio_object_key=stored.object_key, - audio_content_type=content_type, - audio_size_bytes=stored.size_bytes, - audio_sha256=stored.sha256_hex, - asr_text=None, - resolved_label=None, - options_snapshot_json=options_snapshot, - error_message=msg, - ) - self._sessions.record_voice_trace(surgery_id, asr_text=None, error=msg) - self._emit_voice_trace( - source="wav", - status="asr_failed", - surgery_id=surgery_id, - confirmation_id=confirmation_id, - error_message=msg, - audio_object_key=stored.object_key, - ) - raise SurgeryPipelineError("VOICE_ASR_FAILED", msg) + # 6) call_asr + asr_payload = await self._call_asr( + surgery_id=surgery_id, + confirmation_id=confirmation_id, + pcm=pcm, + ctx=audio_ctx, + options_snapshot=options_snapshot, + session_trace=session_trace, + ) + # 7) extract_text + text = await self._extract_text_from_asr( + surgery_id=surgery_id, + confirmation_id=confirmation_id, + asr_payload=asr_payload, + ctx=audio_ctx, + options_snapshot=options_snapshot, + session_trace=session_trace, + ) self._sessions.record_voice_trace(surgery_id, asr_text=text, error=None) - rejected = is_rejection_phrase(text) - chosen: str | None = None - if not rejected: - chosen = parse_voice_choice(text, option_labels) - if chosen is None: - surgery_candidates = self._sessions.get_surgery_candidate_consumables( - surgery_id - ) - chosen = match_voice_choice_against_candidates( - text, surgery_candidates - ) - - if not rejected and not chosen: - _, retry_remaining = await self._sessions.record_voice_parse_failure( - surgery_id, confirmation_id - ) - base = ( - "无法从语音中匹配候选项或本台手术候选清单中的耗材名称," - "请重试或说「不是」否认全部。" - ) - if retry_remaining > 0: - msg = ( - f"{base} 本次未听清或未能解析," - f"您还可重试 {retry_remaining} 次," - "请说「第一个」「第二个」等序号或候选项全名。" - ) - else: - msg = ( - f"{base} 本轮重试机会已用完," - "请再清晰地说序号/全名,或说「不是」否认全部。" - ) - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status="parse_failed", - audio_object_key=stored.object_key, - audio_content_type=content_type, - audio_size_bytes=stored.size_bytes, - audio_sha256=stored.sha256_hex, - asr_text=text, - resolved_label=None, - options_snapshot_json=options_snapshot, - error_message=msg, - ) - self._sessions.record_voice_trace(surgery_id, asr_text=text, error=msg) - self._emit_voice_trace( - source="wav", - status="parse_failed", - surgery_id=surgery_id, - confirmation_id=confirmation_id, - asr_text=text, - error_message=msg, - audio_object_key=stored.object_key, - ) - raise SurgeryPipelineError( - "VOICE_PARSE_FAILED", - msg, - extra={ - "confirmation_id": confirmation_id, - "retry_remaining": retry_remaining, - }, - ) + # 8) parse_choice + rejected, chosen = await self._parse_choice_or_fail( + source="wav", + surgery_id=surgery_id, + confirmation_id=confirmation_id, + text=text, + option_labels=option_labels, + options_snapshot=options_snapshot, + ctx=audio_ctx, + session_trace=session_trace, + ) + # 9) persist_success(含 session 内的 resolve_pending_confirmation) await self._sessions.resolve_pending_confirmation( surgery_id, confirmation_id, chosen_label=chosen, rejected=rejected, ) - final_status = "rejected" if rejected else "recognized" - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status=final_status, - audio_object_key=stored.object_key, - audio_content_type=content_type, - audio_size_bytes=stored.size_bytes, - audio_sha256=stored.sha256_hex, - asr_text=text, - resolved_label=chosen if not rejected else None, - options_snapshot_json=options_snapshot, - error_message=None, - ) - self._emit_voice_trace( + await self._emitter.success( source="wav", status=final_status, surgery_id=surgery_id, confirmation_id=confirmation_id, + ctx=audio_ctx, asr_text=text, resolved_label=chosen if not rejected else None, rejected=rejected, - audio_object_key=stored.object_key, + options_snapshot_json=options_snapshot, ) if rejected: @@ -502,21 +275,20 @@ class VoiceConfirmationService: confirmation_id: str, recognized_text: str, ) -> VoiceResolveResult: - """浏览器 Web Speech 等客户端本机识别后的文本,不经 MinIO/百度 ASR,解析规则与 `resolve_from_wav` 一致。""" + """浏览器本机识别文本,不经 MinIO/百度 ASR,解析规则与 `resolve_from_wav` 一致。""" pending = self._sessions.get_pending_confirmation_by_id( surgery_id, confirmation_id ) if pending is None: - self._emit_voice_trace( + raise await self._emitter.fail( source="text", status="confirmation_not_found", + code="CONFIRMATION_NOT_FOUND", + message="未找到该待确认项或已处理。", surgery_id=surgery_id, confirmation_id=confirmation_id, - error_message="未找到该待确认项或已处理。", - ) - raise SurgeryPipelineError( - "CONFIRMATION_NOT_FOUND", - "未找到该待确认项或已处理。", + persist_audit=False, + record_session_trace=False, ) option_labels = [a.strip() for a, _ in pending.options if a.strip()] @@ -524,93 +296,46 @@ class VoiceConfirmationService: [{"label": a, "confidence": b} for a, b in pending.options], ensure_ascii=False, ) + session_trace = self._session_trace(surgery_id) text = (recognized_text or "").strip() if not text: - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status="client_stt_empty", - audio_object_key=None, - audio_content_type=None, - audio_size_bytes=None, - audio_sha256=None, - asr_text=None, - resolved_label=None, - options_snapshot_json=options_snapshot, - error_message="客户端识别文本为空", - ) - self._sessions.record_voice_trace(surgery_id, asr_text=None, error="empty text") - self._emit_voice_trace( + raise await self._emitter.fail( source="text", status="client_stt_empty", + code="VOICE_TEXT_EMPTY", + message="客户端识别文本为空", surgery_id=surgery_id, confirmation_id=confirmation_id, - error_message="客户端识别文本为空", + options_snapshot_json=options_snapshot, + session_trace_recorder=lambda _a, _m: self._sessions.record_voice_trace( + surgery_id, asr_text=None, error="empty text" + ), ) - raise SurgeryPipelineError("VOICE_TEXT_EMPTY", "recognized_text 为空。") self._sessions.record_voice_trace(surgery_id, asr_text=text, error=None) - rejected = is_rejection_phrase(text) - chosen: str | None = None - if not rejected: - chosen = parse_voice_choice(text, option_labels) - if chosen is None: - surgery_candidates = self._sessions.get_surgery_candidate_consumables( - surgery_id - ) - chosen = match_voice_choice_against_candidates(text, surgery_candidates) - - if not rejected and not chosen: - _, retry_remaining = await self._sessions.record_voice_parse_failure( - surgery_id, confirmation_id - ) - base = ( + rejected, chosen = await self._parse_choice_or_fail( + source="text", + surgery_id=surgery_id, + confirmation_id=confirmation_id, + text=text, + option_labels=option_labels, + options_snapshot=options_snapshot, + ctx=VoiceAuditContext(), + session_trace=session_trace, + parse_status_on_failure="client_stt_parse_failed", + parse_message_prefix=( "无法从文本中匹配候选项或本台手术候选清单中的耗材名称," "请重试或说「不是」否认全部。" - ) - if retry_remaining > 0: - msg = ( - f"{base} 本次未能解析," - f"您还可重试 {retry_remaining} 次," - "请输入「第一个」「第二个」等或候选项全名。" - ) - else: - msg = ( - f"{base} 本轮重试机会已用完," - "请再输入序号/全名,或说「不是」否认全部。" - ) - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status="client_stt_parse_failed", - audio_object_key=None, - audio_content_type=None, - audio_size_bytes=None, - audio_sha256=None, - asr_text=text, - resolved_label=None, - options_snapshot_json=options_snapshot, - error_message=msg, - ) - self._sessions.record_voice_trace(surgery_id, asr_text=text, error=msg) - self._emit_voice_trace( - source="text", - status="client_stt_parse_failed", - surgery_id=surgery_id, - confirmation_id=confirmation_id, - asr_text=text, - error_message=msg, - ) - raise SurgeryPipelineError( - "VOICE_PARSE_FAILED", - msg, - extra={ - "confirmation_id": confirmation_id, - "retry_remaining": retry_remaining, - }, - ) + ), + parse_retry_hint_still=( + "请输入「第一个」「第二个」等或候选项全名。" + ), + parse_retry_hint_exhausted=( + "请再输入序号/全名,或说「不是」否认全部。" + ), + ) await self._sessions.resolve_pending_confirmation( surgery_id, @@ -618,22 +343,8 @@ class VoiceConfirmationService: chosen_label=chosen, rejected=rejected, ) - final_status = "rejected" if rejected else "recognized" - await self._persist_audit( - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status=final_status, - audio_object_key=None, - audio_content_type=None, - audio_size_bytes=None, - audio_sha256=None, - asr_text=text, - resolved_label=chosen if not rejected else None, - options_snapshot_json=options_snapshot, - error_message=None, - ) - self._emit_voice_trace( + await self._emitter.success( source="text", status=final_status, surgery_id=surgery_id, @@ -641,6 +352,7 @@ class VoiceConfirmationService: asr_text=text, resolved_label=chosen if not rejected else None, rejected=rejected, + options_snapshot_json=options_snapshot, ) if rejected: @@ -659,37 +371,233 @@ class VoiceConfirmationService: message="已确认并记一条消耗。", ) - async def _persist_audit( + # ------------------------------------------------------------------ + # 内部阶段 + # ------------------------------------------------------------------ + async def _upload_wav( self, *, surgery_id: str, confirmation_id: str, - status: str, - audio_object_key: str | None, - audio_content_type: str | None, - audio_size_bytes: int | None, - audio_sha256: str | None, - asr_text: str | None, - resolved_label: str | None, - options_snapshot_json: str | None, - error_message: str | None, - ) -> None: + wav_bytes: bytes, + content_type: str | None, + options_snapshot: str, + session_trace, + ) -> StoredAudio: try: - async with AsyncSessionLocal() as session: - async with session.begin(): - await self._audits.save_audit( - session, - surgery_id=surgery_id, - confirmation_id=confirmation_id, - status=status, - audio_object_key=audio_object_key, - audio_content_type=audio_content_type, - audio_size_bytes=audio_size_bytes, - audio_sha256=audio_sha256, - asr_text=asr_text, - resolved_label=resolved_label, - options_snapshot_json=options_snapshot_json, - error_message=error_message, - ) + await run_in_threadpool(self._minio.ensure_bucket) + return await run_in_threadpool( + lambda: self._minio.upload_voice_wav( + surgery_id=surgery_id, + confirmation_id=confirmation_id, + data=wav_bytes, + content_type=content_type, + ) + ) except Exception as exc: - logger.error("Persist voice audit failed: {}", exc) + logger.warning("MinIO upload failed: {}", exc) + raise await self._emitter.fail( + source="wav", + status="upload_failed", + code="MINIO_UPLOAD_FAILED", + message=f"语音文件上传失败:{exc}", + surgery_id=surgery_id, + confirmation_id=confirmation_id, + ctx=VoiceAuditContext( + audio_content_type=content_type, + audio_size_bytes=len(wav_bytes), + ), + options_snapshot_json=options_snapshot, + session_trace_recorder=session_trace, + ) from exc + + async def _decode_pcm( + self, + *, + surgery_id: str, + confirmation_id: str, + wav_bytes: bytes, + ctx: VoiceAuditContext, + options_snapshot: str, + session_trace, + ) -> bytes: + try: + return await run_in_threadpool(wav_bytes_to_pcm16k_mono_s16le, wav_bytes) + except WavDecodeError as exc: + raise await self._emitter.fail( + source="wav", + status="invalid_audio", + code="VOICE_AUDIO_INVALID", + message=f"无法解析 WAV 音频:{exc}", + surgery_id=surgery_id, + confirmation_id=confirmation_id, + ctx=ctx, + options_snapshot_json=options_snapshot, + session_trace_recorder=session_trace, + ) from exc + + async def _call_asr( + self, + *, + surgery_id: str, + confirmation_id: str, + pcm: bytes, + ctx: VoiceAuditContext, + options_snapshot: str, + session_trace, + ) -> object: + try: + return await run_in_threadpool(self._baidu.asr, pcm, "pcm", 16000, None) + except BaiduSpeechNotConfiguredError as exc: + raise await self._emitter.fail( + source="wav", + status="baidu_not_configured", + code="BAIDU_NOT_CONFIGURED", + message=str(exc), + surgery_id=surgery_id, + confirmation_id=confirmation_id, + ctx=ctx, + persist_audit=False, + record_session_trace=False, + ) from exc + except Exception as exc: + raise await self._emitter.fail( + source="wav", + status="asr_failed", + code="VOICE_ASR_FAILED", + message=f"语音识别调用失败:{exc}", + surgery_id=surgery_id, + confirmation_id=confirmation_id, + ctx=ctx, + options_snapshot_json=options_snapshot, + session_trace_recorder=session_trace, + ) from exc + + async def _extract_text_from_asr( + self, + *, + surgery_id: str, + confirmation_id: str, + asr_payload: object, + ctx: VoiceAuditContext, + options_snapshot: str, + session_trace, + ) -> str: + if not isinstance(asr_payload, dict): + raise await self._emitter.fail( + source="wav", + status="asr_failed", + code="VOICE_ASR_FAILED", + message="ASR 返回格式异常", + surgery_id=surgery_id, + confirmation_id=confirmation_id, + ctx=ctx, + options_snapshot_json=options_snapshot, + session_trace_recorder=session_trace, + ) + if asr_payload.get("err_no") != 0: + msg = ( + f"asr_err_{asr_payload.get('err_no')}: " + f"{asr_payload.get('err_msg')}" + ) + raise await self._emitter.fail( + source="wav", + status="asr_failed", + code="VOICE_ASR_FAILED", + message=msg, + surgery_id=surgery_id, + confirmation_id=confirmation_id, + ctx=ctx, + options_snapshot_json=options_snapshot, + session_trace_recorder=session_trace, + ) + + results = asr_payload.get("result") + text: str | None = None + if isinstance(results, list) and results: + text = str(results[0]) + elif isinstance(results, str): + text = results + text = (text or "").strip() + if not text: + raise await self._emitter.fail( + source="wav", + status="asr_failed", + code="VOICE_ASR_FAILED", + message="语音识别结果为空", + surgery_id=surgery_id, + confirmation_id=confirmation_id, + ctx=ctx, + options_snapshot_json=options_snapshot, + session_trace_recorder=session_trace, + ) + return text + + async def _parse_choice_or_fail( + self, + *, + source, + surgery_id: str, + confirmation_id: str, + text: str, + option_labels: list[str], + options_snapshot: str, + ctx: VoiceAuditContext, + session_trace, + parse_status_on_failure: str = "parse_failed", + parse_message_prefix: str = ( + "无法从语音中匹配候选项或本台手术候选清单中的耗材名称," + "请重试或说「不是」否认全部。" + ), + parse_retry_hint_still: str = ( + "请说「第一个」「第二个」等序号或候选项全名。" + ), + parse_retry_hint_exhausted: str = ( + "请再清晰地说序号/全名,或说「不是」否认全部。" + ), + ) -> tuple[bool, str | None]: + rejected = is_rejection_phrase(text) + chosen: str | None = None + if not rejected: + chosen = parse_voice_choice(text, option_labels) + if chosen is None: + candidates = self._sessions.get_surgery_candidate_consumables(surgery_id) + chosen = match_voice_choice_against_candidates(text, candidates) + + if rejected or chosen: + return rejected, chosen + + _, retry_remaining = await self._sessions.record_voice_parse_failure( + surgery_id, confirmation_id + ) + if retry_remaining > 0: + if source == "wav": + suffix = ( + f" 本次未听清或未能解析,您还可重试 {retry_remaining} 次," + f"{parse_retry_hint_still}" + ) + else: + suffix = ( + f" 本次未能解析,您还可重试 {retry_remaining} 次," + f"{parse_retry_hint_still}" + ) + else: + suffix = f" 本轮重试机会已用完,{parse_retry_hint_exhausted}" + msg = parse_message_prefix + suffix + + raise await self._emitter.fail( + source=source, + status=parse_status_on_failure, + code="VOICE_PARSE_FAILED", + message=msg, + surgery_id=surgery_id, + confirmation_id=confirmation_id, + ctx=ctx, + asr_text=text, + options_snapshot_json=options_snapshot, + session_trace_recorder=session_trace, + include_extra={ + "confirmation_id": confirmation_id, + "retry_remaining": retry_remaining, + }, + ) diff --git a/docs/client-api-integration.md b/docs/client-api-integration.md deleted file mode 100644 index 0fbd903..0000000 --- a/docs/client-api-integration.md +++ /dev/null @@ -1,213 +0,0 @@ -# 手术室监控服务 — 客户端 HTTP API 对接说明 - -本文档面向**集成我方 FastAPI 后端的客户系统**(HIS、手麻、护理工作站、自研终端等),说明如何通过 HTTP 调用完成「手术开始 / 结束」「耗材结果查询」「低置信度耗材语音确认」等能力。 - -> **说明**:仓库内 `scripts/demo_client/` 等演示页面仅用于**内部联调与测试**,不代表生产对接规范;生产环境请按本文档与 OpenAPI 契约自行实现客户端。 - ---- - -## 1. 服务与发现 - - -| 项目 | 说明 | -| ------- | ------------------------------------------------------------- | -| 默认监听 | 应用默认 `0.0.0.0:38080`(以实际部署为准,可能经反向代理改写路径或端口) | -| 基础路径 | 路由**无全局前缀**;下文路径均为相对服务根路径 | -| OpenAPI | 服务启动后可访问 `**/docs`(Swagger UI)**、`**/redoc`** 获取实时 Schema 与试调 | -| 健康检查 | `GET /health`:探活与数据库连通性(降级时可能返回 503) | -| 跨域 CORS | 仅当服务端开启演示用 CORS 配置时对浏览器页面生效;**服务端对接通常不受 CORS 限制** | - - -认证方式以部署约定为准;当前公开路由未在文档层强制 API Key,若贵方环境增加了网关鉴权,请在请求头中按网关要求携带。 - ---- - -## 2. 客户端 API 一览(`/client/...`) - -所有「客户集成」接口均位于 `**/client`** 命名空间下。 - - -| 方法 | 路径 | 摘要 | -| ------ | ------------------------------------------------------------------------------- | ------------------------- | -| `POST` | `/client/surgeries/start` | 开始手术:确认摄像头开录成功后返回 | -| `POST` | `/client/surgeries/end` | 结束手术:确认停录成功后返回 | -| `GET` | `/client/surgeries/{surgery_id}/result` | 查询该台手术的耗材明细与汇总 | -| `GET` | `/client/surgeries/{surgery_id}/pending-confirmation` | 拉取一条待确认项(含 TTS 音频 Base64) | -| `POST` | `/client/surgeries/{surgery_id}/pending-confirmation/{confirmation_id}/resolve` | 上传医生语音 WAV,完成确认或否认 | - - -路径参数 `**surgery_id`**:固定 **6 位数字**(正则 `^\d{6}$`)。 - ---- - -## 3. 端到端业务流程(推荐时序) - -以下为客户系统与手术室监控服务之间的**推荐调用顺序**与并行关系。 - -```mermaid -sequenceDiagram - participant Client as 客户系统 - participant API as 监控服务 API - - Client->>API: POST /client/surgeries/start - Note over Client,API: body: surgery_id, camera_ids, candidate_consumables - API-->>Client: 200 accepted(开录已确认) - - par 术中轮询 - loop 按需轮询 - Client->>API: GET .../result - API-->>Client: 200 明细+汇总 或 503 RESULT_NOT_READY - end - loop 语音确认闭环(若启用) - Client->>API: GET .../pending-confirmation - API-->>Client: 200 待确认+MP3 或 404 无待确认 - opt 有待确认 - Client->>Client: 播放 prompt_audio_mp3_base64 - Client->>API: POST .../resolve(multipart audio) - API-->>Client: 200 accepted - end - end - end - - Client->>API: POST /client/surgeries/end - API-->>Client: 200 accepted(停录已确认) - - Client->>API: GET .../result - API-->>Client: 200 归档后持久化结果(若可查) -``` - - - -### 3.1 手术生命周期(状态视角) - -```mermaid -flowchart LR - A[未开始] -->|POST start 成功| B[录制中 / 推理中] - B -->|GET result 可 200| C[可查消耗] - B -->|GET pending 可 200| D[有待确认] - D -->|POST resolve| B - B -->|POST end 成功| E[已结束] - E -->|GET result| C -``` - - - ---- - -## 4. 接口说明与请求/响应要点 - -### 4.1 `POST /client/surgeries/start` - -**Content-Type**:`application/json` - -**请求体(摘要)** - - -| 字段 | 类型 | 说明 | -| ----------------------- | -------- | ------------------------------------------------------------- | -| `surgery_id` | string | 6 位数字手术号 | -| `camera_ids` | string[] | 至少 1 个;需与贵方 RTSP/SDK 映射中的 **摄像头 ID** 一致 | -| `candidate_consumables` | string[] | 可选;**本台手术可能用到的耗材名称清单**。服务端仅对该清单内耗材做自动记账与待确认追问;为空则只做拉流推理、不写入消耗 | - - -**成功(200)**:`SurgeryApiResponse` — `status` 一般为 `accepted`,表示**服务端已确认开录完成**。 - -**失败**:常见 `**503 Service Unavailable`**,`detail` 内含业务 `code`(如录制子系统未就绪、开录未确认等)。开录/停录类错误会按服务端配置**自动重试**;仍失败则返回最后一次错误信息。 - -### 4.2 `POST /client/surgeries/end` - -**Content-Type**:`application/json` - -**请求体**:`{ "surgery_id": "123456" }` - -**成功(200)**:停录已确认。 - -**失败**:`503` 同 start,表示未能在确认摄像头全部停录后完成请求。 - -### 4.3 `GET /client/surgeries/{surgery_id}/result` - -**幂等只读**。术中返回当前内存已记账结果;结束后返回数据库持久化结果(以服务端实现为准)。 - -**成功(200)**:`SurgeryResultResponse` - -- `details`:多行消耗明细,字段顺序为 `**item_id` → `item_name` → `qty` → `doctor_id` → `timestamp`** -- `summary`:按 `item_id` 汇总的 `total_quantity` - -`**503`**:`detail.code === "RESULT_NOT_READY"` — 尚无该手术可查询结果(未开始、未成功开录或暂无可返回数据)。 - -### 4.4 `GET /client/surgeries/{surgery_id}/pending-confirmation` - -用于**低置信度识别**的人工确认闭环(需服务端启用语音确认及相关配置)。 - -**成功(200)**:包含 `confirmation_id`、`prompt_text`、候选项 `options`、`prompt_audio_mp3_base64`(与话术一致的 **MP3 标准 Base64**,无换行)、模型 Top1 等。 - -`**404`**:当前无待确认或手术未在进行(`NO_PENDING_CONFIRMATION`)。 - -`**422` / `503`**:话术为空、音频无效、ASR/TTS/MinIO/百度语音等异常时返回,详见响应 `detail.code`。 - -### 4.5 `POST /client/surgeries/{surgery_id}/pending-confirmation/{confirmation_id}/resolve` - -**Content-Type**:`multipart/form-data` - -**表单字段** - - -| 字段 | 说明 | -| ------- | -------------------------------------------------------- | -| `audio` | 单个 `**.wav`** 文件;建议 16 kHz 单声道 PCM;其他格式服务端可能尝试 ffmpeg 转码 | - - -**成功(200)**:`SurgeryPendingConfirmationResolveResponse` — 含 `resolved_label`、`rejected`、`asr_text`、`audio_object_key` 等。 - -**常见 HTTP 状态**:`404`(确认项不存在或已失效)、`409`(已处理)、`422`(音频/解析问题)、`503`(外部依赖故障)。 - ---- - -## 5. 错误约定 - -FastAPI 对 `HTTPException` 的 JSON 外形一般为: - -```json -{ - "detail": { - "code": "业务错误码", - "message": "人类可读说明", - "surgery_id": "123456" - } -} -``` - -部分错误可能在 `detail` 中附带额外字段(如重试剩余次数),请以实际响应与 OpenAPI 为准。 - -**校验错误(422)**:请求体或路径不符合 Pydantic 校验时,FastAPI 可能返回标准 `422` 验证错误结构(与上表「对象型 detail」不同),请客户端分别解析。 - ---- - -## 6. 与内部演示能力的关系(非客户必选) - -以下路由用于**内部一键联调**,客户生产系统**无需依赖**: - - -| 路径 | 说明 | -| ------------------------------------------- | -------------------------------------------------- | -| `GET /internal/demo/orchestrator-status` | 探测演示编排是否开启、RTSP 配置文件是否设置等 | -| `POST /internal/demo/orchestrate-and-start` | 仅在服务端开启 `DEMO_ORCHESTRATOR_ENABLED` 时注册;用于演示环境串联开录 | - - -客户正式对接应直接调用 `**/client/...`**,并在贵方环境中配置真实的 `camera_ids` 与视频后端映射。 - ---- - -## 7. 联调建议 - -1. 先调用 `**GET /health`** 确认服务与数据库可用。 -2. 用 `**POST /client/surgeries/start`** 验证 `camera_ids` 与现场 RTSP/SDK 配置一致,避免 503。 -3. `**candidate_consumables**` 与实际手术耗材名称尽量与院内目录或模型标签对齐,减少待确认次数。 -4. 结果查询 `**503**` 时建议**退避重试**(术中数据尚未就绪属正常现象)。 -5. 以 `**/docs`** 导出或对照契约测试,与贵方 CI 中的契约测试对齐。 - ---- - -## 8. 文档修订 - -接口行为以部署实例的 **OpenAPI(`/docs`)** 与代码为准;字段含义补充见仓库内 `app/schemas.py` 中各模型的 `description`。 \ No newline at end of file diff --git a/docs/staging-regression-checklist.md b/docs/staging-regression-checklist.md index 9378576..97d24be 100644 --- a/docs/staging-regression-checklist.md +++ b/docs/staging-regression-checklist.md @@ -4,42 +4,37 @@ ## 环境 -- [ ] `GET /health` 返回 `200`,`database: connected` -- [ ] 环境变量:`VIDEO_RTSP_URLS_JSON` 或 `VIDEO_RTSP_URLS_JSON_FILE` 与客户端 `camera_ids` 一致 -- [ ] `MINIO_*`、`BAIDU_SPEECH_*` 已配置(语音确认链路) -- [ ] 模型权重路径可读(容器内挂载 `app/resources/*.pt`) +- `GET /health` 返回 `200`,`database: connected` +- 环境变量:`VIDEO_RTSP_URLS_JSON` 或 `VIDEO_RTSP_URLS_JSON_FILE` 与客户端 `camera_ids` 一致 +- `MINIO_`*、`BAIDU_SPEECH_*` 已配置(语音确认链路) +- 模型权重路径可读(容器内挂载 `app/resources/*.pt`) ## 主流程 -1. **开始手术** `POST /client/surgeries/start` - - [ ] 请求体含 6 位 `surgery_id`、`camera_ids`、`candidate_consumables`(非空才会记账) - - [ ] 返回 `200`,日志中各路 RTSP 首帧就绪 - -2. **进行中查询(可选)** `GET /client/surgeries/{id}/result` - - [ ] 在已有至少一条明细后返回 `200`;仅开录尚无明细时可能 `503 RESULT_NOT_READY`(与实现一致) - -3. **低置信追问** - - [ ] `GET /client/surgeries/{id}/pending-confirmation` 有任务时 `200`,含 `prompt_text`、`options` - - [ ] 客户端对 `prompt_text` **TTS 播报**,采集医生回答为 **WAV** - - [ ] `POST .../pending-confirmation/{confirmation_id}/resolve`,`multipart` 字段名 `audio` - - [ ] 确认后明细中出现 `source=voice`;否认不增加明细 - -4. **结束手术** `POST /client/surgeries/end` - - [ ] 返回 `200`,摄像头任务停止 - -5. **最终结果** `GET /client/surgeries/{id}/result` - - [ ] 返回 `200`,`details` / `summary` 与术中所见一致 - -6. **数据库** - - [ ] `surgery_final_results` / `surgery_result_details` 有对应 `surgery_id` - - [ ] `voice_confirmation_audits` 在语音确认路径有追溯行(成功/失败分支视联调覆盖而定) +1. **开始手术** `POST /client/surgeries/start` + - 请求体含 6 位 `surgery_id`、`camera_ids`;`candidate_consumables` 可空(空则全量目录/模型类名) + - 返回 `200`,日志中各路 RTSP 首帧就绪 +2. **进行中查询(可选)** `GET /client/surgeries/{id}/result` + - 至少一条明细时 `200`;无明细(开录后尚无、已归档但零消耗等)为 `503 RESULT_NOT_READY` +3. **低置信追问** + - `GET /client/surgeries/{id}/pending-confirmation` 有任务时 `200`,含 `prompt_text`、`options` + - 客户端对 `prompt_text` **TTS 播报**,采集医生回答为 **WAV** + - `POST .../pending-confirmation/{confirmation_id}/resolve`,`multipart` 字段名 `audio` + - 确认后明细中出现 `source=voice`;否认不增加明细 +4. **结束手术** `POST /client/surgeries/end` + - 返回 `200`,摄像头任务停止 +5. **最终结果** `GET /client/surgeries/{id}/result` + - 有明细时 `200`,`details` / `summary` 与术中所见一致;整台手术无任何明细时为 `503 RESULT_NOT_READY` +6. **数据库** + - `surgery_final_results` / `surgery_result_details` 有对应 `surgery_id` + - `voice_confirmation_audits` 在语音确认路径有追溯行(成功/失败分支视联调覆盖而定) ## 失败与重试(抽样) -- [ ] RTSP 不可达:`start` 最终 `503`,消息含开录失败说明 -- [ ] MinIO 不可用:`resolve` 返回 `503` 或业务码 `MINIO_*` -- [ ] 停录后写库失败:服务日志提示归档;后台重试或修复 DB 后可再次 `start` 同号前会先尝试落归档(见接口说明) +- RTSP 不可达:`start` 最终 `503`,消息含开录失败说明 +- MinIO 不可用:`resolve` 返回 `503` 或业务码 `MINIO_`* +- 停录后写库失败:服务日志提示归档;后台重试或修复 DB 后可再次 `start` 同号前会先尝试落归档(见接口说明) ## 与文档 -- 客户端集成以 **OpenAPI**(`/docs`)与 [客户端手术通信接口说明](./客户端手术通信接口说明.md) 为准;**待确认 resolve 为 multipart WAV**,非 JSON `chosen_label`。 +- 客户端集成以 **OpenAPI**(`/docs`)与 [客户端手术通信接口说明](./客户端手术通信接口说明.md) 为准;**待确认 resolve 为 multipart WAV**,非 JSON `chosen_label`。 \ No newline at end of file diff --git a/docs/video-backends.md b/docs/video-backends.md index 6094a8d..b7b5fcb 100644 --- a/docs/video-backends.md +++ b/docs/video-backends.md @@ -8,10 +8,10 @@ ## RTSP 模式(默认) -1. 配置 **`camera_id` → RTSP URL** 映射,任选其一或组合使用: - - **`VIDEO_RTSP_URLS_JSON_FILE`**:指向 UTF-8 JSON 文件(对象键为与请求一致的 `camera_id`)。仓库示例:[`app/resources/camera_rtsp_urls.sample.json`](../app/resources/camera_rtsp_urls.sample.json)(示例 ID:`or-cam-01`、`or-cam-02`)。 - - **`VIDEO_RTSP_URLS_JSON`**:内联 JSON 字符串;与文件合并时**覆盖同键**。 - - **`VIDEO_RTSP_URL_TEMPLATE`**:单模板,可用 `{camera_id}`。 +1. 配置 `**camera_id` → RTSP URL** 映射,任选其一或组合使用: + - `**VIDEO_RTSP_URLS_JSON_FILE`**:指向 UTF-8 JSON 文件(对象键为与请求一致的 `camera_id`)。仓库示例:`[app/resources/camera_rtsp_urls.sample.json](../app/resources/camera_rtsp_urls.sample.json)`(示例 ID:`or-cam-01`、`or-cam-02`)。 + - `**VIDEO_RTSP_URLS_JSON**`:内联 JSON 字符串;与文件合并时**覆盖同键**。 + - `**VIDEO_RTSP_URL_TEMPLATE`**:单模板,可用 `{camera_id}`。 2. 调用 `POST /client/surgeries/start` 时,`camera_ids` 必须能在上述配置中解析出 RTSP 地址。 3. **开录确认**:每路摄像头在超时内成功打开并读到**首帧**后,才认为该路已开录。 @@ -26,20 +26,20 @@ SDK **不作为构建期依赖**:将厂商提供的 Linux x86_64 动态库挂 行为概要: 1. 进程内对 `NET_DVR_Init` 使用引用计数;每路使用 SDK 的工作线程在登录后 `NET_DVR_Logout`,线程结束时配对 `NET_DVR_Cleanup`。 -2. 若 `HIKVISION_SDK_FALLBACK_TO_RTSP=true`(默认),在**无法加载动态库**、**登录失败**或**未配置凭据**时,自动回退到 `VIDEO_RTSP_*` 映射拉流。 +2. 若 `HIKVISION_SDK_FALLBACK_TO_RTSP=true`(默认),在**无法加载动态库**、**登录失败**或**未配置凭据**时,自动回退到 `VIDEO_RTSP_`* 映射拉流。 **注意**:`NET_DVR_Login_V30` 的设备信息结构体在不同 SDK 版本上可能存在差异;若登录异常,请优先使用 RTSP 回退或按厂商文档校对 ctypes 绑定。 ## 推理与结果查询 - 开录后按 `VIDEO_INFERENCE_INTERVAL_SEC` 抽帧,依次调用耗材分类与撕扯动作模型。 -- **候选耗材清单**(开始手术请求体中的 `candidate_consumables`)为**硬约束**:若为空,服务端**不会**写入任何消耗明细(仅拉流推理);非空时仅允许清单内标签自动记账。 +- **候选耗材清单**(`candidate_consumables`):非空时**仅**清单内名称参与自动记账与待确认;**缺省或 `[]`** 时,用耗材目录 Excel **全部商品名**作为候选;无目录则用分类模型**全部类名**。 - 当分类 Top1 置信度 **≥** `VIDEO_AUTO_CONFIRM_CONFIDENCE`(**默认 0.9**)且标签在候选清单内时,自动写入一条 `source=vision` 的消耗明细;**低于**该线的识别需人工确认(在语音下沿之上且能展示候选项时入队)。 -- 置信度在 \[`VIDEO_VOICE_CONFIRM_MIN_CONFIDENCE`, `VIDEO_AUTO_CONFIRM_CONFIDENCE`\) 等区间且存在可向医生展示的候选时,会生成**待确认**任务;客户端 `GET /client/surgeries/{surgery_id}/pending-confirmation`,确认后 `POST .../pending-confirmation/{id}/resolve` 等。 -- 已有至少一条消耗明细后,`GET /client/surgeries/{surgery_id}/result` 返回 200;若已开录但尚未产生任何明细,返回 503 `RESULT_NOT_READY`。 +- 置信度在 `VIDEO_VOICE_CONFIRM_MIN_CONFIDENCE`, `VIDEO_AUTO_CONFIRM_CONFIDENCE` 等区间且存在可向医生展示的候选时,会生成**待确认**任务;客户端 `GET /client/surgeries/{surgery_id}/pending-confirmation`,确认后 `POST .../pending-confirmation/{id}/resolve` 等。 +- `GET /client/surgeries/{surgery_id}/result` 仅在存在**至少一条**消耗明细时返回 200;无明细(已开录但尚未记账、已结束但零消耗、或尚无归档等)返回 503 `RESULT_NOT_READY`。 - 同类物品写入受 `VIDEO_DETAIL_COOLDOWN_SEC` 节流。 - RTSP 读帧连续失败达到 `VIDEO_READ_FAILURE_RECONNECT_THRESHOLD` 时会 `release` 并尝试重连,间隔 `VIDEO_RECONNECT_BACKOFF_SECONDS`。 ## 相关环境变量 -详见仓库根目录 `.env.example` 中「视频:RTSP + 可选海康 HCNetSDK」一节。 +详见仓库根目录 `.env.example` 中「视频:RTSP + 可选海康 HCNetSDK」一节。 \ No newline at end of file diff --git a/docs/客户端手术通信接口说明.md b/docs/客户端手术通信接口说明.md index 660861b..7fd97f6 100644 --- a/docs/客户端手术通信接口说明.md +++ b/docs/客户端手术通信接口说明.md @@ -1,86 +1,320 @@ -# 手术室监控服务 · 客户端手术通信接口说明 + -| 能力 | 说明 | -| ----- | ------------------------------------------------------ | -| 开始手术 | 请求开始手术;服务端启动摄像头录制,**仅在确认开录完成后**返回 HTTP 200。 | -| 结束手术 | 请求结束手术;服务端停止摄像头录制,**仅在确认停录完成后**返回 HTTP 200。 | -| 查询结果 | 根据手术 6 位号查询消耗明细与汇总;**仅在已开录且至少已有一条消耗明细后**返回 HTTP 200。 | -| 待确认耗材 | 低置信度时服务端排队一条待确认任务;客户端拉取话术(TTS)并在医生确认后回传,**不阻塞**后续视频推理。 | +# 手术室监控服务:客户端手术通信接口说明 +面向对接本 FastAPI 服务的客户端(HIS、手麻、工作站等)。字段、状态码与返回模型以 `/docs` 和 `/openapi.json` 为准,本文用于摘要和流程说明。 -**约定:** +> [!summary] 常用响应模型 +> - `SurgeryApiResponse` +> - `SurgeryResultResponse` +> - `SurgeryPendingConfirmationResponse` +> - 业务错误外形见 `SurgeryClientErrorResponse` +> - 内部演示页:`scripts/demo_client/`(仅供演示,不作为对外契约) -- **开始 / 结束** 使用 `POST`,请求体为 **JSON**(`Content-Type: application/json`)。 -- **查询结果** 使用 `GET`,**无请求体**;手术号放在 **URL 路径** 中(见 4.3),符合「只读资源用 GET」的惯例。 -- 手术标识 `**surgery_id`**:必须为 **恰好 6 位数字**(正则 `^\d{6}$`),例如 `123456`。 +## 能力概览 ---- +- 探活:`GET /health`,用于检查进程和数据库状态,详见 5.1 节。 +- 开始手术:`POST /client/surgeries/start`,只有在开录确认成功后才返回 `200`。 +- 结束手术:`POST /client/surgeries/end`,只有在停录确认成功后才返回 `200`。 +- 查询结果:`GET /client/surgeries/{surgery_id}/result`,至少存在一条消耗明细时返回 `200`;否则返回 `503`,常见错误码为 `RESULT_NOT_READY`。 +- 待确认播报:`GET /client/surgeries/{surgery_id}/pending-confirmation`,拉取队首低置信度任务,返回话术文本和 MP3 Base64。 +- 待确认答复:`POST /client/surgeries/{surgery_id}/pending-confirmation/{confirmation_id}/resolve`,上传医生答复的 WAV 录音,服务端完成 ASR 后入账或关闭。该录音与播报音频无关。 -## 2. 基础信息 +> [!important] HTTP 约定 +> - `start` 和 `end` 使用 `POST` + `application/json` +> - `result` 使用 `GET` +> - `resolve` 使用 `POST` + `multipart/form-data` +> - `surgery_id` 固定为 6 位数字,正则为 `^\d{6}$` +> - `resolve` 路径中的 `confirmation_id` 必须与待确认接口返回值一致 +> - `camera_ids` 必须与第 2 节清单及运维配置完全一致 +## 1. 服务与基础信息 -| 项目 | 说明 | -| ---------- | ------------------------------------------------------- | -| 协议 | HTTP/HTTPS | -| 请求体格式 | 开始/结束:`application/json`;查询结果:无 body | -| 响应体格式 | JSON | -| 路径前缀 | 服务端根路径下直接挂载,例如 `https://<主机>:<端口>/client/surgeries/...` | -| 默认服务端口(开发) | `38080`(以实际部署为准) | +- 协议:`HTTP/HTTPS` +- 端口:`38080`,生产环境以实际入口为准 +- 路由:无全局前缀;业务接口位于 `/client/...`,健康检查位于 `/health` +- `start` / `end` 请求体:JSON +- `resolve` 请求体:`multipart/form-data`,字段名为 `audio` +- 在线文档:`/docs`、`/redoc` +## 2. 摄像头 ID 与 RTSP -> **说明:** 若生产环境存在网关或反向代理,请将上表中的「主机、端口、是否 HTTPS」替换为对外统一入口地址。 +RTSP 地址、账号、口令等由客户端对接工程师提供给服务端运维,运维再写入服务端环境(例如 JSON 映射或环境变量)。业务程序不在客户端保存 RTSP,客户端只在 `POST /client/surgeries/start` 中传 `camera_ids`。 ---- +配置格式示例见 `app/resources/camera_rtsp_urls.sample.json`。配置项细节见 `.env.example` 与 `docs/video-backends.md`。 -## 3. 接口列表 +**摄像头映射示例** +- `or-cam-01` + - RTSP:`rtsp://...`(由现场或 NVR 文档整理后交给运维) + - 备注:术间、机位 +- `or-cam-02` + - RTSP:`...` + - 备注:`...` -| 序号 | 方法 | 路径 | 说明 | -| --- | ------ | ------------------------------------------------------------------------------- | --------- | -| 1 | `POST` | `/client/surgeries/start` | 开始手术 | -| 2 | `POST` | `/client/surgeries/end` | 结束手术 | -| 3 | `GET` | `/client/surgeries/{surgery_id}/result` | 查询手术结果 | -| 4 | `GET` | `/client/surgeries/{surgery_id}/pending-confirmation` | 拉取一条待确认耗材 | -| 5 | `POST` | `/client/surgeries/{surgery_id}/pending-confirmation/{confirmation_id}/resolve` | 提交医生确认结果 | +> [!warning] 对接要求 +> - `camera_ids` 必须与运维配置中的 key 完全一致 +> - RTSP 不应硬编码在客户端业务程序中 +> [!tip] 联调建议 +> - 运维配置完成后,客户端使用上面清单中的 `camera_id` 调用 `start` 验证是否返回 `200` +> - 若返回 `503` 且 `detail.code = RECORDING_CANNOT_START`,优先核对 ID 拼写以及监控服务器侧网络连通性 ---- +## 3. HTTP 路由一览 -## 4. 接口详情 +1. `GET /health`:探活 +2. `POST /client/surgeries/start`:开始手术 +3. `POST /client/surgeries/end`:结束手术 +4. `GET /client/surgeries/{surgery_id}/result`:查询手术结果 +5. `GET /client/surgeries/{surgery_id}/pending-confirmation`:拉取待确认耗材 +6. `POST /client/surgeries/{surgery_id}/pending-confirmation/{confirmation_id}/resolve`:提交待确认结果(WAV) -### 4.1 开始手术 +## 4. 流程 -**用途:** 在手术开始时,由客户端向服务端上报手术编号、参与采集的摄像头,以及本台手术可能涉及的耗材清单;**服务端启动关联摄像头录制**。 +### 4.1 时序图 -**成功条件(HTTP 200):** 仅在服务端**确认摄像头已开始录制**之后,才返回 **HTTP 200**。不得在「仅收到请求、尚未开录」时返回 200。 +```mermaid +sequenceDiagram + participant Client as 客户系统 + participant API as 监控服务 API + Client->>API: POST /client/surgeries/start + Note over Client,API: body: surgery_id, camera_ids, candidate_consumables + API-->>Client: 200 accepted(开录已确认) -| 项目 | 内容 | -| --- | ------------------------- | -| 方法 | `POST` | -| 路径 | `/client/surgeries/start` | + par 术中 + loop 轮询结果 + Client->>API: GET .../result + API-->>Client: 200 或 503 RESULT_NOT_READY + end + loop 轮询待确认(若启用) + Client->>API: GET .../pending-confirmation + API-->>Client: 200 或 404 + opt 有待确认 + Client->>Client: 播放 prompt_audio_mp3_base64 + Client->>API: POST .../resolve(multipart audio) + API-->>Client: 200 accepted + end + end + end + Client->>API: POST /client/surgeries/end + API-->>Client: 200 accepted(停录已确认) -**请求体字段:** + Client->>API: GET .../result + API-->>Client: 200(持久化后可查时返回) +``` +### 4.2 状态图 -| 字段名 | 类型 | 必填 | 说明 | -| ----------------------- | ---------- | --- | ----------------------------------------------------------------------------------------- | -| `surgery_id` | `string` | 是 | 手术 6 位号,必须为 6 位数字。 | -| `camera_ids` | `string[]` | 是 | 摄像头 ID 列表,至少 1 个元素;须与服务端配置的 RTSP 映射键一致(示例见 `app/resources/camera_rtsp_urls.sample.json`)。 | -| `candidate_consumables` | `string[]` | 否 | 本台手术允许记账的耗材名称清单。**为空或未传则不会写入任何消耗**(仅拉流推理);非空时自动记账与待确认仅针对清单内名称。 | +```mermaid +flowchart LR + A[未开始] -->|start 200| B[录制 / 推理中] + B -->|result 200| C[有消耗数据可查] + B -->|pending 200| D[待确认] + D -->|resolve| B + B -->|end 200| E[已结束] + E -->|result 200| C +``` +## 5. 接口详情 -**说明:** 若该 `surgery_id` 在服务端仍存在**尚未写入数据库**的上一台手术内存归档,开始新会话前会先尝试落库;落库失败则返回 **503**(`RECORDING_CANNOT_START`),避免静默丢失数据。 +以下按“基本信息 -> 请求 -> 响应 -> 状态码”组织,与 OpenAPI 中 `tags: client` 和 `health` 一致。 -**请求示例:** +### 5.0 通用约定 + +**路径参数 `surgery_id`** + +- 长度:固定 `6` +- 字符集:仅数字 +- 正则:`^\d{6}$` + +**业务错误响应** + +多数业务失败在 `4xx` 或 `5xx` 下返回如下 JSON: + +```json +{ + "detail": { + "code": "错误码字符串", + "message": "人类可读说明", + "surgery_id": "123456" + } +} +``` + +### 5.1 探活 + +**基本信息** + +- 方法:`GET` +- 路径:`/health` +- 请求体:无 + +**响应说明** + +- `200` + - 说明:进程正常且数据库可连通 + - 响应体示例:`{"status":"ok","database":"connected"}` +- `503` + - 说明:数据库不可用(降级) + - 响应体示例:`{"status":"degraded","database":"unavailable"}` + +### 5.2 开始手术 + +**基本信息** + +- 方法:`POST` +- 路径:`/client/surgeries/start` +- Content-Type:`application/json; charset=utf-8` + +**业务说明** + +- 服务端会为 `camera_ids` 中的每个摄像头建立拉流与推理任务,只有在确认开录成功(如首帧就绪)后才返回 HTTP `200` +- 若同一 `surgery_id` 存在尚未落库的历史归档,服务端会先尝试写入数据库;失败时可能返回 `503`(如 `RECORDING_CANNOT_START`),以避免静默丢数据 +- `candidate_consumables` 为空时,服务端会展开为目录 Excel 中的全部商品名,或在未配置目录时展开为分类模型的全部类名 + +**请求体(JSON)** + +- `surgery_id` + - 类型:`string` + - 必填:是 + - 说明:6 位数字,与路径规则一致 +- `camera_ids` + - 类型:`string[]` + - 必填:是 + - 说明:至少 1 个;必须与运维配置的摄像头 ID 完全一致,见第 2 节 +- `candidate_consumables` + - 类型:`string[]` + - 必填:否 + - 说明:非空时仅这些名称参与自动记账与待确认;缺省或 `[]` 时使用全部候选 + +**响应体(200)** + +- `surgery_id` + - 类型:`string` + - 说明:与请求一致 +- `status` + - 类型:`string` + - 说明:成功时通常为 `accepted` +- `message` + - 类型:`string` + - 说明:说明文案 + +**状态码** + +- `200`:开录已确认 +- `422`:参数校验失败,例如 `surgery_id` 非 6 位或 `camera_ids` 为空数组 +- `503`:开录未确认或录制子系统故障;`detail.code` 常见为 `RECORDING_CANNOT_START` + +**请求示例** ```json { @@ -90,17 +324,7 @@ } ``` -**成功响应(HTTP 200):** 表示开录已确认。 - - -| 字段名 | 类型 | 说明 | -| ------------ | -------- | ---------------------------- | -| `surgery_id` | `string` | 回显手术 6 位号。 | -| `status` | `string` | 处理状态(例如 `accepted` 表示开录已确认)。 | -| `message` | `string` | 人类可读的说明文案。 | - - -**响应示例:** +**响应示例(200)** ```json { @@ -110,34 +334,36 @@ } ``` -**重试:** 开录调用失败时,服务端会按配置**自动重试**若干次(间隔若干秒);**全部尝试仍失败**后再返回 **HTTP 503**。环境变量:`SURGERY_RECORDING_MAX_ATTEMPTS`(默认 3,含首次)、`SURGERY_RECORDING_RETRY_DELAY_SECONDS`(默认 `1.0`)。 +### 5.3 结束手术 -**失败响应(HTTP 503):** 重试用尽仍无法在约定条件下确认开录时返回。响应体见 **§5.2**(OpenAPI 模型 `SurgeryClientErrorResponse`,错误码示例:`RECORDING_CANNOT_START`);`detail.message` 中会注明已重试次数。 +**基本信息** ---- +- 方法:`POST` +- 路径:`/client/surgeries/end` +- Content-Type:`application/json; charset=utf-8` -### 4.2 结束手术 +**业务说明** -**用途:** 在手术结束时,由客户端请求服务端结束该 `surgery_id` 对应手术:**服务端须停止关联摄像头的录制**。 +停止该 `surgery_id` 关联的全部摄像头任务,只有在确认停录完成后才返回 `200`。 -**成功条件(HTTP 200):** 仅在服务端**确认所有关联摄像头已停止录制**之后,才返回 **HTTP 200**。不得在「仅收到请求、尚未停录」时返回 200。 +**请求体(JSON)** +- `surgery_id` + - 类型:`string` + - 必填:是 + - 说明:6 位数字 -| 项目 | 内容 | -| --- | ----------------------- | -| 方法 | `POST` | -| 路径 | `/client/surgeries/end` | +**响应体(200)** +字段含义与 5.2 节一致,`message` 示例为 `摄像头录制已停止,手术已结束。` -**请求体字段:** +**状态码** +- `200`:停录已确认 +- `422`:参数校验失败 +- `503`:停录未确认或故障;`detail.code` 常见为 `RECORDING_NOT_STOPPED` -| 字段名 | 类型 | 必填 | 说明 | -| ------------ | -------- | --- | ------------------ | -| `surgery_id` | `string` | 是 | 手术 6 位号,必须为 6 位数字。 | - - -**请求示例:** +**请求示例** ```json { @@ -145,100 +371,76 @@ } ``` -**成功响应(HTTP 200):** 表示停录已完成。 +### 5.4 查询手术结果 +**基本信息** -| 字段名 | 类型 | 说明 | -| ------------ | -------- | ---------------------------- | -| `surgery_id` | `string` | 回显手术 6 位号。 | -| `status` | `string` | 处理状态(例如 `accepted` 表示停录已确认)。 | -| `message` | `string` | 人类可读的说明文案。 | +- 方法:`GET` +- 路径:`/client/surgeries/{surgery_id}/result` +- 路径参数:`surgery_id` +- 请求体:无 +**业务说明** -**响应示例:** +- 仅当存在至少一条消耗明细时返回 `200` +- 无明细(包括已归档但零消耗)、手术未开始、未成功开录或当前尚不可查时,返回 `503` +- 上述 `503` 场景的常见错误码为 `RESULT_NOT_READY` -```json -{ - "surgery_id": "123456", - "status": "accepted", - "message": "摄像头录制已停止,手术已结束。" -} -``` +**响应体(200)** -**重试:** 停录调用失败时,服务端会按配置**自动重试**(与开始手术相同的环境变量);**全部尝试仍失败**后再返回 **HTTP 503**。 +- `surgery_id` + - 类型:`string` + - 说明:手术号 +- `status` + - 类型:`string` + - 说明:成功时通常为 `completed` +- `message` + - 类型:`string` + - 说明:说明 +- `details` + - 类型:`array` + - 说明:消耗明细列表,字段见下文 +- `summary` + - 类型:`array` + - 说明:按 `item_id` 汇总的结果,字段见下文 -**失败响应(HTTP 503):** 重试用尽仍无法确认停录完成时返回。响应体见 **§5.2**(错误码示例:`RECORDING_NOT_STOPPED`);`detail.message` 中会注明已重试次数。 +**`details[]` 元素** ---- +- `item_id` + - 类型:`string` + - 说明:物品 ID;有目录时多为产品编码,否则通常与名称或模型类名一致 +- `item_name` + - 类型:`string` + - 说明:物品名称 +- `qty` + - 类型:`integer` + - 说明:本条记录数量,当前恒为 `1`;一次识别或一次人工确认只追加一条明细 +- `doctor_id` + - 类型:`string` + - 说明:记账关联的医生或系统标识 +- `timestamp` + - 类型:`string` + - 说明:ISO 8601 时间(`date-time`) -### 4.3 查询手术结果 +**`summary[]` 元素** -**用途:** 根据 `surgery_id` 查询该台手术下耗材消耗明细及按物品汇总。 +- `item_id` + - 类型:`string` + - 说明:与明细一致 +- `item_name` + - 类型:`string` + - 说明:名称,通常取该 `item_id` 首条明细中的名称 +- `total_quantity` + - 类型:`integer` + - 说明:该物品在本台手术中的合计数量,`>= 0` -**成功条件(HTTP 200):** 仅在**已开录**且**至少已有一条消耗明细**(自动识别或医生确认)之后返回 **HTTP 200** 及 `details` / `summary`。若已开录但尚无明细,返回 **503**(见 **§5.2**,错误码 `RESULT_NOT_READY`)。 +**状态码** +- `200`:至少有一条明细 +- `422`:`surgery_id` 路径不符合约束 +- `503`:`RESULT_NOT_READY`,当前无可用明细或不可查 -| 项目 | 内容 | -| --- | --------------------------------------- | -| 方法 | `GET` | -| 路径 | `/client/surgeries/{surgery_id}/result` | - - -**路径参数:** - - -| 参数名 | 类型 | 必填 | 说明 | -| ------------ | -------- | --- | ------------------------------ | -| `surgery_id` | `string` | 是 | 手术 6 位号,必须为 6 位数字,出现在 URL 路径中。 | - - -**请求示例:** - -```http -GET /client/surgeries/123456/result HTTP/1.1 -Host: <主机>:<端口> -``` - -(浏览器或客户端直接访问完整 URL 即可,例如 `https://<主机>:<端口>/client/surgeries/123456/result`。) - -**成功响应(HTTP 200):** - - -| 字段名 | 类型 | 说明 | -| ------------ | -------- | ---------------------------------------------------------------- | -| `surgery_id` | `string` | 手术 6 位号。 | -| `status` | `string` | 成功时一般为 `completed`(以服务端约定为准)。 | -| `message` | `string` | 说明信息。 | -| `details` | 数组 | **消耗明细**:按事件发生,可能有多行;每行含物品、数量、医生、时间。 | -| `summary` | 数组 | **按物品汇总**:同一 `item_id` 在 `details` 中 `quantity` 的合计,便于客户端直接展示总计。 | - - -`**details[]` 中每一项(明细行):** - - -| 字段名 | 类型 | 必填 | 说明 | -| ----------- | --------- | --- | ---------------------------------------------------------------- | -| `item_id` | `string` | 是 | 物品 ID。 | -| `item_name` | `string` | 是 | 物品名称。 | -| `quantity` | `integer` | 是 | 本条记录对应的消耗数量(非负整数)。 | -| `doctor_id` | `string` | 是 | 医生 ID。 | -| `timestamp` | `string` | 是 | 记录时间,**ISO 8601**(JSON 中为 ISO 格式字符串,与 OpenAPI 中 `date-time` 一致)。 | -| `source` | `string` | 否 | `vision` 自动识别;`voice` 医生通过待确认接口确认。 | - - -`**summary[]` 中每一项(汇总行):** - - -| 字段名 | 类型 | 必填 | 说明 | -| ---------------- | --------- | --- | ------------------------------------- | -| `item_id` | `string` | 是 | 物品 ID。 | -| `item_name` | `string` | 是 | 物品名称(与明细中该 ID 首次出现时的名称一致,具体规则以服务端为准)。 | -| `total_quantity` | `integer` | 是 | 该物品在本台手术中的消耗数量**合计**。 | - - -**约定:** `summary` 应由服务端根据 `details` 按 `item_id` 汇总得到,保证与明细一致。 - -**响应示例:** +**响应示例(200)** ```json { @@ -249,136 +451,146 @@ Host: <主机>:<端口> { "item_id": "HC001", "item_name": "纱布", - "quantity": 2, + "qty": 1, "doctor_id": "D1001", "timestamp": "2026-04-21T10:30:00+08:00" - }, - { - "item_id": "HC001", - "item_name": "纱布", - "quantity": 1, - "doctor_id": "D1002", - "timestamp": "2026-04-21T11:05:00+08:00" - }, - { - "item_id": "HC002", - "item_name": "缝线", - "quantity": 1, - "doctor_id": "D1001", - "timestamp": "2026-04-21T10:45:00+08:00" } ], "summary": [ - { "item_id": "HC001", "item_name": "纱布", "total_quantity": 3 }, - { "item_id": "HC002", "item_name": "缝线", "total_quantity": 1 } + { + "item_id": "HC001", + "item_name": "纱布", + "total_quantity": 1 + } ] } ``` ---- +### 5.5 拉取待确认耗材 -### 4.4 拉取待确认耗材 +**基本信息** -**用途:** 当模型置信度不足但存在候选时,服务端将任务放入 FIFO 队列。客户端轮询本接口获取**队首**一条待确认项,使用 `prompt_text` 进行 TTS 播报,并由医生口述选择;**服务端视频推理不等待本步骤**。 +- 方法:`GET` +- 路径:`/client/surgeries/{surgery_id}/pending-confirmation` +- 路径参数:`surgery_id` +- 请求体:无 -**成功条件(HTTP 200):** 当前手术进行中且队列非空。 +**业务说明** -**失败(HTTP 404):** 无待确认项或手术未在进行。`detail.code` 示例:`NO_PENDING_CONFIRMATION`。 +- 返回当前 FIFO 队首的一条低置信度识别任务 +- `prompt_audio_mp3_base64` 与 `prompt_text` 内容一致,为标准 Base64 的 MP3 字符串(无换行) +- 客户端解码后应按 `audio/mpeg` 播放 +**响应体(200)** -| 项目 | 内容 | -| --- | ----------------------------------------------------- | -| 方法 | `GET` | -| 路径 | `/client/surgeries/{surgery_id}/pending-confirmation` | +- `surgery_id` + - 类型:`string` + - 说明:手术号 +- `confirmation_id` + - 类型:`string` + - 说明:待确认项 ID;提交 5.6 节接口时原样放入路径 +- `prompt_text` + - 类型:`string` + - 说明:播报或展示用语,与 MP3 内容一致 +- `prompt_audio_mp3_base64` + - 类型:`string` + - 说明:MP3 的 Base64 +- `options` + - 类型:`array` + - 说明:候选项列表,字段见下文 +- `model_top1_label` + - 类型:`string` + - 说明:模型原始 Top1 类名,可能不在本台候选内 +- `model_top1_confidence` + - 类型:`number` + - 说明:Top1 置信度 +- `created_at` + - 类型:`string` + - 说明:创建时间(ISO 8601) +**`options[]` 元素** -**响应字段(节选):** `confirmation_id`、`prompt_text`、`options[]`(`label` + `confidence`)、`model_top1_label`、`model_top1_confidence`、`created_at`。 +- `label` + - 类型:`string` + - 说明:展示给医生的选项名称 +- `confidence` + - 类型:`number` + - 说明:该选项对应的置信度 ---- +**状态码** -### 4.5 提交耗材确认结果 +- `200`:当前有一条待确认 +- `404`:无待确认或手术未活跃;常见错误码为 `NO_PENDING_CONFIRMATION` +- `422`:例如话术为空导致无法 TTS;错误码见响应,如 `TTS_TEXT_EMPTY` +- `503`:语音服务未配置或 TTS 失败;例如 `BAIDU_NOT_CONFIGURED`、`TTS_ERROR` -**用途:** 客户端采集医生回答的 **WAV 音频**并上传;服务端将音频存入 MinIO、调用百度 ASR 识别、解析 4.4 返回的候选项;**确认**则记一条 `source=voice` 的消耗明细,**否认**则关闭该待确认项且不记账。 +### 5.6 提交待确认结果(医生语音) +**基本信息** -| 项目 | 内容 | -| -------------- | ------------------------------------------------------------------------------- | -| 方法 | `POST` | -| 路径 | `/client/surgeries/{surgery_id}/pending-confirmation/{confirmation_id}/resolve` | -| `Content-Type` | `multipart/form-data` | +- 方法:`POST` +- 路径:`/client/surgeries/{surgery_id}/pending-confirmation/{confirmation_id}/resolve` +- Content-Type:`multipart/form-data` +**路径参数** -**请求体(multipart):** +- `surgery_id` + - 约束:6 位数字 + - 说明:同 5.0 节 +- `confirmation_id` + - 约束:长度 1 到 128 + - 说明:与 5.5 节响应中的 `confirmation_id` 一致 +**请求体(multipart)** -| 字段 | 类型 | 必填 | 说明 | -| ------- | --- | --- | -------------------------------------------------------- | -| `audio` | 文件 | 是 | 医生语音 `**.wav`**;建议 16 kHz 单声道 PCM,其他格式服务端可尝试用 ffmpeg 转码。 | +- `audio` + - 类型:`file` + - 必填:是 + - 说明:单个 `.wav` 文件;建议使用 16 kHz 单声道 PCM;非 `.wav` 扩展名会返回 `422` +**业务说明** -**成功响应(HTTP 200):** `SurgeryPendingConfirmationResolveResponse`:`resolved_label`、`rejected`、`asr_text`、`audio_object_key` 等(与 OpenAPI 一致)。 +音频上传至对象存储后执行 ASR 和候选解析。若识别为确认某个候选项,则记一条消耗;若识别为否认全部候选,则不记消耗。 -**错误:** `404`(项不存在或手术未活跃)、`409`(已处理)、`422`(空文件、非 `.wav`、ASR/解析失败等业务码见 `detail.code`)、`503`(MinIO/百度未配置或上传失败等)。 +**响应体(200)** -> **说明:** 人工追问的 **TTS 播报由客户端**根据 4.4 的 `prompt_text` 完成;服务端不要求部署扬声器/麦克风。 +- `surgery_id` + - 类型:`string` + - 说明:手术号 +- `confirmation_id` + - 类型:`string` + - 说明:待确认 ID +- `status` + - 类型:`string` + - 说明:成功时为 `accepted` +- `message` + - 类型:`string` + - 说明:说明 +- `resolved_label` + - 类型:`string | null` + - 说明:确认后的耗材名称;否认全部候选时为 `null` +- `rejected` + - 类型:`boolean` + - 说明:是否否认全部候选,不记消耗时为 `true` +- `asr_text` + - 类型:`string | null` + - 说明:语音识别文本 +- `audio_object_key` + - 类型:`string | null` + - 说明:对象存储中的原始 WAV 键,便于追溯 ---- +**状态码** -## 5. 错误与校验 +- `200`:已受理并完成解析 +- `404`:确认项不存在或手术未活跃;例如 `CONFIRMATION_NOT_FOUND` +- `409`:当前确认项已处理过;例如 `CONFIRMATION_ALREADY_RESOLVED` +- `422`:空文件、非 `.wav`、`VOICE_AUDIO_INVALID`、ASR/解析失败等,具体错误码见响应 +- `503`:MinIO、百度等依赖不可用;例如 `MINIO_NOT_CONFIGURED`、`MINIO_UPLOAD_FAILED`、`BAIDU_NOT_CONFIGURED` -### 5.1 参数校验(HTTP 422) +**cURL 示例** -当参数不符合约束时(例如 `surgery_id` 不是 6 位数字、开始手术时 `camera_ids` 为空数组等),服务端通常返回 **HTTP 422**,响应体为 FastAPI/Pydantic 风格的校验错误详情。 - -建议在客户端侧对 `surgery_id` 先做本地校验,减少无效请求。 - -### 5.2 业务未就绪(HTTP 503) - -当**成功条件未满足**(开录/停录未确认、或查询结果时算法结果尚未就绪)时,服务端返回 **HTTP 503**,响应体为 JSON,且 `**detail` 为对象**(与 OpenAPI 中的 `**SurgeryClientErrorResponse`** / `**SurgeryClientErrorDetail`** 一致): - - -| 字段 | 类型 | 说明 | -| ------------------- | -------- | ---------------------------------------------------------------------------- | -| `detail.code` | `string` | 业务错误码,如 `RECORDING_CANNOT_START`、`RECORDING_NOT_STOPPED`、`RESULT_NOT_READY`。 | -| `detail.message` | `string` | 人类可读说明。 | -| `detail.surgery_id` | `string` | 手术 6 位号。 | - - -**示例:** - -```json -{ - "detail": { - "code": "RESULT_NOT_READY", - "message": "仅在已开录且算法已产生可查询的实时计算结果后返回 HTTP 200;当前条件不满足。", - "surgery_id": "123456" - } -} +```bash +curl -sS -X POST \ + "http://<主机>:38080/client/surgeries/123456/pending-confirmation//resolve" \ + -F "audio=@/path/to/voice.wav;type=audio/wav" ``` - ---- - -## 6. 实现与演进说明(给阅读者) - -- **开始 / 结束 / 查询结果** 与录制、算法流水线的具体绑定以实现为准;**未满足约定条件时不返回 200**(见各节成功条件),与 **OpenAPI(`/docs` 或 `/openapi.json`)** 中声明的 **200 / 503 / 422** 一致。 -- **人工确认**由客户端完成 TTS 与拾音(ASR);服务端只提供结构化候选与话术,不要求部署环境具备扬声器/麦克风。 -- 接入真实子系统后,仍应保持:成功响应体与 `SurgeryApiResponse`、`SurgeryResultResponse` 模型一致;503 与 `SurgeryClientErrorResponse` 一致。 - -联调时请以 **OpenAPI 文档**(如 `/docs`)为准,本文档与之同步维护。 - ---- - -## 7. 文档修订 - - -| 版本 | 日期 | 说明 | -| --- | ---------- | ---------------------------------------------------------------- | -| 1.6 | 2026-04-21 | 待确认耗材接口;候选清单硬约束;查询结果需至少一条明细;客户端侧人工确认。 | -| 1.5 | 2026-04-21 | 开始/结束手术:录制流水线失败时重试,仍失败再 503;可配置 `SURGERY_RECORDING_`*。 | -| 1.4 | 2026-04-21 | 与 OpenAPI 对齐:开始/结束/查询的 200/503 条件及 `SurgeryClientErrorResponse`。 | -| 1.3 | 2026-04-21 | 结束手术:仅在实际停录确认后返回 HTTP 200;否则 503。 | -| 1.2 | 2026-04-21 | 查询结果响应增加 `details`(物品 id/名称/数量/医生/时间)与 `summary`(按物品汇总)。 | -| 1.1 | 2026-04-21 | 查询结果改为 `GET /client/surgeries/{surgery_id}/result`。 | -| 1.0 | 2026-04-21 | 初版,`POST /client/surgeries/start`、`POST /client/surgeries/end`。 | - - diff --git a/main.py b/main.py index 7c6d567..5ffa842 100644 --- a/main.py +++ b/main.py @@ -9,28 +9,46 @@ from loguru import logger from app.api import router as api_router from app.config import settings from app.database import check_database, engine, init_db_schema -from app.dependencies import camera_session_manager +from app.dependencies import build_container -logger.remove() -logger.add( - sys.stderr, - format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function} - {message}", -) + +def configure_logging() -> None: + """集中配置 loguru sink;由 create_app 显式调用,避免 import-time 副作用。""" + logger.remove() + logger.add( + sys.stderr, + format=( + "{time:YYYY-MM-DD HH:mm:ss} | " + "{level: <8} | " + "{name}:{function} - {message}" + ), + ) @asynccontextmanager async def lifespan(app: FastAPI): await check_database() - await init_db_schema() - logger.info("Database connection verified and schema ensured") - await camera_session_manager.start_archive_retry_loop() - yield - await camera_session_manager.shutdown() - await engine.dispose() - logger.info("Database engine disposed") + if settings.auto_create_schema: + await init_db_schema() + logger.info("Database connection verified; schema auto-created (dev mode)") + else: + logger.info( + "Database connection verified; auto_create_schema=false, " + "expecting 'alembic upgrade head' to have run" + ) + container = build_container(settings) + app.state.container = container + await container.start() + try: + yield + finally: + await container.shutdown() + await engine.dispose() + logger.info("Database engine disposed") def create_app() -> FastAPI: + configure_logging() application = FastAPI( title="Operation Room Monitor", lifespan=lifespan, @@ -69,9 +87,9 @@ app = create_app() def main() -> None: uvicorn.run( "main:app", - host="0.0.0.0", - port=38080, - reload=True, + host=settings.server_host, + port=settings.server_port, + reload=settings.server_reload, ) diff --git a/pyproject.toml b/pyproject.toml index 4c30255..a75c5e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "pytest>=8.3.0", "pytest-asyncio>=0.25.0", "aiosqlite>=0.21.0", + "alembic>=1.14.0", ] [tool.pytest.ini_options] diff --git a/scripts/demo_client/README.md b/scripts/demo_client/README.md index e1fc0d9..0810ad5 100644 --- a/scripts/demo_client/README.md +++ b/scripts/demo_client/README.md @@ -100,7 +100,7 @@ open http://localhost:38081/ - §4.1 `POST /client/surgeries/start` — 含 `surgery_id` 校验、`camera_ids` 多值输入、`candidate_consumables` 标签编辑器(初始值从 `/labels.json` 载入,可增删) - §4.2 `POST /client/surgeries/end` - §4.3 `GET /client/surgeries/{id}/result` — 以表格渲染 `details` 与 `summary` -- §4.4 `GET /client/surgeries/{id}/pending-confirmation` — 支持手动拉取与 2s 自动轮询 +- §4.4 `GET /client/surgeries/{id}/pending-confirmation` — 支持手动拉取与 **10s** 自动轮询(请求串行排队,避免与 §4.5 上传后紧接拉取竞态) - §4.5 `POST .../resolve` — 本地麦克风录音 → 16 kHz 单声道 WAV → `multipart/form-data` 上传 - **调试:无摄像头** — 两路视频选择与 `camera_id`;一键联调见上文;手跑假流见 `fake_rtsp_from_file.py` 与本文「调试:无真实摄像头」 diff --git a/scripts/demo_client/index.html b/scripts/demo_client/index.html index 13e3650..e2821c1 100644 --- a/scripts/demo_client/index.html +++ b/scripts/demo_client/index.html @@ -349,7 +349,7 @@ -

默认策略:Top1 置信度 < 0.9 且达语音下沿时多会入队待确认;≥ VIDEO_AUTO_CONFIRM_CONFIDENCE(默认 0.9)且标签在 candidate_consumables 内则直接记 vision,拉取待确认为 404。可在环境变量中调整 VIDEO_AUTO_CONFIRM_CONFIDENCE。确认时在「语音确认(录音)」上传 WAV 即可。

+

默认策略:Top1 置信度 < 0.9 且达语音下沿时多会入队待确认;≥ VIDEO_AUTO_CONFIRM_CONFIDENCE(默认 0.9)且标签在 candidate_consumables 内则直接记 vision,拉取待确认为 404。可在环境变量中调整 VIDEO_AUTO_CONFIRM_CONFIDENCE§4.1 开录返回 200 后本页会自动排入一次 §4.4 拉取;§4.5 上传成功后也会串行拉取下一条,多条待确认按服务端 FIFO 逐条处理。若在轮询 GET 尚未返回时已提交 §4.5,本页会丢弃过期响应并自动再拉一次,避免旧 confirmation_id 盖住新队首。

@@ -641,16 +641,24 @@ const detail = (body && (body.detail !== undefined)) ? body.detail : body; const errText = (typeof detail === "object" && detail !== null) ? JSON.stringify(detail, null, 2) : String(detail || body || "错误"); alert("一键开录失败 HTTP " + res.status + "\n\n" + errText); + return; } + // 开录成功后立即排入 §4.4;并使此前进行中的 pending GET 失效(避免旧 id 覆盖) + _pendingSyncSeq++; + fetchPendingOnce(); return; } const camera_ids = $("camera-ids").value.split(",").map(s => s.trim()).filter(Boolean); if (camera_ids.length === 0) { alert("camera_ids 至少要 1 个"); return; } - await apiJson("POST", "/client/surgeries/start", { + const { res } = await apiJson("POST", "/client/surgeries/start", { surgery_id: sid, camera_ids, candidate_consumables: [...tags], }); + if (res.ok) { + _pendingSyncSeq++; + fetchPendingOnce(); + } }; // ============================================================ @@ -710,11 +718,23 @@ // ============================================================ // §4.4 pending-confirmation(响应内带 Base64 MP3)+ 可选自动播报 + // §4.5 依赖本节的 confirmation_id;避免并发拉取竞态与「已无待确认仍保留旧 id」 // ============================================================ let pollTimer = null; - /** 仅在一次成功播出音频/TTS 后更新,避免未播成功却跳过 */ + /** + * 自动播报去重:在「开始排程」时同步写入当前 confirmation_id,避免串行 GET 在首段尚未播完时 + * 再次命中 !== 判断而启动第二遍播报(_pendingSyncSeq 重拉、开录后立即拉取等会紧挨着第二次 fetch)。 + * 仅当播放失败时在 catch 中清除,便于轮询/手动重试。 + */ let lastSpokenConfirmationId = null; let lastPendingPayload = null; + /** + * 与「进行中的 GET pending」比对:§4.5 resolve 成功、手术 id 变更、重新开录时递增。 + * 解决竞态:GET 已发出后用户才提交 resolve,晚到的响应会带着旧 confirmation_id,不得写回 UI。 + */ + let _pendingSyncSeq = 0; + /** 串行化 GET pending:链式 Promise,支持 await fetchPendingOnce()(§4.5 成功后拉取下一条) */ + let _pendingFetchChain = Promise.resolve(); /** 方案1:首次用户手势内播放极短静音,解锁自动播放;之后待确认 MP3 复用同一 Audio */ const SILENT_UNLOCK_DATA_URL = @@ -852,23 +872,27 @@ $("surgery-id").addEventListener("input", () => { lastSpokenConfirmationId = null; lastPendingPayload = null; + _pendingSyncSeq++; }); async function playLastPendingManually() { const p = lastPendingPayload; if (!p || !p.confirmation_id) return; const pt = (p.prompt_text || "").trim(); + const cid = p.confirmation_id; try { + lastSpokenConfirmationId = cid; await playPromptAudioBase64(p.prompt_audio_mp3_base64, pt); - lastSpokenConfirmationId = p.confirmation_id; } catch (e) { console.warn("[demo-client] 手动播放失败", e); + if (lastSpokenConfirmationId === cid) lastSpokenConfirmationId = null; } } - async function fetchPendingOnce() { + async function runFetchPendingOnce() { const sid = surgeryId(); if (!/^\d{6}$/.test(sid)) return; + const startSeq = _pendingSyncSeq; const path = `/client/surgeries/${sid}/pending-confirmation`; const url = baseUrl() + path; let res; @@ -878,6 +902,10 @@ addLog("GET", url, "NETWORK", String(e), { error: true }); return; } + if (startSeq !== _pendingSyncSeq) { + fetchPendingOnce(); + return; + } const raw = await res.text(); let body; try { @@ -885,6 +913,10 @@ } catch { body = raw; } + if (startSeq !== _pendingSyncSeq) { + fetchPendingOnce(); + return; + } if (res.status === 404) { // 无待确认为常态,不写入右侧「响应日志」,减少刷屏 } else { @@ -892,6 +924,15 @@ } const box = $("pending-render"); if (res.status === 200 && body && body.confirmation_id) { + const prevId = lastPendingPayload && lastPendingPayload.confirmation_id; + if (prevId && prevId !== body.confirmation_id) { + recordingWav = null; + $("btn-resolve").disabled = true; + $("audio-preview").hidden = true; + $("btn-download").style.display = "none"; + $("rec-info").textContent = "新待确认已入队,请重新录音后上传"; + $("rec-info").className = "warn small"; + } box.hidden = false; lastPendingPayload = body; $("confirmation-id").value = body.confirmation_id; @@ -911,26 +952,48 @@ if (btnPlay) btnPlay.onclick = () => void playLastPendingManually(); const pt = (body.prompt_text || "").trim(); const ttsOn = $("tts-pending") && $("tts-pending").checked; - if (ttsOn && pt && body.confirmation_id !== lastSpokenConfirmationId) { + const cidForTts = body.confirmation_id; + if (ttsOn && pt && cidForTts !== lastSpokenConfirmationId) { + lastSpokenConfirmationId = cidForTts; void (async () => { try { await playPromptAudioBase64(body.prompt_audio_mp3_base64, pt); - lastSpokenConfirmationId = body.confirmation_id; } catch (e) { console.warn("[demo-client] 自动播报未完成(可点「播放话术」)", e); + if (lastSpokenConfirmationId === cidForTts) lastSpokenConfirmationId = null; } })(); } } else if (res.status === 404) { lastPendingPayload = null; + lastSpokenConfirmationId = null; box.hidden = false; - box.innerHTML = '暂无待确认项。'; + box.innerHTML = '暂无待确认项。请先 §4.4 拉取到待确认后再 §4.5 录音上传。'; + $("confirmation-id").value = ""; + $("btn-resolve").disabled = true; + recordingWav = null; + $("audio-preview").hidden = true; + $("btn-download").style.display = "none"; + $("rec-info").textContent = "无待确认:无需录音;有新任务时会自动填入 confirmation_id"; + $("rec-info").className = "muted small"; } else { box.hidden = false; box.innerHTML = `HTTP ${res.status}`; } } + /** 对外入口:重叠调用自动排队;返回的 Promise 在整段链完成后 settle(便于 §4.5 await) */ + function fetchPendingOnce() { + const run = _pendingFetchChain.then( + () => runFetchPendingOnce(), + () => runFetchPendingOnce(), + ); + _pendingFetchChain = run.catch((e) => { + console.warn("[demo-client] pending fetch", e); + }); + return run; + } + $("btn-pending").onclick = fetchPendingOnce; function applyAutoPoll() { if (pollTimer) { clearInterval(pollTimer); pollTimer = null; } @@ -1110,6 +1173,7 @@ $("audio-preview").hidden = true; $("btn-download").style.display = "none"; lastSpokenConfirmationId = null; + _pendingSyncSeq++; $("rec-info").textContent = "已提交,正在拉取下一条待确认…"; $("rec-info").className = "ok small"; await fetchPendingOnce(); diff --git a/tests/conftest.py b/tests/conftest.py index b825ab3..b302517 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,9 @@ -"""Shared test fixtures (SQLite memory DB, AsyncSessionLocal monkeypatch).""" +"""Shared test fixtures (SQLite memory DB session factory).""" from __future__ import annotations -import asyncio -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator -import pytest import pytest_asyncio from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine @@ -15,7 +13,7 @@ from app.db.base import Base @pytest_asyncio.fixture async def sqlite_session_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession], None]: - """In-memory SQLite + create_all; yields async_sessionmaker.""" + """In-memory SQLite + create_all; yields async_sessionmaker suitable for injection.""" engine = create_async_engine("sqlite+aiosqlite:///:memory:") async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) @@ -28,47 +26,3 @@ async def sqlite_session_factory() -> AsyncGenerator[async_sessionmaker[AsyncSes ) yield factory await engine.dispose() - - -@pytest.fixture -def patched_async_session_local( - monkeypatch: pytest.MonkeyPatch, -) -> Generator[async_sessionmaker[AsyncSession], None, None]: - """ - Replace AsyncSessionLocal in modules that open DB sessions, for sync tests - (e.g. TestClient) that use asyncio.run internally. - """ - engine = create_async_engine("sqlite+aiosqlite:///:memory:") - - async def _init() -> None: - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - asyncio.run(_init()) - factory = async_sessionmaker( - engine, - class_=AsyncSession, - expire_on_commit=False, - autoflush=False, - autobegin=False, - ) - - monkeypatch.setattr( - "app.services.video.session_manager.AsyncSessionLocal", - factory, - ) - monkeypatch.setattr( - "app.services.surgery_pipeline.AsyncSessionLocal", - factory, - ) - monkeypatch.setattr( - "app.services.voice_resolution.AsyncSessionLocal", - factory, - ) - - yield factory - - async def _dispose() -> None: - await engine.dispose() - - asyncio.run(_dispose()) diff --git a/tests/faces/图片_20260423162025_32_52.png b/tests/faces/图片_20260423162025_32_52.png new file mode 100644 index 0000000..30c3325 Binary files /dev/null and b/tests/faces/图片_20260423162025_32_52.png differ diff --git a/tests/faces/图片_20260423162213_11_62.png b/tests/faces/图片_20260423162213_11_62.png new file mode 100644 index 0000000..7ce3ed9 Binary files /dev/null and b/tests/faces/图片_20260423162213_11_62.png differ diff --git a/tests/faces/图片_20260423172228_12_62.png b/tests/faces/图片_20260423172228_12_62.png new file mode 100644 index 0000000..6768213 Binary files /dev/null and b/tests/faces/图片_20260423172228_12_62.png differ diff --git a/tests/faces/图片_20260423184622_13_62.png b/tests/faces/图片_20260423184622_13_62.png new file mode 100644 index 0000000..acda2ba Binary files /dev/null and b/tests/faces/图片_20260423184622_13_62.png differ diff --git a/tests/test_api_contract.py b/tests/test_api_contract.py index 00388b8..52857ba 100644 --- a/tests/test_api_contract.py +++ b/tests/test_api_contract.py @@ -148,6 +148,16 @@ def test_get_result_503_not_ready(api_app: FastAPI) -> None: assert r.json()["detail"]["code"] == "RESULT_NOT_READY" +def test_get_result_503_empty_details(api_app: FastAPI) -> None: + pipeline = MagicMock() + pipeline.get_consumption_details_for_client = AsyncMock(return_value=[]) + api_app.dependency_overrides[get_surgery_pipeline] = lambda: pipeline + client = TestClient(api_app) + r = client.get("/client/surgeries/123456/result") + assert r.status_code == 503 + assert r.json()["detail"]["code"] == "RESULT_NOT_READY" + + def test_pending_confirmation_200_and_404(api_app: FastAPI) -> None: ts = datetime(2026, 4, 21, 12, 0, tzinfo=timezone.utc) payload = SurgeryPendingConfirmationResponse( diff --git a/tests/test_app_integration.py b/tests/test_app_integration.py new file mode 100644 index 0000000..9d29d5a --- /dev/null +++ b/tests/test_app_integration.py @@ -0,0 +1,305 @@ +"""集成测试:通过真实的 ``create_app()`` + ``TestClient`` 走通 start/pending/resolve/end/result 全链路。 + +与 ``tests/test_api_contract.py`` 不同,这里不用 ``dependency_overrides`` 替换整个 +``SurgeryPipeline``;而是通过 ``app.state.container`` 注入一个「stubbed session manager」, +其余 pipeline/voice/repository 组件都保持真实实现,并使用 in-memory SQLite 作为会话工厂。 + +这样可以覆盖: +1. `create_app()` 的 CORS / demo_orchestrator 路径挂载、lifespan 启动/关闭流程。 +2. API → Pipeline → Session Registry → Repository 的真实调用链。 +3. durable fallback 目录被 ``ArchivePersister`` 写入/清理的真实路径。 + +所有会与外界交互的边界(RTSP/海康/MinIO/百度)通过容器中的 stub 对象隔离。 +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator +from datetime import datetime, timezone +from typing import Any + +import pytest +import pytest_asyncio +from fastapi.testclient import TestClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +import app.db.models # noqa: F401 # register ORM tables on Base.metadata +import main as main_module +from app.db.base import Base +from app.dependencies import AppContainer, build_container +from app.domain.consumption import SurgeryConsumptionStored +from app.services.video.session_registry import ( + PendingConsumableConfirmation, + RunningSurgery, + SurgerySessionState, +) + + +class _StubCameraSessionManager: + """仅实现测试链路必需方法;其余方法委托 ``_registry`` / ``_archive``(由真实组件提供)。 + + 这样 ``SurgeryPipeline`` / ``VoiceConfirmationService`` 读到的接口与真实 CameraSessionManager + 等价,不需要真实 RTSP 或推理流水线。 + """ + + def __init__(self, real: Any) -> None: + self._real = real + self._registry = real._registry + self._archive = real._archive + + async def start_surgery( + self, + surgery_id: str, + camera_ids: list[str], + candidate_consumables: list[str], + ) -> None: + if self._registry.has_active(surgery_id): + from app.surgery_errors import SurgeryPipelineError + + raise SurgeryPipelineError( + "RECORDING_CANNOT_START", + "该手术已在录制中,请勿重复开始。", + ) + state = SurgerySessionState( + candidate_consumables=list(candidate_consumables), + name_to_code={}, + ) + state.ready.set() + run = RunningSurgery( + stop_event=asyncio.Event(), state=state, tasks=[] + ) + await self._registry.register(surgery_id, run) + + async def stop_surgery( + self, surgery_id: str, *, require_active: bool = True + ) -> None: + run = await self._registry.unregister(surgery_id) + if run is None: + if require_active: + from app.surgery_errors import SurgeryPipelineError + + raise SurgeryPipelineError( + "RECORDING_NOT_STOPPED", + "停录未能完成:当前没有该手术的活跃录制会话。", + ) + return + details = list(run.state.details) + await self._archive.persist_or_archive(surgery_id, details) + + def __getattr__(self, name: str) -> Any: + return getattr(self._real, name) + + +class _StubVoiceService: + """屏蔽 MinIO/百度调用;保留 ``synthesize_prompt_to_mp3`` 与 ``resolve_from_wav`` 的最小语义。""" + + def __init__(self, real: Any) -> None: + self._real = real + self._sessions = real._sessions + + def synthesize_prompt_to_mp3(self, prompt_text: str) -> bytes: + return b"MP3-FAKE-" + prompt_text.encode("utf-8", errors="replace") + + async def resolve_from_wav( + self, + *, + surgery_id: str, + confirmation_id: str, + wav_bytes: bytes, + filename: str, + content_type: str | None, + ) -> Any: + from app.services.voice_resolution import VoiceResolveResult + from app.surgery_errors import SurgeryPipelineError + + pending = self._sessions.get_pending_confirmation_by_id( + surgery_id, confirmation_id + ) + if pending is None: + raise SurgeryPipelineError( + "CONFIRMATION_NOT_FOUND", + "未找到该待确认项或已处理。", + ) + label = (pending.options[0][0] if pending.options else None) or ( + pending.model_top1_label + ) + await self._sessions.resolve_pending_confirmation( + surgery_id, + confirmation_id, + chosen_label=label, + rejected=False, + ) + return VoiceResolveResult( + resolved_label=label, + rejected=False, + asr_text="第一个", + audio_object_key=f"stub/{surgery_id}/{confirmation_id}.wav", + message="确认成功(stub)", + ) + + +@pytest_asyncio.fixture +async def sqlite_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession], None]: + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + factory = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autobegin=False, + ) + yield factory + await engine.dispose() + + +@pytest.fixture +def integration_client( + monkeypatch: pytest.MonkeyPatch, + sqlite_factory: async_sessionmaker[AsyncSession], + tmp_path, +) -> TestClient: + async def _noop() -> None: + return None + + monkeypatch.setattr(main_module, "check_database", _noop) + monkeypatch.setattr(main_module, "init_db_schema", _noop) + + class _FakeEngine: + async def dispose(self) -> None: + return None + + monkeypatch.setattr(main_module, "engine", _FakeEngine()) + + from app.config import settings as real_settings + + monkeypatch.setattr( + real_settings, + "archive_persist_durable_fallback_dir", + str(tmp_path / "pending_archive"), + ) + monkeypatch.setattr(real_settings, "auto_create_schema", False) + + def _stubbed_build_container(*args, **kwargs) -> AppContainer: + container = build_container(real_settings, session_factory=sqlite_factory) + container.camera_session_manager = _StubCameraSessionManager( + container.camera_session_manager + ) + container.surgery_pipeline._sessions = container.camera_session_manager + container.voice_confirmation_service._sessions = ( + container.camera_session_manager._registry + ) + container.surgery_pipeline._voice = _StubVoiceService( + container.surgery_pipeline + ) + return container + + monkeypatch.setattr(main_module, "build_container", _stubbed_build_container) + + async def _instant_sleep(_d: float) -> None: + return None + + monkeypatch.setattr("app.api.asyncio.sleep", _instant_sleep) + + app = main_module.create_app() + with TestClient(app) as client: + yield client + + +def _enqueue_pending( + client: TestClient, *, surgery_id: str +) -> str: + container: AppContainer = client.app.state.container + run = container.camera_session_manager._registry.get_running(surgery_id) + assert run is not None + cid = "cid-integration-1" + pending = PendingConsumableConfirmation( + id=cid, + status="pending", + options=[("纱布", 0.42)], + prompt_text="请确认:是否为纱布", + created_at=datetime.now(timezone.utc), + model_top1_label="纱布", + model_top1_confidence=0.42, + ) + run.state.pending_fifo.append(cid) + run.state.pending_by_id[cid] = pending + return cid + + +def test_full_flow_start_pending_resolve_end_result( + integration_client: TestClient, +) -> None: + client = integration_client + surgery_id = "100001" + + r = client.post( + "/client/surgeries/start", + json={ + "surgery_id": surgery_id, + "camera_ids": ["cam1"], + "candidate_consumables": ["纱布"], + }, + ) + assert r.status_code == 200, r.text + assert r.json()["status"] == "accepted" + + r2 = client.get(f"/client/surgeries/{surgery_id}/pending-confirmation") + assert r2.status_code == 404, r2.text + + cid = _enqueue_pending(client, surgery_id=surgery_id) + + r3 = client.get(f"/client/surgeries/{surgery_id}/pending-confirmation") + assert r3.status_code == 200, r3.text + body3 = r3.json() + assert body3["confirmation_id"] == cid + import base64 + + decoded = base64.b64decode(body3["prompt_audio_mp3_base64"].encode("ascii")) + assert decoded.startswith(b"MP3-FAKE-") + + r4 = client.post( + f"/client/surgeries/{surgery_id}/pending-confirmation/{cid}/resolve", + files={"audio": ("voice.wav", b"RIFFxxxx", "audio/wav")}, + ) + assert r4.status_code == 200, r4.text + body4 = r4.json() + assert body4["resolved_label"] == "纱布" + assert body4["rejected"] is False + + r5 = client.get(f"/client/surgeries/{surgery_id}/result") + assert r5.status_code == 200, r5.text + body5 = r5.json() + assert body5["surgery_id"] == surgery_id + assert len(body5["details"]) == 1 + row = body5["details"][0] + assert row["item_name"] == "纱布" + assert row["qty"] == 1 + assert row["doctor_id"] == "voice" + + r6 = client.post( + "/client/surgeries/end", json={"surgery_id": surgery_id} + ) + assert r6.status_code == 200, r6.text + + r7 = client.get(f"/client/surgeries/{surgery_id}/result") + assert r7.status_code == 200, r7.text + body7 = r7.json() + assert len(body7["details"]) == 1 + assert body7["details"][0]["item_name"] == "纱布" + + +def test_result_not_ready_before_start(integration_client: TestClient) -> None: + r = integration_client.get("/client/surgeries/999999/result") + assert r.status_code == 503 + assert r.json()["detail"]["code"] == "RESULT_NOT_READY" + + +def test_health_endpoint_ok_via_real_app(integration_client: TestClient) -> None: + r = integration_client.get("/health") + assert r.status_code == 200 + body = r.json() + assert body["status"] == "ok" + assert body["database"] == "connected" diff --git a/tests/test_archive_persister.py b/tests/test_archive_persister.py new file mode 100644 index 0000000..8b1adca --- /dev/null +++ b/tests/test_archive_persister.py @@ -0,0 +1,101 @@ +"""ArchivePersister:指数退避、重试上限与 durable fallback 恢复。""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from app.config import Settings +from app.domain.consumption import SurgeryConsumptionStored +from app.repositories.surgery_results import SurgeryResultRepository +from app.services.video.archive_persister import ArchivePersister + + +class _AlwaysFailRepo(SurgeryResultRepository): + def __init__(self) -> None: + super().__init__() + self.calls = 0 + + async def save_final_result(self, session: AsyncSession, **kwargs: object) -> None: + self.calls += 1 + raise RuntimeError("db down") + + +def _detail(item_id: str = "纱布") -> SurgeryConsumptionStored: + return SurgeryConsumptionStored( + item_id=item_id, + item_name=item_id, + qty=1, + doctor_id="vision", + timestamp=datetime(2026, 4, 23, 12, 0, tzinfo=timezone.utc), + source="vision", + ) + + +@pytest.mark.asyncio +async def test_persist_or_archive_writes_durable_fallback( + tmp_path, + sqlite_session_factory: async_sessionmaker[AsyncSession], +) -> None: + fallback_dir = tmp_path / "pending_archive" + settings = Settings(archive_persist_durable_fallback_dir=str(fallback_dir)) + repo = _AlwaysFailRepo() + persister = ArchivePersister( + settings=settings, + repository=repo, + session_factory=sqlite_session_factory, + ) + ok = await persister.persist_or_archive("abc123", [_detail("纱布")]) + assert ok is False + path = fallback_dir / "abc123.json" + assert path.exists() + payload = json.loads(path.read_text(encoding="utf-8")) + assert payload["surgery_id"] == "abc123" + assert payload["details"][0]["item_id"] == "纱布" + assert persister.archived_details("abc123") is not None + + +@pytest.mark.asyncio +async def test_recover_from_durable_fallback_reloads_pending_archive( + tmp_path, + sqlite_session_factory: async_sessionmaker[AsyncSession], +) -> None: + fallback_dir = tmp_path / "pending_archive" + fallback_dir.mkdir() + payload = { + "surgery_id": "recov01", + "saved_at": "2026-04-23T08:00:00+00:00", + "details": [ + { + "item_id": "缝线", + "item_name": "缝线", + "qty": 1, + "doctor_id": "vision", + "timestamp": "2026-04-23T08:00:00+00:00", + "source": "vision", + } + ], + } + (fallback_dir / "recov01.json").write_text( + json.dumps(payload, ensure_ascii=False), encoding="utf-8" + ) + settings = Settings(archive_persist_durable_fallback_dir=str(fallback_dir)) + persister = ArchivePersister( + settings=settings, + repository=SurgeryResultRepository(), + session_factory=sqlite_session_factory, + ) + loaded = await persister.recover_from_durable_fallback() + assert loaded == 1 + details = persister.archived_details("recov01") + assert details is not None + assert details[0].item_id == "缝线" + + # 下一次 retry 应成功落库并清理内存 + durable 文件。 + ok = await persister.try_persist_archive("recov01") + assert ok is True + assert persister.archived_details("recov01") is None + assert not (fallback_dir / "recov01.json").exists() diff --git a/tests/test_archive_restart_recovery.py b/tests/test_archive_restart_recovery.py new file mode 100644 index 0000000..026d7d6 --- /dev/null +++ b/tests/test_archive_restart_recovery.py @@ -0,0 +1,121 @@ +"""进程重启后的归档恢复集成测试。 + +场景:某次手术结束后写库失败 → ArchivePersister 将明细写入 durable fallback 目录。 +之后 API 进程重启(相当于重新 ``create_app()``)时,``AppContainer.start()`` 会调用 +``camera_session_manager.start_archive_retry_loop()`` → ``recover_from_durable_fallback()``, +把磁盘上的待落库归档读回内存;随后走真实 DB 写入路径将其成功持久化。 +""" + +from __future__ import annotations + +import asyncio +import json +from collections.abc import AsyncGenerator +from datetime import datetime, timezone + +import pytest +import pytest_asyncio +from fastapi.testclient import TestClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +import app.db.models # noqa: F401 register ORM tables +import main as main_module +from app.db.base import Base +from app.dependencies import AppContainer, build_container +from app.domain.consumption import SurgeryConsumptionStored +from app.services.video.archive_persister import _serialize_details + + +@pytest_asyncio.fixture +async def sqlite_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession], None]: + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + factory = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autobegin=False, + ) + yield factory + await engine.dispose() + + +def _seed_durable_fallback(directory, surgery_id: str) -> None: + directory.mkdir(parents=True, exist_ok=True) + details = [ + SurgeryConsumptionStored( + item_id="item-1", + item_name="纱布", + qty=2, + doctor_id="voice", + timestamp=datetime(2026, 4, 23, 12, 0, tzinfo=timezone.utc), + source="voice", + ), + ] + payload = { + "surgery_id": surgery_id, + "saved_at": datetime.now(timezone.utc).isoformat(), + "details": _serialize_details(details), + } + (directory / f"{surgery_id}.json").write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + +def test_durable_fallback_recovers_on_startup_and_persists( + monkeypatch: pytest.MonkeyPatch, + sqlite_factory: async_sessionmaker[AsyncSession], + tmp_path, +) -> None: + durable_dir = tmp_path / "pending_archive" + surgery_id = "200001" + _seed_durable_fallback(durable_dir, surgery_id) + assert (durable_dir / f"{surgery_id}.json").exists() + + async def _noop() -> None: + return None + + monkeypatch.setattr(main_module, "check_database", _noop) + monkeypatch.setattr(main_module, "init_db_schema", _noop) + + class _FakeEngine: + async def dispose(self) -> None: + return None + + monkeypatch.setattr(main_module, "engine", _FakeEngine()) + + from app.config import settings as real_settings + + monkeypatch.setattr( + real_settings, "archive_persist_durable_fallback_dir", str(durable_dir) + ) + monkeypatch.setattr(real_settings, "auto_create_schema", False) + monkeypatch.setattr(real_settings, "archive_persist_retry_interval_seconds", 5.0) + + def _build(*_a, **_kw) -> AppContainer: + return build_container(real_settings, session_factory=sqlite_factory) + + monkeypatch.setattr(main_module, "build_container", _build) + + app = main_module.create_app() + with TestClient(app) as client: + container: AppContainer = client.app.state.container + archive = container.camera_session_manager._archive + assert archive.archived_details(surgery_id) is not None + + ok = asyncio.get_event_loop().run_until_complete( + archive.try_persist_archive(surgery_id) + ) + assert ok, "Expected immediate retry to persist against sqlite" + assert archive.archived_details(surgery_id) is None + assert not (durable_dir / f"{surgery_id}.json").exists() + + r = client.get(f"/client/surgeries/{surgery_id}/result") + assert r.status_code == 200, r.text + body = r.json() + assert len(body["details"]) == 1 + assert body["details"][0]["item_name"] == "纱布" + assert body["details"][0]["qty"] == 2 diff --git a/tests/test_consumption_tsv_log.py b/tests/test_consumption_tsv_log.py index 25b0988..2f73c09 100644 --- a/tests/test_consumption_tsv_log.py +++ b/tests/test_consumption_tsv_log.py @@ -47,7 +47,7 @@ def test_build_tsv_line_matches_sample_shape(monkeypatch: pytest.MonkeyPatch) -> wall_end_epoch=w0 + 45.0, ) parts = line.rstrip("\n").split("\t") - assert len(parts) == 5 + assert len(parts) == 9 assert parts[0] == "2237844" assert parts[1] == "一次性医用灭菌棉签" assert parts[2] == "1" @@ -58,6 +58,10 @@ def test_build_tsv_line_matches_sample_shape(monkeypatch: pytest.MonkeyPatch) -> + _RANGE_SEP + "2024-01-01T00:00:45.000+00:00" ) + assert parts[5] == "cls2" + assert parts[6] == "0.0003" + assert parts[7] == "cls3" + assert parts[8] == "0.0002" def test_resolve_consumption_item_id_uses_normalized_catalog_key() -> None: @@ -67,7 +71,17 @@ def test_resolve_consumption_item_id_uses_normalized_catalog_key() -> None: def test_header_columns() -> None: cols = HEADER.strip().split("\t") - assert cols == ["item_id", "item_name", "qty", "doctor_id", "timestamp"] + assert cols == [ + "item_id", + "item_name", + "qty", + "doctor_id", + "timestamp", + "top2_name", + "top2_conf", + "top3_name", + "top3_conf", + ] def test_per_surgery_file_init_and_append( @@ -137,8 +151,10 @@ def test_build_consumption_markdown_top123_columns(monkeypatch: pytest.MonkeyPat wall_end_epoch=w0 + 45.0, ) assert "| item_id |" in md and "| item_name |" in md and "| qty |" in md + assert "| top2 |" in md and "| top3 |" in md assert "2237844" in md assert "一次性医用灭菌棉签" in md + assert "cls2" in md and "cls3" in md assert "DOCTOR_PLACEHOLDER" in md assert "| 1 |" in md # 终端为可读时间戳,非落盘用 ISO@cam diff --git a/tests/test_effective_candidate_consumables.py b/tests/test_effective_candidate_consumables.py new file mode 100644 index 0000000..324d172 --- /dev/null +++ b/tests/test_effective_candidate_consumables.py @@ -0,0 +1,51 @@ +"""effective_candidate_consumables:空请求时回退到目录或模型类名。""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +from openpyxl import Workbook + +from app.config import Settings +from app.services.consumable_vision_algorithm import ConsumableVisionAlgorithmService + + +def test_effective_preserves_non_empty_request() -> None: + svc = ConsumableVisionAlgorithmService(Settings()) + got = svc.effective_candidate_consumables([" 纱布 ", "缝线", "纱布"]) + assert got == ["纱布", "缝线"] + + +def test_effective_empty_uses_model_class_names(monkeypatch: pytest.MonkeyPatch) -> None: + svc = ConsumableVisionAlgorithmService(Settings()) + mock_cls = MagicMock() + mock_cls.names = {0: "ban", 1: "apple"} + monkeypatch.setattr(svc, "_get_cls", lambda: mock_cls) + assert svc.effective_candidate_consumables([]) == ["apple", "ban"] + + +def test_effective_empty_prefers_catalog_xlsx(tmp_path: Path) -> None: + xlsx = tmp_path / "cat.xlsx" + wb = Workbook() + ws = wb.active + ws.append(["产品编码", "商品名称"]) + ws.append(["C1", "商品乙"]) + ws.append(["C2", "商品甲"]) + wb.save(xlsx) + + settings = Settings(consumable_catalog_xlsx_path=str(xlsx)) + svc = ConsumableVisionAlgorithmService(settings) + got = svc.effective_candidate_consumables([]) + assert got == ["商品乙", "商品甲"] + + +def test_effective_whitespace_only_treated_as_empty( + monkeypatch: pytest.MonkeyPatch, +) -> None: + svc = ConsumableVisionAlgorithmService(Settings()) + mock_cls = MagicMock() + mock_cls.names = {0: "x"} + monkeypatch.setattr(svc, "_get_cls", lambda: mock_cls) + assert svc.effective_candidate_consumables(["", " "]) == ["x"] diff --git a/tests/test_probs_numpy_device.py b/tests/test_probs_numpy_device.py new file mode 100644 index 0000000..699f5c2 --- /dev/null +++ b/tests/test_probs_numpy_device.py @@ -0,0 +1,36 @@ +"""_probs_data_to_numpy1d:CPU / CUDA / MPS 上均能离设备再转 NumPy。""" + +from __future__ import annotations + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") + +from app.services.consumable_vision_algorithm import _probs_data_to_numpy1d + + +def test_probs_numpy_cpu_tensor() -> None: + t = torch.tensor([0.1, 0.3, 0.6], dtype=torch.float32) + arr = _probs_data_to_numpy1d(t) + assert arr.dtype == np.float64 + np.testing.assert_allclose(arr, [0.1, 0.3, 0.6], rtol=1e-5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA 不可用,跳过设备张量用例") +def test_probs_numpy_cuda_tensor() -> None: + t = torch.tensor([0.0, 1.0], dtype=torch.float32, device="cuda") + arr = _probs_data_to_numpy1d(t) + assert arr.dtype == np.float64 + np.testing.assert_allclose(arr, [0.0, 1.0], rtol=1e-5) + + +@pytest.mark.skipif( + not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), + reason="MPS 不可用,跳过设备张量用例", +) +def test_probs_numpy_mps_tensor() -> None: + t = torch.tensor([0.25, 0.75], dtype=torch.float32, device="mps") + arr = _probs_data_to_numpy1d(t) + assert arr.dtype == np.float64 + np.testing.assert_allclose(arr, [0.25, 0.75], rtol=1e-5) diff --git a/tests/test_session_manager_unit.py b/tests/test_session_manager_unit.py index d6312cc..b32b097 100644 --- a/tests/test_session_manager_unit.py +++ b/tests/test_session_manager_unit.py @@ -30,7 +30,7 @@ def test_live_consumption_requires_non_empty_details() -> None: ) st = SurgerySessionState(candidate_consumables=["纱布"]) run = RunningSurgery(stop_event=asyncio.Event(), state=st, tasks=[]) - mgr._active["123456"] = run + mgr._registry._active["123456"] = run st.ready.set() assert mgr.live_consumption_if_active("123456") is None @@ -59,7 +59,7 @@ async def test_resolve_voice_accepts_label_on_surgery_list_not_in_topk_options() model_top1_confidence=0.41, ) st.pending_fifo.append(pid) - mgr._active["123456"] = RunningSurgery( + mgr._registry._active["123456"] = RunningSurgery( stop_event=asyncio.Event(), state=st, tasks=[] ) @@ -95,7 +95,7 @@ async def test_resolve_pending_appends_voice_detail() -> None: ) st.pending_fifo.append(pid) run = RunningSurgery(stop_event=asyncio.Event(), state=st, tasks=[]) - mgr._active["123456"] = run + mgr._registry._active["123456"] = run await mgr.resolve_pending_confirmation( "123456", pid, chosen_label="纱布", rejected=False @@ -129,7 +129,7 @@ async def test_resolve_reject_closes_without_detail() -> None: model_top1_confidence=0.4, ) st.pending_fifo.append(pid) - mgr._active["123456"] = RunningSurgery( + mgr._registry._active["123456"] = RunningSurgery( stop_event=asyncio.Event(), state=st, tasks=[] ) @@ -171,14 +171,10 @@ async def test_archive_retry_loop_starts() -> None: result_repository=None, ) await mgr.start_archive_retry_loop() - assert mgr._retry_task is not None - mgr._retry_stop.set() - mgr._retry_task.cancel() - try: - await mgr._retry_task - except asyncio.CancelledError: - pass - mgr._retry_task = None + persister = mgr._archive + assert persister._retry_task is not None + await mgr.shutdown() + assert persister._retry_task is None @pytest.mark.asyncio @@ -242,9 +238,12 @@ async def test_handle_high_conf_top1_not_in_candidates_enqueues_pending() -> Non ], ) await mgr._handle_classification_result(state=state, cls_res=res) - assert state.details == [] + assert len(state.details) == 1 + assert state.details[0].item_name == "待确认" + assert state.details[0].source == "pending_confirmation" assert len(state.pending_fifo) == 1 pid = state.pending_fifo[0] + assert state.details[0].pending_confirmation_id == pid assert "缝线" in state.pending_by_id[pid].prompt_text @@ -270,6 +269,8 @@ async def test_handle_mid_confidence_enqueues_pending() -> None: ) await mgr._handle_classification_result(state=state, cls_res=res) assert len(state.pending_fifo) == 1 + assert len(state.details) == 1 + assert state.details[0].item_name == "待确认" @pytest.mark.asyncio @@ -360,7 +361,7 @@ async def test_resolve_invalid_chosen_label() -> None: model_top1_confidence=0.4, ) st.pending_fifo.append(pid) - mgr._active["123456"] = RunningSurgery( + mgr._registry._active["123456"] = RunningSurgery( stop_event=asyncio.Event(), state=st, tasks=[] ) with pytest.raises(SurgeryPipelineError) as excinfo: @@ -407,7 +408,7 @@ async def test_resolve_second_time_not_found() -> None: model_top1_confidence=0.4, ) st.pending_fifo.append(pid) - mgr._active["123456"] = RunningSurgery( + mgr._registry._active["123456"] = RunningSurgery( stop_event=asyncio.Event(), state=st, tasks=[] ) await mgr.resolve_pending_confirmation( @@ -442,7 +443,7 @@ async def test_resolve_already_resolved_status() -> None: ) st.pending_by_id[pid] = pending st.pending_fifo.append(pid) - mgr._active["123456"] = RunningSurgery( + mgr._registry._active["123456"] = RunningSurgery( stop_event=asyncio.Event(), state=st, tasks=[] ) pending.status = "confirmed" diff --git a/tests/test_session_rank.py b/tests/test_session_rank.py index 3cc103e..65efda1 100644 --- a/tests/test_session_rank.py +++ b/tests/test_session_rank.py @@ -1,5 +1,7 @@ from app.services.consumable_vision_algorithm import PredictionCandidate -from app.services.video.session_manager import _rank_topk_for_candidates +from app.services.video.classification_handler import ( + rank_topk_for_candidates as _rank_topk_for_candidates, +) def test_rank_respects_candidate_order() -> None: diff --git a/tests/test_stream_worker_redaction.py b/tests/test_stream_worker_redaction.py new file mode 100644 index 0000000..233109e --- /dev/null +++ b/tests/test_stream_worker_redaction.py @@ -0,0 +1,22 @@ +"""RTSP URL 日志脱敏。""" + +from __future__ import annotations + +from app.services.video.stream_worker import redact_rtsp_url + + +def test_redact_rtsp_url_with_credentials() -> None: + assert ( + redact_rtsp_url("rtsp://admin:secret@10.0.0.1:554/Streaming/Channels/101") + == "rtsp://***@10.0.0.1:554/Streaming/Channels/101" + ) + + +def test_redact_rtsp_url_without_credentials_unchanged() -> None: + url = "rtsp://10.0.0.1:554/Streaming/Channels/101" + assert redact_rtsp_url(url) == url + + +def test_redact_rtsp_url_empty() -> None: + assert redact_rtsp_url(None) == "" + assert redact_rtsp_url("") == "" diff --git a/tests/test_surgery_pipeline_persistence.py b/tests/test_surgery_pipeline_persistence.py index c70dc30..41ec575 100644 --- a/tests/test_surgery_pipeline_persistence.py +++ b/tests/test_surgery_pipeline_persistence.py @@ -10,11 +10,10 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.config import Settings +from app.domain.consumption import SurgeryConsumptionStored from app.repositories.surgery_results import SurgeryResultRepository -from app.schemas import SurgeryConsumptionStored from app.services.surgery_pipeline import SurgeryPipeline from app.services.video.session_manager import ( - ArchivedSurgery, CameraSessionManager, RunningSurgery, SurgerySessionState, @@ -22,26 +21,17 @@ from app.services.video.session_manager import ( from app.services.voice_resolution import VoiceConfirmationService -def _patch_db_sessions( - sqlite_session_factory: async_sessionmaker[AsyncSession], - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setattr( - "app.services.video.session_manager.AsyncSessionLocal", - sqlite_session_factory, - ) - monkeypatch.setattr( - "app.services.surgery_pipeline.AsyncSessionLocal", - sqlite_session_factory, +def _install_active(mgr: CameraSessionManager, surgery_id: str, state: SurgerySessionState) -> None: + """测试辅助:直接把一条 RunningSurgery 塞进注册表,跳过真实 camera worker。""" + mgr._registry._active[surgery_id] = RunningSurgery( + stop_event=asyncio.Event(), state=state, tasks=[] ) @pytest.mark.asyncio async def test_stop_surgery_persists_final_result( sqlite_session_factory: async_sessionmaker[AsyncSession], - monkeypatch: pytest.MonkeyPatch, ) -> None: - _patch_db_sessions(sqlite_session_factory, monkeypatch) repo = SurgeryResultRepository() settings = Settings() mgr = CameraSessionManager( @@ -49,6 +39,7 @@ async def test_stop_surgery_persists_final_result( vision_algorithm=MagicMock(), hikvision_runtime=None, result_repository=repo, + session_factory=sqlite_session_factory, ) ts = datetime(2026, 4, 21, 12, 0, tzinfo=timezone.utc) st = SurgerySessionState(candidate_consumables=["纱布"]) @@ -63,9 +54,7 @@ async def test_stop_surgery_persists_final_result( ) ) st.ready.set() - mgr._active["123456"] = RunningSurgery( - stop_event=asyncio.Event(), state=st, tasks=[] - ) + _install_active(mgr, "123456", st) await mgr.stop_surgery("123456", require_active=True) @@ -75,7 +64,7 @@ async def test_stop_surgery_persists_final_result( assert loaded is not None assert len(loaded) == 1 assert loaded[0].item_id == "纱布" - assert mgr._archive.get("123456") is None + assert mgr.archived_consumption_fallback("123456") is None class _FlakyResultRepo(SurgeryResultRepository): @@ -93,16 +82,19 @@ class _FlakyResultRepo(SurgeryResultRepository): @pytest.mark.asyncio async def test_stop_surgery_failed_persist_goes_to_archive_then_retry_persists( sqlite_session_factory: async_sessionmaker[AsyncSession], + tmp_path, monkeypatch: pytest.MonkeyPatch, ) -> None: - _patch_db_sessions(sqlite_session_factory, monkeypatch) repo = _FlakyResultRepo() - settings = Settings() + settings = Settings( + archive_persist_durable_fallback_dir=str(tmp_path / "pending_archive") + ) mgr = CameraSessionManager( settings=settings, vision_algorithm=MagicMock(), hikvision_runtime=None, result_repository=repo, + session_factory=sqlite_session_factory, ) ts = datetime(2026, 4, 21, 12, 0, tzinfo=timezone.utc) st = SurgerySessionState(candidate_consumables=[]) @@ -116,18 +108,21 @@ async def test_stop_surgery_failed_persist_goes_to_archive_then_retry_persists( source="vision", ) ) - mgr._active["654321"] = RunningSurgery( - stop_event=asyncio.Event(), state=st, tasks=[] - ) + _install_active(mgr, "654321", st) await mgr.stop_surgery("654321", require_active=True) - assert "654321" in mgr._archive + assert mgr.archived_consumption_fallback("654321") is not None assert repo.calls == 1 - ok = await mgr._try_persist_archive("654321") + # durable fallback 文件应已写入 + durable = tmp_path / "pending_archive" / "654321.json" + assert durable.exists() + + ok = await mgr._archive.try_persist_archive("654321") assert ok is True - assert "654321" not in mgr._archive + assert mgr.archived_consumption_fallback("654321") is None assert repo.calls == 2 + assert not durable.exists(), "成功落库后 durable 文件应被清理" async with sqlite_session_factory() as session: async with session.begin(): @@ -140,9 +135,7 @@ async def test_stop_surgery_failed_persist_goes_to_archive_then_retry_persists( @pytest.mark.asyncio async def test_pipeline_prefers_live_then_db_then_archive( sqlite_session_factory: async_sessionmaker[AsyncSession], - monkeypatch: pytest.MonkeyPatch, ) -> None: - _patch_db_sessions(sqlite_session_factory, monkeypatch) repo = SurgeryResultRepository() settings = Settings() mgr = CameraSessionManager( @@ -150,12 +143,14 @@ async def test_pipeline_prefers_live_then_db_then_archive( vision_algorithm=MagicMock(), hikvision_runtime=None, result_repository=repo, + session_factory=sqlite_session_factory, ) voice = MagicMock(spec=VoiceConfirmationService) pipeline = SurgeryPipeline( mgr, result_repository=repo, voice_confirmation=voice, + session_factory=sqlite_session_factory, ) ts = datetime(2026, 4, 21, 12, 0, tzinfo=timezone.utc) @@ -171,9 +166,7 @@ async def test_pipeline_prefers_live_then_db_then_archive( ) ) st.ready.set() - mgr._active["111111"] = RunningSurgery( - stop_event=asyncio.Event(), state=st, tasks=[] - ) + _install_active(mgr, "111111", st) live = await pipeline.get_consumption_details_for_client("111111") assert live is not None @@ -186,8 +179,9 @@ async def test_pipeline_prefers_live_then_db_then_archive( assert len(from_db) == 1 assert from_db[0].item_id == "纱布" - mgr._archive["333333"] = ArchivedSurgery( - details=[ + await mgr._archive.restore( + "333333", + [ SurgeryConsumptionStored( item_id="归档项", item_name="归档项", @@ -196,7 +190,7 @@ async def test_pipeline_prefers_live_then_db_then_archive( timestamp=ts, source="vision", ) - ] + ], ) only_archive = await pipeline.get_consumption_details_for_client("333333") assert only_archive is not None diff --git a/tests/test_surgery_repository.py b/tests/test_surgery_repository.py index f34dac1..135528b 100644 --- a/tests/test_surgery_repository.py +++ b/tests/test_surgery_repository.py @@ -9,8 +9,8 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn import app.db.models # noqa: F401 from app.db.base import Base from app.db.models import SurgeryResultDetailRow +from app.domain.consumption import SurgeryConsumptionStored from app.repositories.surgery_results import SurgeryResultRepository -from app.schemas import SurgeryConsumptionStored @pytest.fixture diff --git a/tests/test_voice_pending_store_protocol.py b/tests/test_voice_pending_store_protocol.py new file mode 100644 index 0000000..09256da --- /dev/null +++ b/tests/test_voice_pending_store_protocol.py @@ -0,0 +1,135 @@ +"""解耦测试:用 fake PendingConfirmationStore 验证 VoiceConfirmationService 对端口的依赖。 + +该用例不构造完整的 CameraSessionManager,验证 Phase 5 引入的协议可替换性。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from app.config import Settings +from app.repositories.voice_audits import VoiceAuditRepository +from app.services.pending_confirmation_port import PendingConfirmationStore +from app.services.video.session_manager import PendingConsumableConfirmation +from app.services.voice_resolution import VoiceConfirmationService +from app.surgery_errors import SurgeryPipelineError + + +@dataclass +class _FakePendingStore: + """与 PendingConfirmationStore 协议等价的可控 fake;不依赖 CameraSessionManager。""" + + pendings: dict[tuple[str, str], PendingConsumableConfirmation] = field( + default_factory=dict + ) + candidates: dict[str, list[str]] = field(default_factory=dict) + resolved: list[tuple[str, str, str | None, bool]] = field(default_factory=list) + traces: list[dict[str, str | None]] = field(default_factory=list) + parse_failures: dict[str, int] = field(default_factory=dict) + + def get_pending_confirmation_by_id( + self, surgery_id: str, confirmation_id: str + ) -> PendingConsumableConfirmation | None: + return self.pendings.get((surgery_id, confirmation_id)) + + def get_surgery_candidate_consumables(self, surgery_id: str) -> list[str]: + return list(self.candidates.get(surgery_id, [])) + + async def record_voice_parse_failure( + self, surgery_id: str, confirmation_id: str + ) -> tuple[int, int]: + key = f"{surgery_id}:{confirmation_id}" + self.parse_failures[key] = self.parse_failures.get(key, 0) + 1 + remaining = max(0, 2 - self.parse_failures[key]) + return self.parse_failures[key], remaining + + async def resolve_pending_confirmation( + self, + surgery_id: str, + confirmation_id: str, + *, + chosen_label: str | None, + rejected: bool, + ) -> None: + self.resolved.append((surgery_id, confirmation_id, chosen_label, rejected)) + + def record_voice_trace( + self, + surgery_id: str, + *, + asr_text: str | None, + error: str | None, + ) -> None: + self.traces.append( + {"surgery_id": surgery_id, "asr_text": asr_text, "error": error} + ) + + +def test_fake_store_satisfies_protocol() -> None: + """_FakePendingStore 必须符合 PendingConfirmationStore 协议(静态/运行时同时验证)。""" + store = _FakePendingStore() + assert isinstance(store, PendingConfirmationStore) + + +@pytest.mark.asyncio +async def test_resolve_from_recognized_text_with_fake_store( + sqlite_session_factory: async_sessionmaker[AsyncSession], +) -> None: + store = _FakePendingStore() + surgery_id = "123456" + confirmation_id = "cid-a" + store.pendings[(surgery_id, confirmation_id)] = PendingConsumableConfirmation( + id=confirmation_id, + status="pending", + options=[("纱布", 0.4), ("缝线", 0.3)], + prompt_text="请确认", + created_at=datetime.now(timezone.utc), + model_top1_label="纱布", + model_top1_confidence=0.4, + ) + store.candidates[surgery_id] = ["纱布", "缝线"] + + svc = VoiceConfirmationService( + settings=Settings(), + sessions=store, + baidu=MagicMock(), + minio=MagicMock(), + audits=VoiceAuditRepository(), + session_factory=sqlite_session_factory, + ) + + result = await svc.resolve_from_recognized_text( + surgery_id=surgery_id, + confirmation_id=confirmation_id, + recognized_text="第一个", + ) + assert result.resolved_label == "纱布" + assert result.rejected is False + assert store.resolved == [(surgery_id, confirmation_id, "纱布", False)] + + +@pytest.mark.asyncio +async def test_resolve_from_recognized_text_not_found_branch( + sqlite_session_factory: async_sessionmaker[AsyncSession], +) -> None: + store = _FakePendingStore() + svc = VoiceConfirmationService( + settings=Settings(), + sessions=store, + baidu=MagicMock(), + minio=MagicMock(), + audits=VoiceAuditRepository(), + session_factory=sqlite_session_factory, + ) + with pytest.raises(SurgeryPipelineError) as excinfo: + await svc.resolve_from_recognized_text( + surgery_id="000000", + confirmation_id="missing", + recognized_text="第一个", + ) + assert excinfo.value.code == "CONFIRMATION_NOT_FOUND" diff --git a/tests/test_voice_resolution_service.py b/tests/test_voice_resolution_service.py index e6bbd65..0f8ce5c 100644 --- a/tests/test_voice_resolution_service.py +++ b/tests/test_voice_resolution_service.py @@ -42,12 +42,7 @@ def _make_service( minio: MagicMock, baidu: MagicMock, sqlite_factory, - monkeypatch: pytest.MonkeyPatch, ) -> VoiceConfirmationService: - monkeypatch.setattr( - "app.services.voice_resolution.AsyncSessionLocal", - sqlite_factory, - ) audits = VoiceAuditRepository() return VoiceConfirmationService( settings=settings, @@ -55,6 +50,7 @@ def _make_service( baidu=baidu, minio=minio, audits=audits, + session_factory=sqlite_factory, ) @@ -86,7 +82,7 @@ def _active_session_with_pending( ) st.pending_fifo.append(confirmation_id) - mgr._active[surgery_id] = RunningSurgery( + mgr._registry._active[surgery_id] = RunningSurgery( stop_event=asyncio.Event(), state=st, tasks=[] ) return mgr, confirmation_id @@ -131,7 +127,6 @@ async def test_resolve_recognized_appends_voice_detail_and_audit( minio=minio, baidu=baidu, sqlite_factory=sqlite_session_factory, - monkeypatch=monkeypatch, ) wav = _minimal_wav_16k_mono() result = await svc.resolve_from_wav( @@ -145,7 +140,7 @@ async def test_resolve_recognized_appends_voice_detail_and_audit( assert result.resolved_label == "纱布" assert result.asr_text == "第一个" assert result.audio_object_key is not None - st = sessions._active["123456"].state + st = sessions._registry._active["123456"].state assert len(st.details) == 1 assert st.details[0].source == "voice" assert await _audit_count(sqlite_session_factory, surgery_id="123456") == 1 @@ -178,7 +173,6 @@ async def test_resolve_rejected_audit( minio=minio, baidu=baidu, sqlite_factory=sqlite_session_factory, - monkeypatch=monkeypatch, ) result = await svc.resolve_from_wav( surgery_id="123456", @@ -189,7 +183,7 @@ async def test_resolve_rejected_audit( ) assert result.rejected is True assert result.resolved_label is None - assert len(sessions._active["123456"].state.details) == 0 + assert len(sessions._registry._active["123456"].state.details) == 0 async with sqlite_session_factory() as session: async with session.begin(): res = await session.execute(select(VoiceConfirmationAudit)) @@ -229,7 +223,6 @@ async def test_resolve_recognizes_label_not_in_topk_but_in_surgery_candidates( minio=minio, baidu=baidu, sqlite_factory=sqlite_session_factory, - monkeypatch=monkeypatch, ) result = await svc.resolve_from_wav( surgery_id="123456", @@ -240,7 +233,7 @@ async def test_resolve_recognizes_label_not_in_topk_but_in_surgery_candidates( ) assert result.rejected is False assert result.resolved_label == "止血钳" - st = sessions._active["123456"].state + st = sessions._registry._active["123456"].state assert len(st.details) == 1 assert st.details[0].item_name == "止血钳" assert st.details[0].source == "voice" @@ -263,7 +256,6 @@ async def test_audio_too_large_audit( minio=minio, baidu=baidu, sqlite_factory=sqlite_session_factory, - monkeypatch=monkeypatch, ) with pytest.raises(SurgeryPipelineError) as ei: await svc.resolve_from_wav( @@ -298,7 +290,6 @@ async def test_minio_not_configured_no_audit( minio=minio, baidu=baidu, sqlite_factory=sqlite_session_factory, - monkeypatch=monkeypatch, ) with pytest.raises(SurgeryPipelineError) as ei: await svc.resolve_from_wav( @@ -331,7 +322,6 @@ async def test_upload_failed_audit( minio=minio, baidu=baidu, sqlite_factory=sqlite_session_factory, - monkeypatch=monkeypatch, ) with pytest.raises(SurgeryPipelineError) as ei: await svc.resolve_from_wav( @@ -371,7 +361,6 @@ async def test_asr_failed_audit( minio=minio, baidu=baidu, sqlite_factory=sqlite_session_factory, - monkeypatch=monkeypatch, ) with pytest.raises(SurgeryPipelineError) as ei: await svc.resolve_from_wav( @@ -412,7 +401,6 @@ async def test_parse_failed_audit( minio=minio, baidu=baidu, sqlite_factory=sqlite_session_factory, - monkeypatch=monkeypatch, ) with pytest.raises(SurgeryPipelineError) as ei: await svc.resolve_from_wav( @@ -451,7 +439,6 @@ async def test_invalid_wav_decode_audit( minio=minio, baidu=baidu, sqlite_factory=sqlite_session_factory, - monkeypatch=monkeypatch, ) with pytest.raises(SurgeryPipelineError) as ei: await svc.resolve_from_wav( diff --git a/uv.lock b/uv.lock index 73c7dc2..f441e58 100644 --- a/uv.lock +++ b/uv.lock @@ -19,6 +19,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" }, ] +[[package]] +name = "alembic" +version = "1.18.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/13/8b084e0f2efb0275a1d534838844926f798bd766566b1375174e2448cd31/alembic-1.18.4.tar.gz", hash = "sha256:cb6e1fd84b6174ab8dbb2329f86d631ba9559dd78df550b57804d607672cedbc", size = 2056725, upload-time = "2026-02-10T16:00:47.195Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/29/6533c317b74f707ea28f8d633734dbda2119bbadfc61b2f3640ba835d0f7/alembic-1.18.4-py3-none-any.whl", hash = "sha256:a5ed4adcf6d8a4cb575f3d759f071b03cd6e5c7618eb796cb52497be25bfe19a", size = 263893, upload-time = "2026-02-10T16:00:49.997Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -636,6 +650,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/29/0348de65b8cc732daa3e33e67806420b2ae89bdce2b04af740289c5c6c8c/loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c", size = 61595, upload-time = "2024-12-06T11:20:54.538Z" }, ] +[[package]] +name = "mako" +version = "1.3.11" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/59/8a/805404d0c0b9f3d7a326475ca008db57aea9c5c9f2e1e39ed0faa335571c/mako-1.3.11.tar.gz", hash = "sha256:071eb4ab4c5010443152255d77db7faa6ce5916f35226eb02dc34479b6858069", size = 399811, upload-time = "2026-04-14T20:19:51.493Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/a5/19d7aaa7e433713ffe881df33705925a196afb9532efc8475d26593921a6/mako-1.3.11-py3-none-any.whl", hash = "sha256:e372c6e333cf004aa736a15f425087ec977e1fcbd2966aae7f17c8dc1da27a77", size = 78503, upload-time = "2026-04-14T20:19:53.233Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -866,6 +892,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "aiosqlite" }, + { name = "alembic" }, { name = "httpx" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -893,6 +920,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "aiosqlite", specifier = ">=0.21.0" }, + { name = "alembic", specifier = ">=1.14.0" }, { name = "httpx", specifier = ">=0.28.0" }, { name = "pytest", specifier = ">=8.3.0" }, { name = "pytest-asyncio", specifier = ">=0.25.0" },