diff --git a/.gitignore b/.gitignore index 1abbc7d..4fc92f5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,19 +1,31 @@ -# ---------- 输入(视频、Excel/表、转码产物,本地放置)---------- +# ---------- 本地输入(视频、Excel、HEVC 转 H.264,勿提交)---------- /input/* !/input/.gitkeep +/input/remuxed/ +/input/**/*.mp4 +/input/**/*.xlsx +/input/**/*.xls -# ---------- 推理结果(保留空目录占位)---------- +# ---------- 推理 / 推流产物(TSV、日志、ROI、可视化视频)---------- /output/* !/output/.gitkeep +/output/**/*.mp4 +/output/**/*.log +/output/**/*.txt +/output/**/*.tsv +/output/**/basket_roi*.json +/output/smoke/ -# ---------- 模型权重(setup.sh 检查,需本地放置)---------- +# ---------- 模型权重与 MediaPipe 任务文件(setup.sh 本地放置)---------- /weights/* !/weights/.gitkeep +/weights/**/*.pt +/weights/**/*.task doctor_identity_package/*.pth doctor_identity_package/.mediapipe_models/ -# 其他数据目录(若从 vendor 拷贝 Excel/权重) +# 其他数据目录(若从 vendor 拷贝) /data/ # ---------- 运行期临时 / 缓存 ---------- @@ -26,6 +38,7 @@ __pycache__/ .venv/ venv/ .env +.env.* *.egg-info/ dist/ build/ @@ -34,6 +47,17 @@ build/ .ruff_cache/ htmlcov/ .coverage +*.cover + +# ---------- 临时 / 调试输出(误写在项目根或脚本旁)---------- +*.mp4 +*.avi +*.mkv +*.mov +*.log +*.tmp +*.temp +*.bak # ---------- Jupyter / 编辑器临时 ---------- .ipynb_checkpoints/ @@ -44,3 +68,4 @@ htmlcov/ .DS_Store .idea/ .vscode/ +.cursor/ diff --git a/README.md b/README.md index 31b73f9..cd4cb59 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ 段内流程:**手检(≥2 手 union)→ 好坏帧门控 → 耗材分类**;**推流**在每个接触段内与耗材**并行**医生识别;**离线**在全片结束后追加一行医生信息。 -与 `configs/default_config.yaml` 当前参数一致(`imgsz_det: 1920`、`contact+1~+6` 等)。 +与 `configs/default_config.yaml` 当前参数一致(`imgsz_det: 1920`、`contact+3~+10` 等)。 ## 环境要求 @@ -75,8 +75,8 @@ bash scripts/remux_hevc.sh /path/to/source.mp4 | 段 | 关键参数 | |----|----------| | `phase2` | `imgsz_det: 1920`,`pad_bottom_ratio: 0.5`,`det_conf: 0.6` | -| `classification` | 好帧 0.8,耗材 0.8,重试 0.6 / 0.5 | -| `basket` | `iou_on: 0.03`,`confirm: 0.1`,`cooldown: 3`,窗口 contact+1~+6 | +| `classification` | 好帧 0.9,耗材 0.8,重试 0.7 / 0.5 | +| `basket` | `iou_on: 0.03`,`confirm: 0.1`,`cooldown: 10`,窗口 contact+3~+10 | | `stream` | 段窗口与 basket 一致;`infer_source: file` | | `io` | `use_whitelist: false`(全 41 类) | | `doctor_identity` | `enabled` / `stream_enabled`;推流用 `segment_sample_fps` 在段 `[start,end]` 内采样 | @@ -90,9 +90,51 @@ bash scripts/remux_hevc.sh /path/to/source.mp4 ## 输出格式 -- **推流**:15 列 TSV(12 列耗材 + `doctor_id` / `doctor_name` / `doctor_conf`),无末尾汇总行 +- **推流**:12 列耗材 TSV;各段内在 **居中 3s** 窗跑医生 ReID 后 **投票取 Top1**,末尾一行 `医生信息:...`(`doctor_identity.segment_window_sec`)。末段时长不足 7s 判为误触丢弃 - **离线**:12 列 TSV + 末尾一行 `医生信息:...`(全片中间窗口识别) +## 可视化演示(标注 MP4) + +跑完 `main_basket.py` 后,可用独立脚本把 **视频 + TSV** 合成带框演示片(段内复现手检框,Top3/医生文字直接读 TSV,不重跑耗材模型): + +```bash +conda activate yolo +cd /path/to/6.3 + +python scripts/visualize_pipeline.py \ + --video /path/to/your.mp4 \ + --tsv output/result_offline.txt \ + --basket-roi output/basket_roi.json \ + --config configs/default_config.yaml \ + --out output/demo_vis.mp4 +``` + +| 叠加层 | 说明 | +|--------|------| +| 青色虚线框 | 篮子 ROI(需 `--basket-roi`,与 `--save-basket-roi` 配套) | +| 绿色框 | 段内手部检测(`hand_detect.pt`) | +| 黄色粗框 | 双手 union ROI(与 Phase2 一致) | +| 顶部信息条 | TSV 该段时间段的 rank、Top3 或失败原因 | +| 片头 | 视频/TSV 路径 + 离线 `医生信息:` 汇总 | + +**性能建议**(长片/4K):`--preview-width 1280 --det-stride 5`。依赖系统已安装 `ffmpeg`。 + +**中文显示**:叠加文字使用 Pillow + 系统 CJK 字体(默认 `NotoSansCJK-Regular.ttc`)。若出现方框/乱码,请安装 `fonts-noto-cjk`,或通过 `--font /path/to/font.ttc` / 环境变量 `VIS_CJK_FONT` 指定字体。 + +**篮筐附近手框与 ROI**:提供 `--basket-roi` 时,默认只绘制靠近篮子的手(篮子框外扩 20% 后 IoU > `contact_iou_on`),**黄色 ROI** 由其中与篮子 IoU 最高的两只手合并。背景手不再绘制。关闭过滤用 `--no-hand-basket-filter`;贴边漏检可试 `--basket-expand-frac 0.3` 或略降 `--hand-basket-min-iou 0.02`。 + +**本地 smoke**(无真实手术视频时): + +```bash +conda run -n yolo python scripts/make_smoke_fixture.py +conda run -n yolo python scripts/visualize_pipeline.py \ + --video output/smoke/smoke.mp4 \ + --tsv output/smoke/smoke_result.txt \ + --basket-roi output/smoke/smoke_basket_roi.json \ + --out output/smoke/demo_vis.mp4 \ + --det-stride 2 --preview-width 640 --title-sec 1 +``` + ## 目录结构 ``` @@ -106,6 +148,7 @@ bash scripts/remux_hevc.sh /path/to/source.mp4 ├── doctor_identity_package/ # 医生识别(离线整片 + 推流段内) ├── src/ code/ # 编排与算法 ├── output/ # 结果输出目录 +├── scripts/visualize_pipeline.py # TSV → 标注演示 MP4 ├── setup.sh requirements.txt └── README.md ``` diff --git a/code/video_clip_cls/scripts/pipeline/segment_processor.py b/code/video_clip_cls/scripts/pipeline/segment_processor.py index 3e3b7da..4bcb972 100644 --- a/code/video_clip_cls/scripts/pipeline/segment_processor.py +++ b/code/video_clip_cls/scripts/pipeline/segment_processor.py @@ -41,6 +41,32 @@ from ultralytics import YOLO # noqa: E402 from pipeline.hand_roi_merge import HandRoiGrouper, two_largest_hands, union_xyxy # noqa: E402 # 与 run_haocai_actionformer_consumables_e2e 段内失败 return 文案一致,供 Phase2 重试判断 + + +def _detect_hands_on_frame( + det: Any, + fr: np.ndarray, + det_conf: float, + imgsz_det: int, + predict_kw: dict[str, Any] | None, +) -> list[list[float]]: + try: + from hand_detector import detect_hands_xyxy + + return detect_hands_xyxy( + det, + fr, + det_conf=det_conf, + imgsz_det=imgsz_det, + predict_kw=predict_kw, + ) + except ImportError: + pred_kw = dict(predict_kw or {}) + r0 = det.predict( + fr, conf=det_conf, imgsz=imgsz_det, verbose=False, **pred_kw + )[0] + return collect_hand_boxes(det, r0.boxes) if r0.boxes else [] + REASON_NO_VALID_HAOCAI_FRAMES = "(无有效耗材帧:好帧/白名单/耗材置信度未全部满足)" # 推流 / TSV 离线(无好坏帧门控) REASON_NO_VALID_HAOCAI_FRAMES_STREAM = "(无有效耗材帧:白名单/耗材置信度未满足)" @@ -312,8 +338,9 @@ def process_segment_multi_hand_tear( if frame_stride > 1 and (frames_read_in_segment - 1) % frame_stride != 0: return - r0 = det.predict(fr, conf=det_conf, imgsz=imgsz_det, verbose=False, **fg.predict_kw)[0] - hands = collect_hand_boxes(det, r0.boxes) if r0.boxes else [] + hands = _detect_hands_on_frame( + det, fr, det_conf, imgsz_det, fg.predict_kw + ) if len(hands) < 2: return @@ -602,8 +629,7 @@ def process_segment_haocai_from_frames( if frame_stride > 1 and (frames_in_segment - 1) % frame_stride != 0: return - r0 = det.predict(fr, conf=det_conf, imgsz=imgsz_det, verbose=False, **pred_kw)[0] - hands = collect_hand_boxes(det, r0.boxes) if r0.boxes else [] + hands = _detect_hands_on_frame(det, fr, det_conf, imgsz_det, pred_kw) crop = _crop_two_hands_union(fr, hands, pad_ratio) if crop is None: return @@ -702,8 +728,7 @@ def process_segment_haocai_from_cap( return img_h, img_w = fr.shape[:2] - r0 = det.predict(fr, conf=det_conf, imgsz=imgsz_det, verbose=False, **pred_kw)[0] - hands = collect_hand_boxes(det, r0.boxes) if r0.boxes else [] + hands = _detect_hands_on_frame(det, fr, det_conf, imgsz_det, pred_kw) crop = _crop_two_hands_union(fr, hands, pad_ratio) if crop is None: return diff --git a/configs/default_config.yaml b/configs/default_config.yaml index 4d7e1ec..40e65e4 100644 --- a/configs/default_config.yaml +++ b/configs/default_config.yaml @@ -15,6 +15,16 @@ weights: goodbad: weights/goodbad_frame.pt haocai: weights/haocai_classify.pt +# 手部检测:yolo=hand_detect.pt;mediapipe=Hand Landmarker(段内 ROI / 推流接触判定) +hand: + backend: mediapipe + mediapipe_task: weights/hand_landmarker.task + mediapipe_num_hands: 2 + mediapipe_min_detection_confidence: 0.3 + mediapipe_min_presence_confidence: 0.3 + mediapipe_min_tracking_confidence: 0.3 + mediapipe_bbox_margin: 0.05 + runtime: work_dir: null keep_work_dir: false @@ -41,8 +51,8 @@ phase2: classification: imgsz_cls: 224 - good_top1_conf_threshold: 0.8 - good_top1_retry_threshold: 0.6 + good_top1_conf_threshold: 0.9 + good_top1_retry_threshold: 0.7 haocai_min_conf: 0.8 haocai_min_conf_retry: 0.5 empty_cache_every: 0 @@ -64,6 +74,8 @@ doctor_identity: middle_seconds: 10.0 sample_fps: 3.0 segment_sample_fps: 3.0 + # 推流/段内医生 ReID:仅在每段窗口内居中取最多 3s 采样(耗材段窗口仍为 7s) + segment_window_sec: 3.0 pad_frac: 0.15 # 篮子接触分段(main_basket.py / main_basket_stream.py) @@ -73,9 +85,9 @@ basket: contact_iou_on: 0.03 contact_iou_off: 0.01 confirm_seconds: 0.1 - cooldown_seconds: 3.0 - segment_start_offset_sec: 1.0 - segment_end_offset_sec: 6.0 + cooldown_seconds: 10.0 + segment_start_offset_sec: 3.0 + segment_end_offset_sec: 10.0 min_segment_sec: 4.0 scan_frame_stride: 1 roi_frame: first @@ -86,16 +98,16 @@ basket: # 推流实时识别(main_basket_stream.py) # 接触判定 / 手检 imgsz / 好坏帧 / 耗材阈值:与离线共用 basket + phase2 + classification -# 段内推理:本地 MP4 回源 4K + phase2.imgsz_det=1920(与离线一致);RTSP/缓存 fallback 时 JPEG 宽≤1920 +# 段内推理:本地 MP4 回源 4K + phase2.imgsz_det(与离线一致);RTSP/缓存 fallback 时 JPEG 宽≤cache_max_width stream: rtsp: null - ring_buffer_sec: 10.0 + ring_buffer_sec: 12.0 cache_max_width: 1920 jpeg_quality: 85 fps: 25.0 - # 段窗口与 basket 一致:[contact+1, contact+6],时长 5s - segment_start_offset_sec: 1.0 - segment_end_offset_sec: 6.0 + # 段窗口与 basket 一致:[contact+3, contact+10],时长 7s + segment_start_offset_sec: 3.0 + segment_end_offset_sec: 10.0 min_segment_sec: 4.0 infer_source: file infer_fallback: cache diff --git a/main_segments_offline.py b/main_segments_offline.py deleted file mode 100644 index c6c44f8..0000000 --- a/main_segments_offline.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python3 -"""按结果 TSV 时间段对离线视频做手检 → 耗材分类(跳过分段与撕膜,无好坏帧门控)。""" -from __future__ import annotations - -import argparse -import os -import sys -from pathlib import Path - -PACK_ROOT = Path(__file__).resolve().parent -sys.path.insert(0, str(PACK_ROOT / "src")) - -from paths import ensure_code_on_path - -ensure_code_on_path(PACK_ROOT) - -from config import load_run_config -from segments_offline_orchestrator import run_segments_offline_pipeline - - -def main() -> int: - os.environ.setdefault("OPENCV_FFMPEG_LOGLEVEL", "8") - ap = argparse.ArgumentParser( - description="TSV 时间段 → 离线视频段内耗材识别(无 ActionFormer / 无篮子分段 / 无撕膜)" - ) - ap.add_argument("--video", type=Path, required=True, help="输入 MP4") - ap.add_argument( - "--segments-tsv", - type=Path, - required=True, - help="含 start_sec/end_sec 的结果 TSV(如推流输出)", - ) - ap.add_argument( - "--excel", - type=Path, - required=True, - help="商品表 Excel(C 列白名单 + 产品编码)", - ) - ap.add_argument("--out", type=Path, required=True, help="输出 TSV") - ap.add_argument( - "--config", - type=Path, - default=PACK_ROOT / "configs" / "default_config.yaml", - help="配置文件", - ) - ap.add_argument( - "--skip-empty-segments", - action="store_true", - help="跳过 TSV 中 top1_name 为空或为失败文案的行", - ) - args = ap.parse_args() - - cfg_path = args.config.resolve() - if not cfg_path.is_file(): - print("找不到配置:", cfg_path, file=sys.stderr) - return 1 - - run_cfg = load_run_config(PACK_ROOT, cfg_path) - run_cfg.video = args.video.resolve() - run_cfg.excel = args.excel.resolve() - run_cfg.out = args.out.resolve() - run_cfg.segments_tsv = args.segments_tsv.resolve() - run_cfg.segments_skip_empty = bool(args.skip_empty_segments) - - return int(run_segments_offline_pipeline(run_cfg)) - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/make_smoke_fixture.py b/scripts/make_smoke_fixture.py new file mode 100644 index 0000000..bf71adf --- /dev/null +++ b/scripts/make_smoke_fixture.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +"""生成 smoke 用短视频 + TSV + 篮子 ROI(供 visualize_pipeline 验收)。""" +from __future__ import annotations + +import json +from pathlib import Path + +import cv2 +import numpy as np + +PACK_ROOT = Path(__file__).resolve().parent.parent +OUT = PACK_ROOT / "output" / "smoke" + + +def main() -> None: + OUT.mkdir(parents=True, exist_ok=True) + video = OUT / "smoke.mp4" + tsv = OUT / "smoke_result.txt" + roi = OUT / "smoke_basket_roi.json" + + w, h, fps, n = 640, 480, 10.0, 50 + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + wr = cv2.VideoWriter(str(video), fourcc, fps, (w, h)) + for i in range(n): + frame = np.zeros((h, w, 3), dtype=np.uint8) + frame[:, :] = (30 + i % 40, 50, 80) + cv2.rectangle(frame, (80, 60), (280, 360), (200, 200, 200), 2) + cv2.putText( + frame, f"frame {i}", (20, 40), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, + ) + wr.write(frame) + wr.release() + + tsv.write_text( + "rank\tstart_sec\tend_sec\tproduct_id_top1\ttop1_name\ttop1_conf\t" + "product_id_top2\ttop2_name\ttop2_conf\tproduct_id_top3\ttop3_name\ttop3_conf\n" + "1\t1.000000\t4.500000\tP001\t演示耗材A\t0.850000\tP002\t演示耗材B\t" + "0.120000\t\t0.000000\n" + "医生信息:演示医生 (id=SMOKE01, conf=0.9900)\n", + encoding="utf-8", + ) + roi.write_text( + json.dumps({"basket_xyxy": [70.0, 50.0, 290.0, 370.0], "video": str(video)}, indent=2) + + "\n", + encoding="utf-8", + ) + print(f"[smoke] {video}") + print(f"[smoke] {tsv}") + print(f"[smoke] {roi}") + + +if __name__ == "__main__": + main() diff --git a/scripts/verify_segment_spacing.py b/scripts/verify_segment_spacing.py new file mode 100644 index 0000000..c3e8204 --- /dev/null +++ b/scripts/verify_segment_spacing.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +"""离线扫描接触段,校验新窗口/冷却配置(无需弹窗 ROI JSON)。""" +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +PACK_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(PACK_ROOT / "src")) + +from paths import ensure_code_on_path # noqa: E402 + +ensure_code_on_path(PACK_ROOT) + +from config import load_run_config # noqa: E402 +from basket_segmenter import scan_contact_segments # noqa: E402 + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--video", type=Path, required=True) + ap.add_argument("--basket-roi", type=Path, required=True) + ap.add_argument("--config", type=Path, default=PACK_ROOT / "configs/default_config.yaml") + args = ap.parse_args() + + cfg = load_run_config(PACK_ROOT, args.config.resolve()) + roi = json.loads(args.basket_roi.read_text(encoding="utf-8"))["basket_xyxy"] + + segs = scan_contact_segments( + args.video.resolve(), + cfg.hand_model, + roi, + contact_iou_on=cfg.basket_contact_iou_on, + contact_iou_off=cfg.basket_contact_iou_off, + confirm_seconds=cfg.basket_confirm_seconds, + cooldown_seconds=cfg.basket_cooldown_seconds, + segment_start_offset_sec=cfg.basket_segment_start_offset_sec, + segment_end_offset_sec=cfg.basket_segment_end_offset_sec, + min_segment_sec=cfg.basket_min_segment_sec, + det_conf=cfg.basket_det_conf, + imgsz_det=cfg.imgsz_det, + device=str(cfg.device), + ) + print(f"total_segments={len(segs)}") + print( + f"offsets: start={cfg.basket_segment_start_offset_sec} " + f"end={cfg.basket_segment_end_offset_sec} cooldown={cfg.basket_cooldown_seconds}" + ) + for c, t0, t1 in segs: + if 215 <= c <= 245: + print(f" contact={c:.3f} window=[{t0:.3f},{t1:.3f}] dur={t1-t0:.3f}") + near = [c for c, _, _ in segs if abs(c - 226.2) < 1.0 or abs(c - 230.76) < 1.0] + print(f"contacts_near_old_11_12={len(near)} (expect 1)") + return 0 if len(near) <= 1 else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/vis_text.py b/scripts/vis_text.py new file mode 100644 index 0000000..95dcb3e --- /dev/null +++ b/scripts/vis_text.py @@ -0,0 +1,203 @@ +"""OpenCV 画面上绘制中文/混合文本(Pillow + 系统 CJK 字体)。""" +from __future__ import annotations + +import os +from pathlib import Path + +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +# Linux 常见中文字体(按优先级) +_FONT_CANDIDATES: tuple[str, ...] = ( + "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/opentype/noto/NotoSansCJK-Medium.ttc", + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", + "/usr/share/fonts/truetype/arphic/uming.ttc", + "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc", +) + + +def resolve_cjk_font(explicit: Path | None = None) -> Path: + if explicit is not None: + p = explicit.expanduser().resolve() + if p.is_file(): + return p + raise FileNotFoundError(f"指定字体不存在: {p}") + env = os.environ.get("VIS_CJK_FONT", "").strip() + if env: + p = Path(env).expanduser().resolve() + if p.is_file(): + return p + for raw in _FONT_CANDIDATES: + if Path(raw).is_file(): + return Path(raw) + raise FileNotFoundError( + "未找到可用的中文字体。请安装 fonts-noto-cjk 或设置环境变量 VIS_CJK_FONT=/path/to/font.ttc" + ) + + +def _bgr_to_rgb(color: tuple[int, int, int]) -> tuple[int, int, int]: + b, g, r = color + return (r, g, b) + + +class CjkTextRenderer: + """在 BGR ndarray 上绘制 Unicode 文本;字体按像素大小缓存。""" + + def __init__(self, font_path: Path | None = None) -> None: + self._font_path = resolve_cjk_font(font_path) + self._cache: dict[int, ImageFont.FreeTypeFont] = {} + print(f"[vis] 使用字体: {self._font_path}") + + def font_size_for_frame(self, h: int, w: int, *, kind: str = "hud") -> int: + base = min(h, w) + if kind == "title": + return max(22, int(base / 38)) + if kind == "label": + return max(16, int(base / 52)) + if kind == "small": + return max(14, int(base / 60)) + return max(18, int(base / 48)) + + def _font(self, size_px: int) -> ImageFont.FreeTypeFont: + if size_px not in self._cache: + self._cache[size_px] = ImageFont.truetype(str(self._font_path), size_px) + return self._cache[size_px] + + def measure(self, text: str, size_px: int) -> tuple[int, int]: + if not text: + return 0, 0 + font = self._font(size_px) + x0, y0, x1, y1 = font.getbbox(text) + return int(x1 - x0), int(y1 - y0) + + def _blit_texts( + self, + img_bgr: np.ndarray, + items: list[tuple[str, int, int, tuple[int, int, int], tuple[int, int, int] | None, int]], + *, + size_px: int, + ) -> None: + """一次 PIL 转换绘制多行。items: (text, x, y, color_bgr, bg_bgr|None, padding)。""" + if not items: + return + rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + pil = Image.fromarray(rgb) + draw = ImageDraw.Draw(pil) + font = self._font(size_px) + for text, x, y, color_bgr, bg_bgr, padding in items: + if not text: + continue + tw, th = self.measure(text, size_px) + if bg_bgr is not None: + draw.rectangle( + (x, y, x + tw + padding * 2, y + th + padding * 2), + fill=_bgr_to_rgb(bg_bgr), + ) + tx, ty = x + padding, y + padding + else: + tx, ty = x, y + draw.text((tx, ty), text, font=font, fill=_bgr_to_rgb(color_bgr)) + img_bgr[:] = cv2.cvtColor(np.asarray(pil), cv2.COLOR_RGB2BGR) + + def draw( + self, + img_bgr: np.ndarray, + text: str, + x: int, + y: int, + *, + size_px: int, + color_bgr: tuple[int, int, int] = (255, 255, 255), + bg_bgr: tuple[int, int, int] | None = None, + padding: int = 4, + ) -> tuple[int, int]: + """在 (x, y) 为左上角位置绘制文本。返回 (text_w, text_h)。""" + if not text: + return 0, 0 + tw, th = self.measure(text, size_px) + self._blit_texts( + img_bgr, + [(text, x, y, color_bgr, bg_bgr, padding)], + size_px=size_px, + ) + box_w = tw + (padding * 2 if bg_bgr else 0) + box_h = th + (padding * 2 if bg_bgr else 0) + return box_w, box_h + + def draw_lines_top( + self, + img_bgr: np.ndarray, + lines: list[str], + *, + size_px: int, + color_bgr: tuple[int, int, int] = (255, 255, 255), + bar_alpha: float = 0.55, + pad: int = 8, + line_gap: float = 1.35, + ) -> None: + """顶部半透明信息条 + 多行中文。""" + if not lines: + return + lh = int(size_px * line_gap) + widths = [self.measure(ln, size_px)[0] for ln in lines] + max_tw = max(widths) if widths else 0 + bar_h = pad * 2 + lh * len(lines) + overlay = img_bgr.copy() + cv2.rectangle(overlay, (0, 0), (max_tw + pad * 2, bar_h), (0, 0, 0), -1) + cv2.addWeighted(overlay, bar_alpha, img_bgr, 1.0 - bar_alpha, 0, img_bgr) + items = [ + (ln, pad, pad + i * lh, color_bgr, None, 0) for i, ln in enumerate(lines) + ] + self._blit_texts(img_bgr, items, size_px=size_px) + + def draw_lines_block( + self, + img_bgr: np.ndarray, + lines: list[str], + x: int, + y0: int, + *, + size_px: int, + color_bgr: tuple[int, int, int] = (255, 255, 255), + bg_bgr: tuple[int, int, int] = (0, 0, 0), + padding: int = 6, + line_gap: float = 1.45, + ) -> None: + """片头等:多行带黑底块。""" + if not lines: + return + lh = int(size_px * line_gap) + items = [ + (ln, x, y0 + i * lh, color_bgr, bg_bgr, padding) + for i, ln in enumerate(lines) + ] + self._blit_texts(img_bgr, items, size_px=size_px) + + def draw_label_on_box( + self, + img_bgr: np.ndarray, + x1: int, + y1: int, + label: str, + *, + size_px: int, + color_bgr: tuple[int, int, int], + bg_bgr: tuple[int, int, int], + ) -> None: + """框左上角外侧标签(中文/英文均可)。""" + tw, th = self.measure(label, size_px) + ty = max(0, y1 - th - 8) + self.draw( + img_bgr, + label, + x1, + ty, + size_px=size_px, + color_bgr=(0, 0, 0), + bg_bgr=bg_bgr, + padding=3, + ) diff --git a/scripts/visualize_pipeline.py b/scripts/visualize_pipeline.py new file mode 100644 index 0000000..b64bdcc --- /dev/null +++ b/scripts/visualize_pipeline.py @@ -0,0 +1,595 @@ +#!/usr/bin/env python3 +""" +后处理:原始 MP4 + 结果 TSV(+ 可选篮子 ROI JSON)→ 带框标注演示 MP4。 + +段内复跑 hand_detect 画手部/union ROI;Top3 与医生信息直接读 TSV,不重跑耗材/医生模型。 +""" +from __future__ import annotations + +import argparse +import os +import subprocess +import sys +from pathlib import Path +from typing import Any + +import cv2 +import numpy as np +from ultralytics import YOLO + +PACK_ROOT = Path(__file__).resolve().parent.parent +_SCRIPTS = Path(__file__).resolve().parent +sys.path.insert(0, str(PACK_ROOT / "src")) +sys.path.insert(0, str(_SCRIPTS)) + +from paths import ensure_code_on_path # noqa: E402 + +ensure_code_on_path(PACK_ROOT) + +from basket_segmenter import load_basket_roi_json # noqa: E402 +from config import load_run_config # noqa: E402 +from pipeline.hand_roi_merge import bbox_iou_xyxy, two_largest_hands, union_xyxy # noqa: E402 +from run_segments_consumable_vote import ( # noqa: E402 + collect_hand_boxes, + pad_box_bottom_only, +) +from vis_text import CjkTextRenderer # noqa: E402 +from visualize_tsv import ( # noqa: E402 + SegmentVis, + find_active_segment, + parse_result_tsv, +) + + +def _line_w(h: int, w: int) -> int: + return max(1, min(w, h) // 400) + + +def _scale_xyxy( + xyxy: list[float], scale_x: float, scale_y: float +) -> tuple[int, int, int, int]: + x1, y1, x2, y2 = xyxy + return ( + int(round(x1 * scale_x)), + int(round(y1 * scale_y)), + int(round(x2 * scale_x)), + int(round(y2 * scale_y)), + ) + + +def draw_dashed_rect( + img: np.ndarray, + x1: int, + y1: int, + x2: int, + y2: int, + color: tuple[int, int, int], + thickness: int, + dash: int = 12, +) -> None: + pts = [ + ((x1, y1), (x2, y1)), + ((x2, y1), (x2, y2)), + ((x2, y2), (x1, y2)), + ((x1, y2), (x1, y1)), + ] + for (a, b) in pts: + dx, dy = b[0] - a[0], b[1] - a[1] + length = int((dx * dx + dy * dy) ** 0.5) + if length <= 0: + continue + steps = max(1, length // dash) + for i in range(0, steps, 2): + t0 = i / steps + t1 = min((i + 1) / steps, 1.0) + p0 = (int(a[0] + dx * t0), int(a[1] + dy * t0)) + p1 = (int(a[0] + dx * t1), int(a[1] + dy * t1)) + cv2.line(img, p0, p1, color, thickness, cv2.LINE_AA) + + +def draw_labeled_box( + img: np.ndarray, + x1: int, + y1: int, + x2: int, + y2: int, + color: tuple[int, int, int], + label: str, + *, + thickness: int, + dashed: bool = False, + text: CjkTextRenderer, +) -> None: + x1, y1 = max(0, x1), max(0, y1) + h, w = img.shape[:2] + x2, y2 = min(w - 1, x2), min(h - 1, y2) + if x2 <= x1 or y2 <= y1: + return + if dashed: + draw_dashed_rect(img, x1, y1, x2, y2, color, thickness) + else: + cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness, cv2.LINE_AA) + fs = text.font_size_for_frame(h, w, kind="label") + text.draw_label_on_box(img, x1, y1, label, size_px=fs, color_bgr=color, bg_bgr=color) + + +def draw_hud( + img: np.ndarray, + seg: SegmentVis | None, + *, + t_sec: float, + doctor_summary: str | None, + video_name: str, + tsv_name: str, + title_mode: bool = False, + text: CjkTextRenderer, +) -> None: + h, w = img.shape[:2] + + if title_mode: + lines = [ + "手术室耗材流水线 — 可视化", + f"视频: {video_name}", + f"结果: {tsv_name}", + ] + if doctor_summary: + lines.append(f"医生: {doctor_summary}") + fs = text.font_size_for_frame(h, w, kind="title") + text.draw_lines_block( + img, + lines, + 12, + int(h * 0.10), + size_px=fs, + ) + return + + if seg is None: + return + + r = seg.row + lines = [ + f"rank={r.rank} t={t_sec:.2f}s [{r.start_sec:.2f}, {r.end_sec:.2f}]", + ] + if seg.is_failure(): + lines.append(r.n1.strip()) + else: + if r.n1.strip(): + lines.append(f"Top1: {r.n1} ({r.c1}) id={r.id1}") + if r.n2.strip(): + lines.append(f"Top2: {r.n2} ({r.c2})") + if r.n3.strip(): + lines.append(f"Top3: {r.n3} ({r.c3})") + doc = seg.doctor_line() + if doc: + lines.append(doc) + + fs = text.font_size_for_frame(h, w, kind="hud") + text.draw_lines_top(img, lines, size_px=fs) + + +def filter_hands_by_basket( + hand_confs: list[tuple[list[float], float]], + basket_xyxy: list[float], + min_iou: float, +) -> list[tuple[list[float], float]]: + """仅保留与篮子 ROI IoU 超过阈值的手(排除远处背景误检)。""" + basket = [float(v) for v in basket_xyxy] + kept: list[tuple[list[float], float]] = [] + for xyxy, conf in hand_confs: + if bbox_iou_xyxy(xyxy, basket) > float(min_iou) + 1e-12: + kept.append((xyxy, conf)) + return kept + + +def expand_basket_xyxy( + basket_xyxy: list[float], + expand_frac: float, + img_w: int, + img_h: int, +) -> list[float]: + """判定手是否靠近篮子时,外扩篮子框,避免贴边操作 IoU 偏低。""" + x1, y1, x2, y2 = [float(v) for v in basket_xyxy] + bw, bh = max(1.0, x2 - x1), max(1.0, y2 - y1) + px, py = bw * expand_frac, bh * expand_frac + return [ + max(0.0, x1 - px), + max(0.0, y1 - py), + min(float(img_w - 1), x2 + px), + min(float(img_h - 1), y2 + py), + ] + + +def union_roi_from_basket_hands( + near_hands: list[tuple[list[float], float]], + basket_xyxy: list[float], + img_w: int, + img_h: int, + pad_bottom_ratio: float, +) -> tuple[tuple[int, int, int, int] | None, list[tuple[list[float], float]]]: + """ + 黄 ROI:在篮筐附近的手中取与篮子 IoU 最高的两只做 union(与绿框同源)。 + """ + if len(near_hands) < 2: + return None, near_hands + basket = [float(v) for v in basket_xyxy] + ranked = sorted( + near_hands, + key=lambda t: bbox_iou_xyxy(t[0], basket), + reverse=True, + ) + h1, h2 = ranked[0][0], ranked[1][0] + u = union_xyxy(h1, h2) + roi = pad_box_bottom_only(u, img_w, img_h, pad_bottom_ratio) + return roi, near_hands + + +def _scale_basket_xyxy( + basket_xyxy: list[float], scale_x: float, scale_y: float +) -> list[float]: + x1, y1, x2, y2 = basket_xyxy + return [x1 * scale_x, y1 * scale_y, x2 * scale_x, y2 * scale_y] + + +def detect_hands_and_union( + det_model: YOLO, + frame: np.ndarray, + *, + det_conf: float, + imgsz_det: int, + pad_bottom_ratio: float, + predict_kw: dict[str, Any], + basket_xyxy: list[float] | None = None, + hand_basket_min_iou: float | None = None, + basket_expand_frac: float = 0.2, + use_basket_near_hands: bool = True, +) -> tuple[tuple[int, int, int, int] | None, list[tuple[list[float], float]]]: + """ + 返回 (union_roi, 待绘制 hand 列表)。坐标系与输入 frame 一致(已缩放后的画面)。 + 有篮子时默认:仅保留靠近篮子的手,黄 ROI 由其中 IoU 最高的两只合并。 + """ + h, w = frame.shape[:2] + r = det_model.predict( + frame, imgsz=imgsz_det, conf=det_conf, verbose=False, **predict_kw + )[0] + hand_confs: list[tuple[list[float], float]] = [] + if r.boxes is not None: + names = det_model.names + for box in r.boxes: + cid = int(box.cls[0]) + if names.get(cid, "") == "hand": + conf = float(box.conf[0]) if box.conf is not None else 0.0 + hand_confs.append((box.xyxy[0].tolist(), conf)) + + if ( + basket_xyxy is not None + and use_basket_near_hands + and hand_basket_min_iou is not None + ): + basket_match = expand_basket_xyxy( + basket_xyxy, basket_expand_frac, w, h + ) + near = filter_hands_by_basket( + hand_confs, basket_match, hand_basket_min_iou + ) + return union_roi_from_basket_hands( + near, basket_xyxy, w, h, pad_bottom_ratio + ) + + # 无篮子或未启用过滤:全图最大两只(仅作兜底) + draw_confs = hand_confs + union_roi: tuple[int, int, int, int] | None = None + all_xyxy = [hb for hb, _ in hand_confs] + if len(all_xyxy) >= 2: + h1, h2 = two_largest_hands(all_xyxy) + u = union_xyxy(h1, h2) + union_roi = pad_box_bottom_only(u, w, h, pad_bottom_ratio) + return union_roi, draw_confs + + +def resize_frame(frame: np.ndarray, preview_width: int) -> tuple[np.ndarray, float, float]: + h, w = frame.shape[:2] + if w <= preview_width: + return frame, 1.0, 1.0 + scale = preview_width / float(w) + nw = int(round(w * scale)) + nh = int(round(h * scale)) + out = cv2.resize(frame, (nw, nh), interpolation=cv2.INTER_AREA) + return out, scale, scale + + +def open_ffmpeg_writer( + out_path: Path, width: int, height: int, fps: float +) -> subprocess.Popen[bytes]: + out_path.parent.mkdir(parents=True, exist_ok=True) + cmd = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-vcodec", + "rawvideo", + "-pix_fmt", + "bgr24", + "-s", + f"{width}x{height}", + "-r", + f"{fps:.6f}", + "-i", + "-", + "-an", + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-crf", + "23", + "-pix_fmt", + "yuv420p", + str(out_path), + ] + return subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.DEVNULL, + ) + + +def run_visualize(args: argparse.Namespace, cfg: Any) -> int: + video_path = args.video.resolve() + tsv_path = args.tsv.resolve() + out_path = args.out.resolve() + + if not video_path.is_file(): + print(f"[vis] 视频不存在: {video_path}", file=sys.stderr) + return 1 + if not tsv_path.is_file(): + print(f"[vis] TSV 不存在: {tsv_path}", file=sys.stderr) + return 1 + if not Path(cfg.hand_model).is_file(): + print(f"[vis] 缺少手部权重: {cfg.hand_model}", file=sys.stderr) + return 1 + + segments, doctor_summary = parse_result_tsv(tsv_path) + if not segments: + print(f"[vis] TSV 无有效数据段: {tsv_path}", file=sys.stderr) + return 1 + print(f"[vis] 已加载 {len(segments)} 段; 医生汇总: {doctor_summary or '(无)'}") + + try: + cjk = CjkTextRenderer( + args.font.resolve() if getattr(args, "font", None) else None + ) + except FileNotFoundError as ex: + print(f"[vis] {ex}", file=sys.stderr) + return 1 + + basket_roi: list[float] | None = None + if args.basket_roi is not None: + basket_roi = load_basket_roi_json(args.basket_roi.resolve()) + + use_basket_near = not args.no_hand_basket_filter + hand_basket_min_iou: float | None = None + basket_expand_frac = float(args.basket_expand_frac) + if basket_roi is not None and use_basket_near: + hand_basket_min_iou = float( + args.hand_basket_min_iou + if args.hand_basket_min_iou is not None + else getattr(cfg, "basket_contact_iou_on", 0.03) + ) + print( + f"[vis] 篮筐附近手检: 外扩篮子 {basket_expand_frac:.0%} 后 IoU > " + f"{hand_basket_min_iou:.4f};绿框与黄 ROI 均仅用附近手" + ) + elif basket_roi is None and use_basket_near: + print( + "[vis] 未提供 --basket-roi,无法按篮子过滤;" + "将绘制全图手检结果", + file=sys.stderr, + ) + elif args.no_hand_basket_filter: + print("[vis] 已关闭篮筐过滤(--no-hand-basket-filter)") + + predict_kw: dict[str, Any] = {"device": cfg.device} + if cfg.half: + predict_kw["half"] = True + + det_model = YOLO(str(cfg.hand_model)) + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + print(f"[vis] 无法打开视频: {video_path}", file=sys.stderr) + return 1 + + fps = float(cap.get(cv2.CAP_PROP_FPS) or 25.0) + if fps <= 0: + fps = 25.0 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) + + ret, frame0 = cap.read() + if not ret or frame0 is None: + print("[vis] 无法读取首帧", file=sys.stderr) + cap.release() + return 1 + + frame0, sx0, sy0 = resize_frame(frame0, int(args.preview_width)) + out_h, out_w = frame0.shape[:2] + proc = open_ffmpeg_writer(out_path, out_w, out_h, fps) + + def write_frame(img: np.ndarray) -> None: + if proc.stdin is None: + raise RuntimeError("ffmpeg stdin 不可用") + if img.shape[1] != out_w or img.shape[0] != out_h: + img = cv2.resize(img, (out_w, out_h), interpolation=cv2.INTER_AREA) + proc.stdin.write(img.tobytes()) + + title_frames = max(1, int(round(float(args.title_sec) * fps))) + video_name = video_path.name + tsv_name = tsv_path.name + + for _ in range(title_frames): + title_img = frame0.copy() + draw_hud( + title_img, + None, + t_sec=0.0, + doctor_summary=doctor_summary, + video_name=video_name, + tsv_name=tsv_name, + title_mode=True, + text=cjk, + ) + write_frame(title_img) + + cap.set(cv2.CAP_PROP_POS_FRAMES, 0) + lw = _line_w(out_h, out_w) + cached_union: tuple[int, int, int, int] | None = None + cached_hand_confs: list[tuple[list[float], float]] = [] + det_calls = 0 + frame_idx = 0 + + while True: + ret, frame = cap.read() + if not ret or frame is None: + break + frame, sx, sy = resize_frame(frame, int(args.preview_width)) + t_sec = frame_idx / fps + active = find_active_segment(segments, t_sec) + vis = frame.copy() + + if basket_roi is not None: + bx1, by1, bx2, by2 = _scale_xyxy(basket_roi, sx, sy) + draw_labeled_box( + vis, bx1, by1, bx2, by2, (255, 200, 0), "篮子", + thickness=lw, dashed=True, text=cjk, + ) + + in_segment = active is not None + if in_segment and (frame_idx % int(args.det_stride) == 0): + basket_for_det: list[float] | None = None + if basket_roi is not None: + basket_for_det = _scale_basket_xyxy(basket_roi, sx, sy) + cached_union, cached_hand_confs = detect_hands_and_union( + det_model, + frame, + det_conf=float(cfg.det_conf), + imgsz_det=int(cfg.imgsz_det), + pad_bottom_ratio=float(cfg.pad_bottom_ratio), + predict_kw=predict_kw, + basket_xyxy=basket_for_det, + hand_basket_min_iou=hand_basket_min_iou, + basket_expand_frac=basket_expand_frac, + use_basket_near_hands=use_basket_near and basket_roi is not None, + ) + det_calls += 1 + + if in_segment: + for hxyxy, conf in cached_hand_confs: + x1, y1, x2, y2 = (int(round(v)) for v in hxyxy[:4]) + draw_labeled_box( + vis, x1, y1, x2, y2, (0, 220, 0), f"手 {conf:.2f}", + thickness=lw, + text=cjk, + ) + if cached_union is not None: + ux1, uy1, ux2, uy2 = cached_union + draw_labeled_box( + vis, ux1, uy1, ux2, uy2, (0, 220, 255), "ROI", + thickness=max(lw + 1, 2), + text=cjk, + ) + draw_hud( + vis, + active, + t_sec=t_sec, + doctor_summary=doctor_summary, + video_name=video_name, + tsv_name=tsv_name, + text=cjk, + ) + else: + cached_union = None + cached_hand_confs = [] + if args.draw_outside_segments: + fs = cjk.font_size_for_frame(out_h, out_w, kind="small") + cjk.draw( + vis, + "非识别段", + 10, + out_h - fs - 12, + size_px=fs, + color_bgr=(180, 180, 180), + ) + + write_frame(vis) + frame_idx += 1 + if frame_idx % 500 == 0: + print(f"[vis] 进度 {frame_idx}/{total_frames or '?'} 帧, 手检次数={det_calls}") + + cap.release() + if proc.stdin: + proc.stdin.close() + rc = proc.wait() + if rc != 0: + print(f"[vis] ffmpeg 退出码 {rc}", file=sys.stderr) + return 1 + + print(f"[vis] 完成: {out_path} ({frame_idx} 帧 + {title_frames} 片头, 段内手检 {det_calls} 次)") + return 0 + + +def main() -> int: + os.environ.setdefault("OPENCV_FFMPEG_LOGLEVEL", "8") + ap = argparse.ArgumentParser(description="MP4 + TSV → 带框标注演示视频") + ap.add_argument("--video", type=Path, required=True, help="原始 MP4") + ap.add_argument("--tsv", type=Path, required=True, help="main_basket 输出的 TSV/txt") + ap.add_argument("--out", type=Path, required=True, help="输出 MP4") + ap.add_argument( + "--config", + type=Path, + default=PACK_ROOT / "configs" / "default_config.yaml", + ) + ap.add_argument( + "--basket-roi", + type=Path, + default=None, + help="篮子 ROI JSON(main_basket --save-basket-roi)", + ) + ap.add_argument("--det-stride", type=int, default=3, help="段内每 N 帧手检一次") + ap.add_argument("--preview-width", type=int, default=1920, help="输出宽度上限") + ap.add_argument( + "--draw-outside-segments", + action="store_true", + help="非 TSV 时间段角标「非识别段」", + ) + ap.add_argument("--title-sec", type=float, default=3.0, help="片头时长(秒)") + ap.add_argument( + "--font", + type=Path, + default=None, + help="中文字体路径(.ttc/.ttf);默认自动查找 Noto/WQY 等", + ) + ap.add_argument( + "--no-hand-basket-filter", + action="store_true", + help="关闭篮筐附近过滤(默认开启:少画背景手,黄 ROI 在篮筐处)", + ) + ap.add_argument( + "--hand-basket-min-iou", + type=float, + default=None, + help="手与(外扩后)篮子最小 IoU;默认 basket.contact_iou_on", + ) + ap.add_argument( + "--basket-expand-frac", + type=float, + default=0.2, + help="判定靠近篮子时外扩 ROI 比例(默认 0.2)", + ) + args = ap.parse_args() + + cfg = load_run_config(PACK_ROOT, args.config.resolve()) + return run_visualize(args, cfg) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/visualize_tsv.py b/scripts/visualize_tsv.py new file mode 100644 index 0000000..f89ac18 --- /dev/null +++ b/scripts/visualize_tsv.py @@ -0,0 +1,148 @@ +"""TSV 解析(可视化用,无 torch/ultralytics 依赖)。""" +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class E2eRow: + rank: int + start_sec: float + end_sec: float + id1: str + n1: str + c1: str + id2: str + n2: str + c2: str + id3: str + n3: str + c3: str + + +@dataclass +class SegmentVis: + """单段可视化元数据(TSV 一行 + 可选推流医生列)。""" + + row: E2eRow + doctor_id: str = "" + doctor_name: str = "" + doctor_conf: str = "" + + @property + def rank(self) -> int: + return self.row.rank + + @property + def start_sec(self) -> float: + return self.row.start_sec + + @property + def end_sec(self) -> float: + return self.row.end_sec + + def is_failure(self) -> bool: + n1 = self.row.n1.strip() + return bool(n1) and n1.startswith("(") + + def doctor_line(self) -> str: + if self.doctor_name or self.doctor_id: + conf = self.doctor_conf.strip() + conf_s = f", conf={conf}" if conf else "" + name = self.doctor_name or self.doctor_id + return f"段医生: {name} (id={self.doctor_id}{conf_s})" + return "" + + +def _parse_body_lines(body_lines: list[str]) -> list[E2eRow]: + rows: list[E2eRow] = [] + for i, line in enumerate(body_lines, start=2): + if not line.strip(): + continue + parts_line = line.split("\t") + while len(parts_line) < 12: + parts_line.append("") + parts_line = parts_line[:12] + try: + rank = int(parts_line[0]) + s = float(parts_line[1]) + e = float(parts_line[2]) + except ValueError as ex: + raise ValueError(f"第{i}行解析失败: {line[:80]}...") from ex + rows.append( + E2eRow( + rank=rank, + start_sec=s, + end_sec=e, + id1=parts_line[3], + n1=parts_line[4], + c1=parts_line[5], + id2=parts_line[6], + n2=parts_line[7], + c2=parts_line[8], + id3=parts_line[9], + n3=parts_line[10], + c3=parts_line[11], + ) + ) + return rows + + +def parse_result_tsv(tsv_path: Path) -> tuple[list[SegmentVis], str | None]: + """ + 解析离线 12 列或推流 15 列 TSV;提取末尾「医生信息:」汇总(离线)。 + """ + text = tsv_path.read_text(encoding="utf-8") + lines = [ln for ln in text.splitlines() if ln.strip()] + if not lines: + return [], None + + doctor_summary: str | None = None + body_lines: list[str] = [] + header = lines[0] + if header.startswith("医生信息:"): + raise ValueError(f"TSV 首行应为表头: {tsv_path}") + + col = {name.strip(): i for i, name in enumerate(header.split("\t"))} + has_doctor_cols = all(k in col for k in ("doctor_id", "doctor_name", "doctor_conf")) + + for ln in lines[1:]: + if ln.startswith("医生信息:"): + doctor_summary = ln.split(":", 1)[-1].strip() + continue + body_lines.append(ln) + + rows_12 = _parse_body_lines(body_lines) + segments: list[SegmentVis] = [] + for i, er in enumerate(rows_12): + doc_id, doc_name, doc_conf = "", "", "" + if has_doctor_cols and i < len(body_lines): + parts = body_lines[i].split("\t") + + def _get(name: str) -> str: + idx = col.get(name) + if idx is None or idx >= len(parts): + return "" + return parts[idx].strip() + + doc_id = _get("doctor_id") + doc_name = _get("doctor_name") + doc_conf = _get("doctor_conf") + segments.append( + SegmentVis( + row=er, + doctor_id=doc_id, + doctor_name=doc_name, + doctor_conf=doc_conf, + ) + ) + return segments, doctor_summary + + +def find_active_segment(segments: list[SegmentVis], t_sec: float) -> SegmentVis | None: + """多段重叠时取 rank 最小的一段。""" + active = [s for s in segments if s.start_sec <= t_sec < s.end_sec] + if not active: + return None + return min(active, key=lambda s: s.rank) diff --git a/src/basket_segmenter.py b/src/basket_segmenter.py index e757f7c..2d733da 100644 --- a/src/basket_segmenter.py +++ b/src/basket_segmenter.py @@ -11,7 +11,7 @@ from ultralytics import YOLO from action_trigger_logic import ActionTriggerLogic, resolve_contact_iou_thresholds from pipeline.hand_roi_merge import bbox_iou_xyxy -from run_segments_consumable_vote import collect_hand_boxes +from hand_detector import detect_hands_xyxy def _roi_xyxy_from_select(x: int, y: int, w: int, h: int) -> list[float]: @@ -506,8 +506,13 @@ def scan_contact_segments( continue t_sec = float(cap.get(cv2.CAP_PROP_POS_MSEC)) / 1000.0 - r0 = model.predict(frame, conf=det_conf, imgsz=imgsz_det, verbose=False, **predict_kw)[0] - hands = collect_hand_boxes(model, r0.boxes) if r0.boxes else [] + hands = detect_hands_xyxy( + model, + frame, + det_conf=det_conf, + imgsz_det=imgsz_det, + predict_kw=predict_kw, + ) event_t = trigger.process_frame(t_sec, hands, basket) if event_t is not None: starts.append(event_t) diff --git a/src/config.py b/src/config.py index 8296fb6..be97501 100644 --- a/src/config.py +++ b/src/config.py @@ -70,6 +70,11 @@ def load_run_config(pack_root: Path, config_path: Path) -> Namespace: tear_raw = w.get("tear") p1 = data.get("phase1", {}) tm = data.get("tear_merge", {}) + hd = data.get("hand", {}) + hand_backend = str(hd.get("backend", "yolo")).strip().lower() + hand_mp_task_raw = hd.get( + "mediapipe_task", "weights/hand_landmarker.task" + ) return Namespace( video=_rel(pack_root, io["video"]), @@ -82,6 +87,19 @@ def load_run_config(pack_root: Path, config_path: Path) -> Namespace: python=python_exe, actionformer_ckpt=_rel(pack_root, actionformer_raw) if actionformer_raw else None, hand_model=_rel(pack_root, w["hand"]), + hand_backend=hand_backend, + hand_mediapipe_task=_rel(pack_root, hand_mp_task_raw), + hand_mediapipe_num_hands=int(hd.get("mediapipe_num_hands", 2)), + hand_mediapipe_min_detection_confidence=float( + hd.get("mediapipe_min_detection_confidence", 0.3) + ), + hand_mediapipe_min_presence_confidence=float( + hd.get("mediapipe_min_presence_confidence", 0.3) + ), + hand_mediapipe_min_tracking_confidence=float( + hd.get("mediapipe_min_tracking_confidence", 0.3) + ), + hand_mediapipe_bbox_margin=float(hd.get("mediapipe_bbox_margin", 0.05)), goodbad_model=_rel(pack_root, w["goodbad"]), haocai_model=_rel(pack_root, w["haocai"]), tear_model=_rel(pack_root, tear_raw) if tear_raw else None, @@ -141,6 +159,7 @@ def load_run_config(pack_root: Path, config_path: Path) -> Namespace: doctor_identity_segment_sample_fps=float( did.get("segment_sample_fps", did.get("sample_fps", 3.0)) ), + doctor_identity_segment_window_sec=float(did.get("segment_window_sec", 3.0)), doctor_identity_pad_frac=float(did.get("pad_frac", 0.15)), basket_det_conf=float(bk.get("det_conf", p2["det_conf"])), basket_contact_iou_threshold=legacy_contact_iou, diff --git a/src/doctor_identity.py b/src/doctor_identity.py index b951eb8..8098ef0 100644 --- a/src/doctor_identity.py +++ b/src/doctor_identity.py @@ -4,6 +4,7 @@ from __future__ import annotations import importlib.util import sys from argparse import Namespace +from collections import Counter from pathlib import Path from typing import Any @@ -112,13 +113,15 @@ class DoctorIdentityService: getattr(self.args, "doctor_identity_sample_fps", 3.0), ) ) + win_sec = float(getattr(self.args, "doctor_identity_segment_window_sec", 3.0)) + doc_t0, doc_t1 = segment_doctor_infer_window(start_sec, end_sec, win_sec) if use_file_source and video_path is not None and video_path.is_file(): best_crop = mod.pick_best_person_crop_in_window( video_path, self._landmarker, - start_sec, - end_sec, + doc_t0, + doc_t1, sample_fps, pad_frac, ) @@ -126,8 +129,8 @@ class DoctorIdentityService: best_crop = mod.pick_best_person_crop_from_frames( frames, self._landmarker, - start_sec, - end_sec, + doc_t0, + doc_t1, pad_frac, ) else: @@ -175,12 +178,57 @@ class DoctorIdentityService: return f"识别失败({exc})" +def segment_doctor_infer_window( + start_sec: float, + end_sec: float, + window_sec: float, +) -> tuple[float, float]: + """段内医生 ReID 采样窗:全长不超过 window_sec,相对段窗口居中。""" + t0 = float(start_sec) + t1 = float(end_sec) + dur = t1 - t0 + w = max(0.1, float(window_sec)) + if dur <= w + 1e-9: + return t0, t1 + mid = (t0 + t1) * 0.5 + half = w * 0.5 + return mid - half, mid + half + + def stream_doctor_enabled(args: Namespace) -> bool: return bool(getattr(args, "doctor_identity_enabled", True)) and bool( getattr(args, "doctor_identity_stream_enabled", True) ) +def vote_doctor_from_segment_results( + segment_results: list[dict[str, Any] | None], +) -> str: + """ + 对各段医生识别结果投票:按 doctor_id 众数取 Top1,展示名与置信度取该 id 下最高 conf 的一段。 + """ + ok = [r for r in segment_results if r is not None and r.get("ok")] + if not ok: + reasons = [ + str(r.get("reason", "")) + for r in segment_results + if r is not None and not r.get("ok") + ] + hint = reasons[0] if reasons else "无有效段" + return f"识别失败({hint})" + + counts = Counter(str(r["doctor_id"]) for r in ok) + top_id, _ = counts.most_common(1)[0] + candidates = [r for r in ok if str(r["doctor_id"]) == top_id] + best = max(candidates, key=lambda r: float(r.get("doctor_conf") or 0.0)) + conf = float(best.get("doctor_conf") or 0.0) + suffix = " [低置信度]" if best.get("low_confidence") else "" + name = best.get("doctor_name") or "" + if name: + return f"{name} (id={top_id}, conf={conf:.4f}){suffix}" + return f"doctor_id={top_id} (conf={conf:.4f}){suffix}" + + def infer_doctor_text_offline(args: Namespace, video_path: Path) -> str: """离线入口:校验资源后返回医生信息文本。""" if not bool(getattr(args, "doctor_identity_enabled", True)): diff --git a/src/hand_detector.py b/src/hand_detector.py new file mode 100644 index 0000000..97c136e --- /dev/null +++ b/src/hand_detector.py @@ -0,0 +1,68 @@ +"""手部检测统一入口:YOLO hand_detect 或 MediaPipe Hands。""" +from __future__ import annotations + +from argparse import Namespace +from pathlib import Path +from typing import Any + +import numpy as np + + +def create_hand_detector(args: Namespace) -> Any: + backend = str(getattr(args, "hand_backend", "yolo")).strip().lower() + if backend == "mediapipe": + from mediapipe_hand_detector import MediapipeHandDetector + + task = Path(getattr(args, "hand_mediapipe_task")) + return MediapipeHandDetector( + task, + num_hands=int(getattr(args, "hand_mediapipe_num_hands", 2)), + min_detection_confidence=float( + getattr(args, "hand_mediapipe_min_detection_confidence", 0.3) + ), + min_presence_confidence=float( + getattr(args, "hand_mediapipe_min_presence_confidence", 0.3) + ), + min_tracking_confidence=float( + getattr(args, "hand_mediapipe_min_tracking_confidence", 0.3) + ), + bbox_margin=float(getattr(args, "hand_mediapipe_bbox_margin", 0.05)), + ) + from ultralytics import YOLO + + return YOLO(str(args.hand_model)) + + +def detect_hands_xyxy( + det: Any, + frame: np.ndarray, + *, + det_conf: float = 0.6, + imgsz_det: int = 640, + predict_kw: dict[str, Any] | None = None, +) -> list[list[float]]: + if hasattr(det, "detect_xyxy"): + return det.detect_xyxy(frame) + from run_segments_consumable_vote import collect_hand_boxes + + r0 = det.predict( + frame, + conf=det_conf, + imgsz=imgsz_det, + verbose=False, + **(predict_kw or {}), + )[0] + return collect_hand_boxes(det, r0.boxes) if r0.boxes else [] + + +def validate_hand_assets(args: Namespace) -> tuple[bool, str]: + backend = str(getattr(args, "hand_backend", "yolo")).strip().lower() + if backend == "mediapipe": + p = Path(getattr(args, "hand_mediapipe_task")) + if not p.is_file(): + return False, f"缺少 MediaPipe 手部模型: {p}" + return True, "MediaPipe Hands" + p = Path(getattr(args, "hand_model")) + if not p.is_file(): + return False, f"缺少手部检测权重: {p}" + return True, "YOLO hand_detect" diff --git a/src/mediapipe_hand_detector.py b/src/mediapipe_hand_detector.py new file mode 100644 index 0000000..9c05bfc --- /dev/null +++ b/src/mediapipe_hand_detector.py @@ -0,0 +1,113 @@ +"""MediaPipe Hand Landmarker:输出与 YOLO 手检兼容的 xyxy 框列表。""" +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import cv2 +import mediapipe as mp +import numpy as np +from mediapipe.tasks import python +from mediapipe.tasks.python import vision + + +class MediapipeHandDetector: + """每只手一个外接框;供 detect_hands_xyxy / 双手 union 使用。""" + + names: dict[int, str] = {0: "hand"} + + def __init__( + self, + task_path: Path, + *, + num_hands: int = 2, + min_detection_confidence: float = 0.3, + min_presence_confidence: float = 0.3, + min_tracking_confidence: float = 0.3, + bbox_margin: float = 0.05, + ) -> None: + task_path = Path(task_path).resolve() + if not task_path.is_file(): + raise FileNotFoundError(f"MediaPipe 手部模型不存在: {task_path}") + + self.task_path = task_path + self.num_hands = int(num_hands) + self.bbox_margin = float(bbox_margin) + base_options = python.BaseOptions(model_asset_path=str(task_path)) + options = vision.HandLandmarkerOptions( + base_options=base_options, + running_mode=vision.RunningMode.IMAGE, + num_hands=self.num_hands, + min_hand_detection_confidence=float(min_detection_confidence), + min_hand_presence_confidence=float(min_presence_confidence), + min_tracking_confidence=float(min_tracking_confidence), + ) + self._landmarker = vision.HandLandmarker.create_from_options(options) + + def close(self) -> None: + if self._landmarker is not None: + self._landmarker.close() + self._landmarker = None # type: ignore[assignment] + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + def detect_xyxy(self, frame_bgr: np.ndarray) -> list[list[float]]: + if frame_bgr is None or frame_bgr.size == 0: + return [] + h, w = frame_bgr.shape[:2] + rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb) + result = self._landmarker.detect(mp_image) + margin = self.bbox_margin + boxes: list[list[float]] = [] + for lms in result.hand_landmarks: + xs = [lm.x for lm in lms] + ys = [lm.y for lm in lms] + x1 = max(0.0, (min(xs) - margin) * w) + y1 = max(0.0, (min(ys) - margin) * h) + x2 = min(float(w), (max(xs) + margin) * w) + y2 = min(float(h), (max(ys) + margin) * h) + if x2 > x1 + 1.0 and y2 > y1 + 1.0: + boxes.append([x1, y1, x2, y2]) + return boxes + + def predict( + self, + frame: np.ndarray, + conf: float | None = None, + imgsz: int | None = None, + verbose: bool = False, + **kwargs: Any, + ) -> list[_YoloLikeResult]: + del conf, imgsz, verbose, kwargs + return [_YoloLikeResult(self.detect_xyxy(frame))] + + +class _YoloLikeResult: + def __init__(self, hands: list[list[float]]) -> None: + self.boxes = _YoloLikeBoxes(hands) if hands else None + + +class _YoloLikeBoxes: + def __init__(self, hands: list[list[float]]) -> None: + self._hands = hands + + def __iter__(self): + for xyxy in self._hands: + yield _YoloLikeBox(xyxy) + + def __len__(self) -> int: + return len(self._hands) + + +class _YoloLikeBox: + def __init__(self, xyxy: list[float]) -> None: + import torch + + self.cls = torch.tensor([0.0]) + self.conf = torch.tensor([1.0]) + self.xyxy = torch.tensor([xyxy], dtype=torch.float32) diff --git a/src/orchestrator.py b/src/orchestrator.py index 8460f89..5424338 100644 --- a/src/orchestrator.py +++ b/src/orchestrator.py @@ -42,8 +42,13 @@ def _resolve_allowed_names(args: Namespace, excel_path: Path) -> list[str] | Non def _validate_phase2_weights(args: Namespace, *, require_actionformer: bool) -> bool: + from hand_detector import validate_hand_assets + + ok, msg = validate_hand_assets(args) + if not ok: + log(msg) + return False checks: list[tuple[Any, str]] = [ - (args.hand_model, "手部检测"), (args.goodbad_model, "好坏帧"), (args.haocai_model, "耗材分类"), ] @@ -146,7 +151,9 @@ class PipelineManager: predict_kw["half"] = True log("Phase2:加载 YOLO(手 / 好坏帧 / 耗材)…") - det = YOLO(str(args.hand_model)) + from hand_detector import create_hand_detector + + det = create_hand_detector(args) gb = YOLO(str(args.goodbad_model)) cls_m = YOLO(str(args.haocai_model)) diff --git a/src/segments_offline_orchestrator.py b/src/segments_offline_orchestrator.py index a6a58e7..5fdc2c0 100644 --- a/src/segments_offline_orchestrator.py +++ b/src/segments_offline_orchestrator.py @@ -71,7 +71,9 @@ def run_segments_offline_pipeline(args: Namespace) -> int: predict_kw["half"] = True log("[segments-offline] 加载 YOLO(手 / 好坏帧 / 耗材)…") - det = YOLO(str(args.hand_model)) + from hand_detector import create_hand_detector + + det = create_hand_detector(args) gb = YOLO(str(args.goodbad_model)) cls_m = YOLO(str(args.haocai_model)) hc = HaocaiOnlyClassifier( diff --git a/src/stream_basket_session.py b/src/stream_basket_session.py index e7cac85..e014a22 100644 --- a/src/stream_basket_session.py +++ b/src/stream_basket_session.py @@ -5,10 +5,8 @@ from dataclasses import dataclass from typing import Any, Callable import numpy as np -from ultralytics import YOLO - from action_trigger_logic import ActionTriggerLogic -from run_segments_consumable_vote import collect_hand_boxes +from hand_detector import detect_hands_xyxy from stream_frame_buffer import FrameRingBuffer @@ -44,7 +42,7 @@ class StreamBasketSession: def __init__( self, basket_xyxy: list[float], - hand_model: YOLO, + hand_model: Any, trigger: ActionTriggerLogic, *, segment_start_offset_sec: float = 1.0, @@ -84,14 +82,13 @@ class StreamBasketSession: t = float(t_sec) self._current_t = t - r0 = self.hand_model.predict( + hands = detect_hands_xyxy( + self.hand_model, frame, - conf=self.det_conf, - imgsz=self.imgsz_det, - verbose=False, - **self.predict_kw, - )[0] - hands = collect_hand_boxes(self.hand_model, r0.boxes) if r0.boxes else [] + det_conf=self.det_conf, + imgsz_det=self.imgsz_det, + predict_kw=self.predict_kw, + ) start_t = self.trigger.process_frame(t, hands, self.basket_xyxy) self.buffer.append(t, frame) diff --git a/src/stream_orchestrator.py b/src/stream_orchestrator.py index f0ee797..58da2b5 100644 --- a/src/stream_orchestrator.py +++ b/src/stream_orchestrator.py @@ -24,7 +24,11 @@ from basket_segmenter import ( _select_basket_roi_tkinter, save_basket_roi_json, ) -from doctor_identity import DoctorIdentityService, stream_doctor_enabled +from doctor_identity import ( + DoctorIdentityService, + stream_doctor_enabled, + vote_doctor_from_segment_results, +) from pack_utils import log, resolve_allowed_class_idx from stream_basket_session import CachedClip, StreamBasketSession @@ -42,12 +46,71 @@ _HAOCAI_COLS = [ "top3_name", "top3_conf", ] -_DOCTOR_COLS = ["doctor_id", "doctor_name", "doctor_conf"] +def _stream_segment_window_sec(args: Namespace) -> float: + """段识别窗口时长 end_offset - start_offset(当前配置为 7s)。""" + end_off = float( + getattr( + args, + "stream_segment_end_offset_sec", + getattr(args, "basket_segment_end_offset_sec", 10.0), + ) + ) + start_off = float( + getattr( + args, + "stream_segment_start_offset_sec", + getattr(args, "basket_segment_start_offset_sec", 3.0), + ) + ) + return max(0.0, end_off - start_off) + + +def _finalize_stream_output( + out_path: Path, + body_lines: list[str], + doctor_votes: list[dict[str, Any] | None], + *, + args: Namespace, + collect_doctor: bool, +) -> int: + """写 TSV:末段时长不足窗口则丢弃;医生信息投票后追加末行。""" + min_dur = _stream_segment_window_sec(args) + dropped_last = False + if body_lines: + parts = body_lines[-1].split("\t") + if len(parts) >= 3: + t0, t1 = float(parts[1]), float(parts[2]) + if t1 - t0 < min_dur - 1e-9: + log( + f"[stream] 末段 [{t0:.3f},{t1:.3f}] 时长 {t1 - t0:.3f}s < {min_dur:g}s," + "判为误触已丢弃" + ) + body_lines = body_lines[:-1] + if doctor_votes: + doctor_votes = doctor_votes[:-1] + dropped_last = True + + lines: list[str] = ["\t".join(_HAOCAI_COLS)] + body_lines + if collect_doctor: + summary = vote_doctor_from_segment_results(doctor_votes) + lines.append(f"医生信息:{summary}") + log(f"[stream] 医生投票汇总:{summary}") + + out_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + n = len(body_lines) + if dropped_last: + log(f"[stream] 有效段数 {n}(已丢弃末段误触)") + return n def _validate_stream_weights(args: Namespace) -> bool: + from hand_detector import validate_hand_assets + + ok, hand_lab = validate_hand_assets(args) + if not ok: + log(hand_lab) + return False for p, lab in ( - (args.hand_model, "手部检测"), (args.goodbad_model, "好坏帧"), (args.haocai_model, "耗材分类"), ): @@ -171,7 +234,7 @@ def _process_one_clip( rank: int, clip: CachedClip, *, - det: YOLO, + det: Any, hc: HaocaiOnlyClassifier, infer_cap: cv2.VideoCapture | None, use_file_infer: bool, @@ -182,10 +245,9 @@ def _process_one_clip( allowed_idx: frozenset[int] | None, predict_kw: dict[str, Any], product_map: dict[str, str], - out_path: Path, doctor_svc: DoctorIdentityService | None, - include_doctor_cols: bool, -) -> None: + collect_doctor_vote: bool, +) -> tuple[str, dict[str, Any] | None]: log( f"[stream] 识别 rank={rank} [{clip.start_sec:.3f},{clip.end_sec:.3f}] " f"({len(clip.frames)} 帧)…" @@ -239,14 +301,14 @@ def _process_one_clip( clip.end_sec, info, product_map, - legacy_12_col=bool(args.legacy_12_col_only), - include_doctor_cols=include_doctor_cols, - doctor=doc, + legacy_12_col=True, + include_doctor_cols=False, + doctor=None, ) - with out_path.open("a", encoding="utf-8") as f: - f.write(line + "\n") - _log_doctor_result(rank, doc) + if collect_doctor_vote: + _log_doctor_result(rank, doc) log(f"[stream] rank={rank} 已写入") + return line, doc if collect_doctor_vote else None def _maybe_free_gpu() -> None: @@ -304,7 +366,7 @@ def _use_file_infer_for_stream(args: Namespace, *, is_file: bool) -> bool: def _infer_clip( clip: CachedClip, *, - det: YOLO, + det: Any, hc: HaocaiOnlyClassifier, cap: cv2.VideoCapture | None, use_file_infer: bool, @@ -370,8 +432,11 @@ class StreamBasketOrchestrator: if args.half: predict_kw["half"] = True - log("[stream] 加载 YOLO(手 / 好坏帧 / 耗材)…") - det = YOLO(str(args.hand_model)) + from hand_detector import create_hand_detector + + hand_lab = str(getattr(args, "hand_backend", "yolo")) + log(f"[stream] 加载手部检测({hand_lab})与 YOLO(好坏帧 / 耗材)…") + det = create_hand_detector(args) gb = YOLO(str(args.goodbad_model)) cls_m = YOLO(str(args.haocai_model)) cls_names = cls_m.names @@ -393,15 +458,18 @@ class StreamBasketOrchestrator: else: log("[stream] 白名单已关闭,使用全 41 类") - include_doctor_cols = stream_doctor_enabled(args) + collect_doctor_vote = stream_doctor_enabled(args) doctor_svc: DoctorIdentityService | None = None - if include_doctor_cols: + if collect_doctor_vote: try: doctor_svc = DoctorIdentityService(args) - log("[stream] 医生身份识别已启用(段内与耗材并行)") + win = float(getattr(args, "doctor_identity_segment_window_sec", 3.0)) + log( + f"[stream] 医生身份识别已启用(每段居中 {win:g}s 采样 + 末行投票汇总)" + ) except Exception as exc: # noqa: BLE001 - log(f"[stream] 医生识别初始化失败,本 run 不写入医生列: {exc}") - include_doctor_cols = False + log(f"[stream] 医生识别初始化失败,本 run 不写入医生汇总: {exc}") + collect_doctor_vote = False doctor_svc = None cap = cv2.VideoCapture(source) @@ -484,10 +552,8 @@ class StreamBasketOrchestrator: fallback = str(getattr(args, "stream_infer_fallback", "cache")) log(f"[stream] 段内识别: JPEG 缓存帧(infer_fallback={fallback})") - header_cols = list(_HAOCAI_COLS) - if include_doctor_cols: - header_cols.extend(_DOCTOR_COLS) - out_path.write_text("\t".join(header_cols) + "\n", encoding="utf-8") + body_lines: list[str] = [] + doctor_votes: list[dict[str, Any] | None] = [] rank = 0 frame_idx = 0 @@ -495,7 +561,7 @@ class StreamBasketOrchestrator: nonlocal rank for clip in session.poll_ready_clips(): rank += 1 - _process_one_clip( + line, doc = _process_one_clip( rank, clip, det=det, @@ -509,10 +575,12 @@ class StreamBasketOrchestrator: allowed_idx=allowed_idx, predict_kw=predict_kw, product_map=product_map, - out_path=out_path, doctor_svc=doctor_svc, - include_doctor_cols=include_doctor_cols, + collect_doctor_vote=collect_doctor_vote, ) + body_lines.append(line) + if collect_doctor_vote: + doctor_votes.append(doc) session.push_frame(t0, first) process_ready() @@ -542,7 +610,7 @@ class StreamBasketOrchestrator: finally: for clip in session.poll_ready_clips(): rank += 1 - _process_one_clip( + line, doc = _process_one_clip( rank, clip, det=det, @@ -556,18 +624,29 @@ class StreamBasketOrchestrator: allowed_idx=allowed_idx, predict_kw=predict_kw, product_map=product_map, - out_path=out_path, doctor_svc=doctor_svc, - include_doctor_cols=include_doctor_cols, + collect_doctor_vote=collect_doctor_vote, ) + body_lines.append(line) + if collect_doctor_vote: + doctor_votes.append(doc) cap.release() if infer_cap is not None: infer_cap.release() if doctor_svc is not None: doctor_svc.close() + if hasattr(det, "close"): + det.close() - log(f"[stream] 结束,共 {rank} 段,结果: {out_path}") - return 0 if rank > 0 or is_file else 0 + n_written = _finalize_stream_output( + out_path, + body_lines, + doctor_votes, + args=args, + collect_doctor=collect_doctor_vote, + ) + log(f"[stream] 结束,共 {n_written} 段,结果: {out_path}") + return 0 if n_written > 0 or is_file else 0 def run_stream_pipeline(args: Namespace) -> int: diff --git a/tests/test_doctor_vote.py b/tests/test_doctor_vote.py new file mode 100644 index 0000000..43ba42f --- /dev/null +++ b/tests/test_doctor_vote.py @@ -0,0 +1,46 @@ +"""医生段内投票汇总单元测试。""" +from __future__ import annotations + +import sys +import unittest +from pathlib import Path + +PACK_ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(PACK_ROOT / "src")) + +from doctor_identity import ( # noqa: E402 + segment_doctor_infer_window, + vote_doctor_from_segment_results, +) + + +class TestDoctorWindow(unittest.TestCase): + def test_center_3s_in_7s_segment(self) -> None: + t0, t1 = segment_doctor_infer_window(13.0, 20.0, 3.0) + self.assertAlmostEqual(t0, 15.0, places=3) + self.assertAlmostEqual(t1, 18.0, places=3) + + def test_short_segment_uses_full(self) -> None: + t0, t1 = segment_doctor_infer_window(5.0, 7.0, 3.0) + self.assertAlmostEqual(t0, 5.0, places=3) + self.assertAlmostEqual(t1, 7.0, places=3) + + +class TestDoctorVote(unittest.TestCase): + def test_majority_top1(self) -> None: + votes = [ + {"ok": True, "doctor_id": "A", "doctor_name": "张三", "doctor_conf": 0.9}, + {"ok": True, "doctor_id": "B", "doctor_name": "李四", "doctor_conf": 0.95}, + {"ok": True, "doctor_id": "A", "doctor_name": "张三", "doctor_conf": 0.8}, + ] + text = vote_doctor_from_segment_results(votes) + self.assertIn("张三", text) + self.assertIn("id=A", text) + + def test_no_ok_segments(self) -> None: + text = vote_doctor_from_segment_results([{"ok": False, "reason": "无脸"}]) + self.assertIn("识别失败", text) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_stream_basket.py b/tests/test_stream_basket.py index fc037b5..c7cb7ec 100644 --- a/tests/test_stream_basket.py +++ b/tests/test_stream_basket.py @@ -55,17 +55,17 @@ class TestStreamBasketSession(unittest.TestCase): [0, 0, 100, 100], hand_model, trigger, - segment_start_offset_sec=2.0, - segment_end_offset_sec=8.0, + segment_start_offset_sec=3.0, + segment_end_offset_sec=10.0, min_segment_sec=4.0, - ring_buffer_sec=10.0, + ring_buffer_sec=12.0, fps=25.0, cache_max_width=640, ) frame = np.zeros((64, 64, 3), dtype=np.uint8) contact_t = None - for i in range(260): + for i in range(350): t = i * 0.04 start = session.push_frame(t, frame) if start is not None and contact_t is None: @@ -75,8 +75,8 @@ class TestStreamBasketSession(unittest.TestCase): clips = session.poll_ready_clips() self.assertGreaterEqual(len(clips), 1) clip = clips[0] - self.assertAlmostEqual(clip.start_sec, 4.0, places=2) - self.assertAlmostEqual(clip.end_sec, 10.0, places=2) + self.assertAlmostEqual(clip.start_sec, 5.0, places=2) + self.assertAlmostEqual(clip.end_sec, 12.0, places=2) self.assertGreater(len(clip.frames), 0) diff --git a/tests/test_visualize_parse.py b/tests/test_visualize_parse.py new file mode 100644 index 0000000..f0a236c --- /dev/null +++ b/tests/test_visualize_parse.py @@ -0,0 +1,61 @@ +"""visualize_pipeline TSV 解析单元测试(无需 GPU)。""" +from __future__ import annotations + +import sys +from pathlib import Path + +PACK_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(PACK_ROOT / "scripts")) + +from visualize_tsv import parse_result_tsv # noqa: E402 + + +def test_parse_offline_12col_with_doctor_summary(tmp_path: Path) -> None: + tsv = tmp_path / "r.txt" + tsv.write_text( + "rank\tstart_sec\tend_sec\tproduct_id_top1\ttop1_name\ttop1_conf\t" + "product_id_top2\ttop2_name\ttop2_conf\tproduct_id_top3\ttop3_name\ttop3_conf\n" + "1\t1.0\t5.0\tP1\t手套\t0.9\t\t\t\t\t\t\n" + "医生信息:张三 (id=D01, conf=0.91)\n", + encoding="utf-8", + ) + segs, doc = parse_result_tsv(tsv) + assert len(segs) == 1 + assert segs[0].row.n1 == "手套" + assert doc is not None and "张三" in doc + + +def test_parse_stream_15col(tmp_path: Path) -> None: + tsv = tmp_path / "s.txt" + header = "\t".join( + [ + "rank", "start_sec", "end_sec", "product_id_top1", "top1_name", "top1_conf", + "product_id_top2", "top2_name", "top2_conf", "product_id_top3", "top3_name", "top3_conf", + "doctor_id", "doctor_name", "doctor_conf", + ] + ) + row = [ + "1", "2.0", "8.0", "", + "(无有效耗材帧:好帧/白名单/耗材置信度未全部满足)", "", + "", "", "", "", "", "", "D2", "李四", "0.77", + ] + tsv.write_text(header + "\n" + "\t".join(row) + "\n", encoding="utf-8") + segs, doc = parse_result_tsv(tsv) + assert len(segs) == 1 + assert segs[0].is_failure() + assert segs[0].doctor_name == "李四" + assert doc is None + + +def test_parse_failure_hud_text(tmp_path: Path) -> None: + tsv = tmp_path / "f.txt" + tsv.write_text( + "rank\tstart_sec\tend_sec\tproduct_id_top1\ttop1_name\ttop1_conf\t" + "product_id_top2\ttop2_name\ttop2_conf\tproduct_id_top3\ttop3_name\ttop3_conf\n" + "1\t0.5\t3.0\t\t(无有效耗材帧:好帧/白名单/耗材置信度未全部满足)\t\t" + "\t\t\t\t\t\t\n", + encoding="utf-8", + ) + segs, _ = parse_result_tsv(tsv) + assert segs[0].is_failure() + assert "无有效耗材帧" in segs[0].row.n1