From 0d27ae415e0b9ca847099987d94de60fa062c40c Mon Sep 17 00:00:00 2001 From: hsz <2091085305@qq.com> Date: Wed, 3 Jun 2026 14:46:16 +0800 Subject: [PATCH] doctor_identify --- README.md | 12 +- configs/default_config.yaml | 2 + .../infer_doctor_from_video.py | 185 +++++-- .../train_reid_contrastive.py | 458 ++++++++++++++++++ setup.sh | 1 + src/config.py | 4 + src/doctor_identity.py | 202 ++++++++ src/orchestrator.py | 63 +-- src/stream_orchestrator.py | 232 ++++++--- 9 files changed, 991 insertions(+), 168 deletions(-) create mode 100644 doctor_identity_package/train_reid_contrastive.py create mode 100644 src/doctor_identity.py diff --git a/README.md b/README.md index ce8510e..31b73f9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # 手术室耗材篮子识别包(离线 + 推流) -段内流程:**手检(≥2 手 union)→ 好坏帧门控 → 耗材分类**;离线另含**医生识别**。 +段内流程:**手检(≥2 手 union)→ 好坏帧门控 → 耗材分类**;**推流**在每个接触段内与耗材**并行**医生识别;**离线**在全片结束后追加一行医生信息。 与 `configs/default_config.yaml` 当前参数一致(`imgsz_det: 1920`、`contact+1~+6` 等)。 @@ -27,7 +27,7 @@ pip install -r requirements.txt | 脚本 | 用途 | |------|------| | `main_basket.py` | **离线**:全片篮子接触分段 → Phase2 → gap 合并 → 医生识别 | -| `main_basket_stream.py` | **推流/本地 MP4 模拟推流**:逐帧触发 → 段内识别 → 实时写 TSV | +| `main_basket_stream.py` | **推流/本地 MP4 模拟推流**:逐帧触发 → 段内耗材+医生并行 → 实时写 15 列 TSV | | `main_segments_offline.py` | 按 TSV 时间段对离线 MP4 重跑段内识别(校验用) | ## 1. 离线跑视频 @@ -59,6 +59,7 @@ python main_basket_stream.py \ - 本地 MP4:`stream.infer_source: file` → 段内**回源 4K**(与离线一致) - 真 RTSP:无法 seek 时回退 JPEG 缓存(`cache_max_width: 1920`) +- 医生识别:`doctor_identity.stream_enabled: true` 时,每段并行识别,TSV 追加 `doctor_id` / `doctor_name` / `doctor_conf` ## 3. HEVC 视频 @@ -78,16 +79,19 @@ bash scripts/remux_hevc.sh /path/to/source.mp4 | `basket` | `iou_on: 0.03`,`confirm: 0.1`,`cooldown: 3`,窗口 contact+1~+6 | | `stream` | 段窗口与 basket 一致;`infer_source: file` | | `io` | `use_whitelist: false`(全 41 类) | +| `doctor_identity` | `enabled` / `stream_enabled`;推流用 `segment_sample_fps` 在段 `[start,end]` 内采样 | ## 模型文件(`weights/`) - `hand_detect.pt` — 手部检测 - `goodbad_frame.pt` — 好坏帧门控 - `haocai_classify.pt` — 耗材分类 +- `doctor_identity_package/doctor_info.pth` — 医生 ReID(需同目录 `train_reid_contrastive.py`) ## 输出格式 -12 列 TSV + 离线末尾一行 `医生信息:...`(推流无医生行)。 +- **推流**:15 列 TSV(12 列耗材 + `doctor_id` / `doctor_name` / `doctor_conf`),无末尾汇总行 +- **离线**:12 列 TSV + 末尾一行 `医生信息:...`(全片中间窗口识别) ## 目录结构 @@ -99,7 +103,7 @@ bash scripts/remux_hevc.sh /path/to/source.mp4 ├── configs/default_config.yaml ├── weights/ # 3 个 YOLO 权重 ├── input/视频中的商品信息表.xlsx -├── doctor_identity_package/ # 医生识别(仅离线) +├── doctor_identity_package/ # 医生识别(离线整片 + 推流段内) ├── src/ code/ # 编排与算法 ├── output/ # 结果输出目录 ├── setup.sh requirements.txt diff --git a/configs/default_config.yaml b/configs/default_config.yaml index d775bb6..4d7e1ec 100644 --- a/configs/default_config.yaml +++ b/configs/default_config.yaml @@ -56,12 +56,14 @@ output: doctor_identity: enabled: true + stream_enabled: true checkpoint: doctor_identity_package/doctor_info.pth labels_csv: doctor_identity_package/labels.csv pose_min_detection_confidence: 0.30 min_identity_confidence: 0.00 middle_seconds: 10.0 sample_fps: 3.0 + segment_sample_fps: 3.0 pad_frac: 0.15 # 篮子接触分段(main_basket.py / main_basket_stream.py) diff --git a/doctor_identity_package/infer_doctor_from_video.py b/doctor_identity_package/infer_doctor_from_video.py index 22f9b3b..8edcf60 100644 --- a/doctor_identity_package/infer_doctor_from_video.py +++ b/doctor_identity_package/infer_doctor_from_video.py @@ -139,6 +139,19 @@ def expand_bbox_with_padding( return nx1, ny1, nx2, ny2 +def sample_window_timestamps(t0: float, t1: float, sample_fps: float) -> list[float]: + """在 [t0, t1) 内按 sample_fps 均匀采样时间戳。""" + if t1 <= t0 or sample_fps <= 0: + return [] + step = 1.0 / sample_fps + ts: list[float] = [] + t = float(t0) + while t < t1 - 1e-6: + ts.append(t) + t += step + return ts + + def sample_middle_timestamps(duration_sec: float, middle_seconds: float, sample_fps: float) -> list[float]: if duration_sec <= 0 or middle_seconds <= 0 or sample_fps <= 0: return [] @@ -146,13 +159,37 @@ def sample_middle_timestamps(duration_sec: float, middle_seconds: float, sample_ half = middle_seconds / 2.0 t0 = max(0.0, center - half) t1 = min(duration_sec, center + half) - step = 1.0 / sample_fps - ts = [] - t = t0 - while t < t1 - 1e-6: - ts.append(t) - t += step - return ts + return sample_window_timestamps(t0, t1, sample_fps) + + +def _update_best_crop_from_frame( + frame: np.ndarray, + landmarker: PoseLandmarker, + pad_frac: float, + *, + best_area: int, + best_crop: np.ndarray | None, +) -> tuple[int, np.ndarray | None]: + h, w = frame.shape[:2] + rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + mp_img = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb) + res = landmarker.detect(mp_img) + if not res.pose_landmarks: + return best_area, best_crop + + for lmk in res.pose_landmarks: + box = bbox_from_normalized_pose_landmarks(w, h, lmk) + if box is None: + continue + ex1, ey1, ex2, ey2 = expand_bbox_with_padding(*box, w, h, pad_frac=pad_frac) + crop = frame[ey1:ey2, ex1:ex2] + if crop.size == 0: + continue + area = int((ex2 - ex1) * (ey2 - ey1)) + if area > best_area: + best_area = area + best_crop = crop.copy() + return best_area, best_crop def pick_best_person_crop( @@ -182,25 +219,9 @@ def pick_best_person_crop( ok, frame = cap.read() if not ok or frame is None: continue - h, w = frame.shape[:2] - rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - mp_img = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb) - res = landmarker.detect(mp_img) - if not res.pose_landmarks: - continue - - for lmk in res.pose_landmarks: - box = bbox_from_normalized_pose_landmarks(w, h, lmk) - if box is None: - continue - ex1, ey1, ex2, ey2 = expand_bbox_with_padding(*box, w, h, pad_frac=pad_frac) - crop = frame[ey1:ey2, ex1:ex2] - if crop.size == 0: - continue - area = int((ex2 - ex1) * (ey2 - ey1)) - if area > best_area: - best_area = area - best_crop = crop.copy() + best_area, best_crop = _update_best_crop_from_frame( + frame, landmarker, pad_frac, best_area=best_area, best_crop=best_crop + ) cap.release() if best_crop is None: @@ -208,6 +229,63 @@ def pick_best_person_crop( return best_crop +def pick_best_person_crop_in_window( + video_path: Path, + landmarker: PoseLandmarker, + start_sec: float, + end_sec: float, + sample_fps: float, + pad_frac: float, +) -> np.ndarray: + """在视频 [start_sec, end_sec] 窗口内采样并取最大人体 crop。""" + timestamps = sample_window_timestamps(start_sec, end_sec, sample_fps) + if not timestamps: + raise RuntimeError("No valid timestamps in segment window.") + + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Cannot open video: {video_path}") + + best_area = -1 + best_crop: np.ndarray | None = None + try: + for ts in timestamps: + cap.set(cv2.CAP_PROP_POS_MSEC, ts * 1000.0) + ok, frame = cap.read() + if not ok or frame is None: + continue + best_area, best_crop = _update_best_crop_from_frame( + frame, landmarker, pad_frac, best_area=best_area, best_crop=best_crop + ) + finally: + cap.release() + + if best_crop is None: + raise RuntimeError("No person detected in segment window.") + return best_crop + + +def pick_best_person_crop_from_frames( + frames: list[tuple[float, np.ndarray]], + landmarker: PoseLandmarker, + start_sec: float, + end_sec: float, + pad_frac: float, +) -> np.ndarray: + """在缓存帧列表中按时间段筛选并取最大人体 crop(推流 RTSP / 缓存路径)。""" + best_area = -1 + best_crop: np.ndarray | None = None + for t_sec, frame in frames: + if t_sec < start_sec - 1e-6 or t_sec > end_sec + 1e-6: + continue + best_area, best_crop = _update_best_crop_from_frame( + frame, landmarker, pad_frac, best_area=best_area, best_crop=best_crop + ) + if best_crop is None: + raise RuntimeError("No person detected in cached segment frames.") + return best_crop + + def build_label_to_pid(pid_to_label: dict) -> dict[int, str]: label_to_pid: dict[int, str] = {} for raw_pid, label in pid_to_label.items(): @@ -233,22 +311,8 @@ def load_name_mapping(labels_csv: Path) -> dict[str, str]: return mapping -def run_inference(crop_bgr: np.ndarray, checkpoint_path: Path) -> tuple[str, float]: - if not checkpoint_path.is_file(): - raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) - num_classes = int(ckpt["num_classes"]) - pid_to_label = ckpt.get("pid_to_label", {}) - if not isinstance(pid_to_label, dict): - raise RuntimeError("Checkpoint missing valid pid_to_label dict.") - - model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(device) - model.load_state_dict(ckpt["model_state"]) - model.eval() - - transform = transforms.Compose( +def _default_reid_transform() -> transforms.Compose: + return transforms.Compose( [ transforms.Resize((256, 128)), transforms.ToTensor(), @@ -258,6 +322,37 @@ def run_inference(crop_bgr: np.ndarray, checkpoint_path: Path) -> tuple[str, flo ), ] ) + + +def load_reid_model( + checkpoint_path: Path, + device: torch.device | None = None, +) -> tuple[ReIDEmbedModel, torch.device, dict[int, str], transforms.Compose]: + """加载 ReID 模型(推流多段复用,避免重复 torch.load)。""" + if not checkpoint_path.is_file(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + dev = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt = torch.load(checkpoint_path, map_location=dev, weights_only=False) + num_classes = int(ckpt["num_classes"]) + pid_to_label = ckpt.get("pid_to_label", {}) + if not isinstance(pid_to_label, dict): + raise RuntimeError("Checkpoint missing valid pid_to_label dict.") + + model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(dev) + model.load_state_dict(ckpt["model_state"]) + model.eval() + label_to_pid = build_label_to_pid(pid_to_label) + return model, dev, label_to_pid, _default_reid_transform() + + +def run_inference_preloaded( + crop_bgr: np.ndarray, + model: ReIDEmbedModel, + device: torch.device, + label_to_pid: dict[int, str], + transform: transforms.Compose, +) -> tuple[str, float]: crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB) inp = transform(Image.fromarray(crop_rgb)).unsqueeze(0).to(device) @@ -267,13 +362,17 @@ def run_inference(crop_bgr: np.ndarray, checkpoint_path: Path) -> tuple[str, flo pred_label = int(torch.argmax(probs, dim=1).item()) conf = float(probs[0, pred_label].item()) - label_to_pid = build_label_to_pid(pid_to_label) raw_pid = label_to_pid.get(pred_label) if raw_pid is None: raise RuntimeError(f"Predicted label {pred_label} not found in pid mapping.") return raw_pid, conf +def run_inference(crop_bgr: np.ndarray, checkpoint_path: Path) -> tuple[str, float]: + model, device, label_to_pid, transform = load_reid_model(checkpoint_path) + return run_inference_preloaded(crop_bgr, model, device, label_to_pid, transform) + + def main() -> int: args = parse_args() if not args.video.is_file(): diff --git a/doctor_identity_package/train_reid_contrastive.py b/doctor_identity_package/train_reid_contrastive.py new file mode 100644 index 0000000..3c92da4 --- /dev/null +++ b/doctor_identity_package/train_reid_contrastive.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +人体身份重识别(Person ReID)对比学习训练脚本(单机单文件)。 +- Dataset:Market-1501 风格文件名解析 pid/cam; +- PK 采样器:每个 batch 内 P 个身份 × 每张 K 样本; +- 模型:ImageNet ResNet50 骨干 + GAP + 512 维嵌入 + ID 分类头; +- 损失:Batch-Hard Triplet + ID(交叉熵)联合; +""" +from __future__ import annotations + +import argparse +import random +import re +from pathlib import Path +from typing import Iterator + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.utils.data import DataLoader, Dataset, Sampler +from torchvision import models, transforms + +# --------------------------------------------------------------------------- +# 文件名解析:例如 24502_c1_s1_00001.jpg → pid=24502,cam_id=1 +# --------------------------------------------------------------------------- + +_NAME_RE = re.compile( + r"^(?P\d+)_c(?P\d+)_", + flags=re.I, +) + + +def parse_market1501_style_name(stem: str) -> tuple[int | None, int | None]: + """从「不含后缀」的文件名 stem 中提取身份 ID、机位 ID。""" + m = _NAME_RE.match(stem) + if not m: + return None, None + return int(m.group("pid")), int(m.group("cam")) + + +class DoctorReIDDataset(Dataset): + """医生 ReID:解析 pid/cam,将 pid 重映射到 0..num_classes-1。""" + + def __init__(self, image_root: Path, augment: bool) -> None: + self.image_root = Path(image_root).resolve() + exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} + paths = sorted( + p + for p in self.image_root.rglob("*") + if p.is_file() and p.suffix.lower() in exts + ) + pid_raw_list: list[int] = [] + cam_raw_list: list[int] = [] + valid_paths: list[Path] = [] + for p in paths: + pid_raw, cam_raw = parse_market1501_style_name(p.stem) + if pid_raw is None: + continue + valid_paths.append(p) + pid_raw_list.append(pid_raw) + cam_raw_list.append(int(cam_raw)) + + unique_pids = sorted(set(pid_raw_list)) + self.pid_to_label = {pid: i for i, pid in enumerate(unique_pids)} + self.labels: list[int] = [self.pid_to_label[r] for r in pid_raw_list] + self.cam_ids: list[int] = cam_raw_list + self.paths = valid_paths + + if len(self.paths) == 0: + raise RuntimeError(f"目录下未发现有效图像: {self.image_root}") + + # Resize(128,256) 在 torchvision 中为 (height, width) → (256,128) 常见于 ReID。 + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + if augment: + self.transform = transforms.Compose( + [ + transforms.Resize((256, 128)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ColorJitter( + brightness=0.15, contrast=0.15, saturation=0.15, hue=0.02 + ), + transforms.ToTensor(), + normalize, + # RandomErasure 需作用在 Tensor 上,故放在 ToTensor 之后; + transforms.RandomErasing(p=0.5, scale=(0.02, 0.25), ratio=(0.3, 3.3)), + ] + ) + else: + self.transform = transforms.Compose( + [ + transforms.Resize((256, 128)), + transforms.ToTensor(), + normalize, + ] + ) + + def __len__(self) -> int: + return len(self.paths) + + def __getitem__(self, idx: int): + img = Image.open(self.paths[idx]).convert("RGB") + x = self.transform(img) + y = torch.tensor(self.labels[idx], dtype=torch.long) + cam = torch.tensor(self.cam_ids[idx], dtype=torch.long) + return x, y, cam + + +class RandomIdentityPKSampler(Sampler[int]): + """ + PK / Random Identity Sampler(Random Identity Sampling) + + 【原理简述】对比学习三元组需在 batch 内出现「同类多图」才能把 anchor / positive 找出来。 + PK 策略:在每个 mini-batch 中固定结构为: + - 先随机选 **P** 个不同身份; + - 每个身份若无放回地抽 **K** 张图像(样本不足则从该身份放回随机抽满 K); + 则 batch_size = P * K。 + 这样保证了每个身份在同一个 batch 中至少有若干张可用于 Triplet, + Batch-Hard 才能选「最难的正对 / 最难的负样本」。 + 【注意】P 不能超过数据集中可用身份总数;若 K > 该类张数则用放回采样。 + """ + + def __init__( + self, + labels: list[int], + p: int, + k: int, + *, + seed: int = 0, + length: int | None = None, + ) -> None: + super().__init__() + self.labels_np = np.array(labels) + self.p = int(p) + self.k = int(k) + self.seed = seed + self.epoch = 0 + + self.identities = sorted(np.unique(self.labels_np).tolist()) + self.num_identities = len(self.identities) + + idx_by_label: dict[int, list[int]] = {} + for i, lbl in enumerate(self.labels_np.tolist()): + idx_by_label.setdefault(int(lbl), []).append(i) + self.idx_by_label = {k_: np.array(v) for k_, v in idx_by_label.items()} + + if self.num_identities == 0: + raise RuntimeError("PKSampler: 无任何身份类别") + + if self.p > self.num_identities: + raise ValueError( + f"P={self.p} 大于数据集身份数 {self.num_identities}。" + " 请减小 --p。" + ) + + # 每个 epoch 内「迭代步数」:batch 数(每次 batch 含 P×K 张图) + self.labels_flat = labels + if length is None: + # 约按「扫过一遍身份组合」的量级设一个稳定值 + self.num_batches_est = max(32, min(200, len(self.labels_flat) // max(1, self.p * self.k))) + else: + self.num_batches_est = int(length) + + self.batch_size = self.p * self.k + + def __len__(self) -> int: + return self.num_batches_est * self.batch_size + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + def __iter__(self) -> Iterator[int]: + rng = np.random.RandomState((self.epoch * 9973 + self.seed) & 0xFFFFFFFF) + ids_arr = np.asarray(self.identities, dtype=np.int64) + + for _batch_id in range(self.num_batches_est): + # 一步:无放回抽 P 个身份(若身份不够则允许放回,实际 5 类且 P≤5 时总是无放回) + if ids_arr.size >= self.p: + chosen = rng.choice(ids_arr, size=self.p, replace=False) + else: + chosen = rng.choice(ids_arr, size=self.p, replace=True) + + out: list[int] = [] + for pid_pick in chosen.tolist(): + pool = self.idx_by_label[int(pid_pick)] + if pool.size >= self.k: + idx_pick = rng.choice(pool, size=self.k, replace=False) + else: + idx_pick = rng.choice(pool, size=self.k, replace=True) + out.extend(int(t) for t in idx_pick) + + yield from out + + +def batch_hard_triplet_loss(embeddings: torch.Tensor, labels: torch.Tensor, margin: float) -> torch.Tensor: + """ + Batch-Hard Triplet Loss(Hermans ECCV17 一类的标准形式) + + 对每个 anchor(batch 里的每个样本 i)在同一 batch 中选: + - Positive:与它同身份的样本 j 里面,距离**最大**的那一个( hardest positive ); + - Negative:与它不同身份的样本里,距离**最小**的那一个( hardest negative ); + 单项损失通常为 relu(d_pos_hard - d_neg_hard + margin) 对每个有效 anchor 求平均。 + 【数学含义】拉大「最难正对」相对「最易负样本」的间隔。 + 【实现】使用欧氏距离;若某样本在 batch 内无「异类」(极少见)或无「同类第二张」则不参与均值。 + """ + dist = torch.cdist(embeddings.float(), embeddings.float(), p=2.0).clamp(min=1e-8) + + bs = dist.size(0) + lbl = labels.long() + same = lbl.unsqueeze(1).eq(lbl.unsqueeze(0)) + losses: list[torch.Tensor] = [] + + for i in range(bs): + same_pos = same[i].clone() + same_pos[i] = False + if not same_pos.any(): + continue + hardest_pos = dist[i][same_pos].max() + + neg_mask = ~same[i] + neg_mask[i] = False + if not neg_mask.any(): + continue + hardest_neg = dist[i][neg_mask].min() + + losses.append(F.relu(hardest_pos - hardest_neg + margin)) + + if not losses: + return embeddings.sum() * 0.0 + return torch.stack(losses).mean() + + +class ReIDEmbedModel(nn.Module): + """ResNet50 预训练 backbone(去掉 fc)→ BN 512 嵌入 → logits(num_classes)。""" + + def __init__(self, num_classes: int, feat_dim: int = 512) -> None: + super().__init__() + backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) + self.backbone = nn.Sequential(*list(backbone.children())[:-2]) # 到 GAP 前,输出 [B,2048,7,7] + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.bottleneck = nn.Sequential( + nn.Linear(2048, feat_dim), + nn.BatchNorm1d(feat_dim), + nn.ReLU(inplace=True), + ) + self.classifier = nn.Linear(feat_dim, num_classes) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + h = self.backbone(x) + h = self.gap(h) + h = h.view(h.size(0), -1) + emb = self.bottleneck(h) + logits = self.classifier(emb) + return emb, logits + + +def collate_fn_pk(batch): + xs, ys, cams = zip(*batch, strict=True) + return torch.stack(xs, dim=0), torch.stack(ys, dim=0), torch.stack(cams, dim=0) + + +def train_one_epoch( + model: nn.Module, + loader: DataLoader, + optim: torch.optim.Optimizer, + device: torch.device, + margin: float, + triplet_w: float, + id_w: float, +) -> tuple[float, float, float]: + model.train() + sum_t = 0.0 + sum_id = 0.0 + n = 0 + ce = nn.CrossEntropyLoss() + + for x, y, _cam in loader: + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + emb, logits = model(x) + emb = F.normalize(emb, p=2, dim=1) + loss_t = batch_hard_triplet_loss(emb, y, margin=margin) + loss_id = ce(logits, y) + loss = triplet_w * loss_t + id_w * loss_id + + optim.zero_grad() + loss.backward() + optim.step() + + bs = x.size(0) + sum_t += loss_t.detach().item() * bs + sum_id += loss_id.detach().item() * bs + n += bs + + return sum_t / max(n, 1), sum_id / max(n, 1) + + +@torch.no_grad() +def estimate_id_accuracy(model: nn.Module, loader: DataLoader, device: torch.device) -> float: + model.eval() + correct = 0 + total = 0 + for x, y, _ in loader: + x = x.to(device) + y = y.to(device) + _, logits = model(x) + pred = logits.argmax(dim=1) + correct += int((pred == y).sum().item()) + total += y.numel() + return correct / max(total, 1) + + +def resolve_image_root(cli: str | Path) -> Path: + p = Path(cli).resolve() + sub = p / "doctor_picture" + if sub.is_dir(): + return sub + if p.is_dir(): + jpgs = list(p.glob("*.jpg")) + if len(jpgs) > 0: + return p + raise FileNotFoundError(f"未找到图像目录(可传 doctor_info_detect 或 doctor_picture): {cli}") + + +def parse_args(): + ap = argparse.ArgumentParser(description="Doctor Re-ID PK Triplet + ID 训练") + ap.add_argument( + "--data-root", + type=Path, + default=Path(__file__).resolve().parent, + help="含 doctor_picture 或图片根目录的路径", + ) + ap.add_argument("--epochs", type=int, default=50, help="epoch 建议在 40–60(默认 50)") + ap.add_argument( + "--batch-p", + type=int, + default=5, + help="PK 采样:每 batch 采样的身份数 P(将自动不大于身份总数,如默认 5 类)", + ) + ap.add_argument( + "--batch-k", + type=int, + default=8, + help="PK 采样:每位身份抽样张数 K;batch_size=P×K", + ) + ap.add_argument("--lr", type=float, default=3e-4) + ap.add_argument("--triplet-margin", type=float, default=0.3) + ap.add_argument("--triplet-weight", type=float, default=1.0) + ap.add_argument("--id-weight", type=float, default=2.0, help="ID Loss(CrossEntropy)权重") + ap.add_argument("--workers", type=int, default=4) + ap.add_argument("--seed", type=int, default=42) + ap.add_argument( + "--save", + type=Path, + default=Path(__file__).resolve().parent / "doctor_reid_best.pth", + help="最佳权重路径", + ) + return ap.parse_args() + + +def main() -> None: + args = parse_args() + rng = random.Random(args.seed) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + img_root = resolve_image_root(args.data_root) + + ds_train_aug = DoctorReIDDataset(img_root, augment=True) + ds_eval = DoctorReIDDataset(img_root, augment=False) + + n = len(ds_train_aug) + perm = list(range(n)) + rng.shuffle(perm) + n_val = max(32, int(0.1 * n)) + val_ix = sorted(perm[:n_val]) + train_ix = perm[n_val:] + + labels_tr = [ds_train_aug.labels[i] for i in train_ix] + + num_classes = len(ds_train_aug.pid_to_label) + p_eff = min(args.batch_p, len(set(labels_tr))) + + from torch.utils.data import Subset + + sampler = RandomIdentityPKSampler( + labels_tr, + p=p_eff, + k=args.batch_k, + seed=args.seed, + ) + train_loader = DataLoader( + Subset(ds_train_aug, train_ix), + batch_size=sampler.batch_size, + sampler=sampler, + num_workers=args.workers, + pin_memory=True, + drop_last=True, + collate_fn=collate_fn_pk, + ) + val_loader = DataLoader( + Subset(ds_eval, val_ix), + batch_size=64, + shuffle=False, + num_workers=args.workers, + collate_fn=collate_fn_pk, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(device) + + optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4) + scheduler = CosineAnnealingLR(optim, T_max=args.epochs, eta_min=1e-6) + + best_acc = -1.0 + for epoch in range(1, args.epochs + 1): + sampler.set_epoch(epoch) + tr_t, tr_id = train_one_epoch( + model, + train_loader, + optim, + device, + margin=args.triplet_margin, + triplet_w=args.triplet_weight, + id_w=args.id_weight, + ) + scheduler.step() + lr_now = optim.param_groups[0]["lr"] + val_acc = estimate_id_accuracy(model, val_loader, device) + + print( + f"epoch {epoch:03d}/{args.epochs} | " + f"triplet {tr_t:.4f} | id_loss_ce {tr_id:.4f} | " + f"lr {lr_now:.6f} | val_id_acc ~ {val_acc:.4f}" + ) + + if val_acc >= best_acc: + best_acc = val_acc + torch.save( + { + "epoch": epoch, + "model_state": model.state_dict(), + "num_classes": num_classes, + "pid_to_label": ds_train_aug.pid_to_label, + "args": vars(args), + }, + args.save, + ) + print(f"[保存] checkpoint → {args.save} (best val_id_acc {best_acc:.4f})") + + print(f"训练结束。最佳 val_id_acc≈{best_acc:.4f}, 权重: {args.save}") + + +if __name__ == "__main__": + main() diff --git a/setup.sh b/setup.sh index 3ea6a85..02d13b6 100755 --- a/setup.sh +++ b/setup.sh @@ -25,6 +25,7 @@ for w in hand_detect.pt goodbad_frame.pt haocai_classify.pt; do test -f "weights/$w" && echo " OK weights/$w" || echo " 缺失 weights/$w" done test -f doctor_identity_package/doctor_info.pth && echo " OK doctor_info.pth" || echo " 缺失 doctor_info.pth" +test -f doctor_identity_package/train_reid_contrastive.py && echo " OK train_reid_contrastive.py" || echo " 缺失 train_reid_contrastive.py" test -f input/视频中的商品信息表.xlsx && echo " OK Excel" || echo " 缺失 Excel" echo "" diff --git a/src/config.py b/src/config.py index 9b80ced..8296fb6 100644 --- a/src/config.py +++ b/src/config.py @@ -129,6 +129,7 @@ def load_run_config(pack_root: Path, config_path: Path) -> Namespace: gap_merge_enabled=bool(gm.get("enabled", False)), gap_merge_max_gap_sec=float(gm.get("max_gap_sec", 2.0)), doctor_identity_enabled=bool(did.get("enabled", True)), + doctor_identity_stream_enabled=bool(did.get("stream_enabled", True)), doctor_identity_checkpoint=_rel(pack_root, doctor_ckpt_raw), doctor_identity_labels_csv=_rel(pack_root, doctor_labels_raw), doctor_identity_pose_min_detection_confidence=float( @@ -137,6 +138,9 @@ def load_run_config(pack_root: Path, config_path: Path) -> Namespace: doctor_identity_min_identity_confidence=float(did.get("min_identity_confidence", 0.0)), doctor_identity_middle_seconds=float(did.get("middle_seconds", 10.0)), doctor_identity_sample_fps=float(did.get("sample_fps", 3.0)), + doctor_identity_segment_sample_fps=float( + did.get("segment_sample_fps", did.get("sample_fps", 3.0)) + ), doctor_identity_pad_frac=float(did.get("pad_frac", 0.15)), basket_det_conf=float(bk.get("det_conf", p2["det_conf"])), basket_contact_iou_threshold=legacy_contact_iou, diff --git a/src/doctor_identity.py b/src/doctor_identity.py new file mode 100644 index 0000000..b951eb8 --- /dev/null +++ b/src/doctor_identity.py @@ -0,0 +1,202 @@ +"""医生身份识别:离线整片 + 推流段内(与耗材并行)。""" +from __future__ import annotations + +import importlib.util +import sys +from argparse import Namespace +from pathlib import Path +from typing import Any + +import numpy as np +import torch + +PACK_ROOT = Path(__file__).resolve().parent.parent + + +def _load_doctor_module(script_path: Path) -> Any: + spec = importlib.util.spec_from_file_location("doctor_identity_runtime", script_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"无法加载医生识别脚本: {script_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _doctor_script_path() -> Path: + return PACK_ROOT / "doctor_identity_package" / "infer_doctor_from_video.py" + + +class DoctorIdentityService: + """一次性加载 Pose + ReID,供推流每段复用。""" + + def __init__(self, args: Namespace) -> None: + self.args = args + self._mod: Any | None = None + self._landmarker: Any | None = None + self._reid_model: Any | None = None + self._reid_device: torch.device | None = None + self._label_to_pid: dict[int, str] | None = None + self._transform: Any | None = None + self._name_map: dict[str, str] = {} + + def _ensure_loaded(self) -> Any: + if self._mod is not None: + return self._mod + + script_path = _doctor_script_path() + if not script_path.is_file(): + raise FileNotFoundError(f"缺少脚本: {script_path}") + + pack_dir = script_path.parent + if str(pack_dir) not in sys.path: + sys.path.insert(0, str(pack_dir)) + + mod = _load_doctor_module(script_path) + checkpoint = Path(self.args.doctor_identity_checkpoint).resolve() + labels_csv = Path(self.args.doctor_identity_labels_csv).resolve() + if not checkpoint.is_file(): + raise FileNotFoundError(f"缺少权重: {checkpoint}") + if not labels_csv.is_file(): + raise FileNotFoundError(f"缺少标签映射: {labels_csv}") + + model_path = mod._ensure_pose_lite_model(pack_dir / ".mediapipe_models") + opts = mod.PoseLandmarkerOptions( + base_options=mod.BaseOptions(model_asset_path=str(model_path)), + running_mode=mod.VisionRunningMode.IMAGE, + min_pose_detection_confidence=float( + self.args.doctor_identity_pose_min_detection_confidence + ), + ) + self._landmarker = mod.PoseLandmarker.create_from_options(opts) + self._reid_model, self._reid_device, self._label_to_pid, self._transform = ( + mod.load_reid_model(checkpoint) + ) + self._name_map = mod.load_name_mapping(labels_csv) + self._mod = mod + return mod + + def close(self) -> None: + if self._landmarker is not None: + self._landmarker.close() + self._landmarker = None + + def _format_result(self, raw_pid: str, conf: float) -> dict[str, Any]: + min_conf = float(self.args.doctor_identity_min_identity_confidence) + name = self._name_map.get(str(raw_pid), "") + low = conf < min_conf + return { + "ok": True, + "doctor_id": str(raw_pid), + "doctor_name": name, + "doctor_conf": conf, + "low_confidence": low, + } + + def infer_segment( + self, + *, + start_sec: float, + end_sec: float, + video_path: Path | None = None, + use_file_source: bool = False, + frames: list[tuple[float, np.ndarray]] | None = None, + ) -> dict[str, Any]: + """段内医生识别。失败返回 ok=False 与 reason。""" + try: + mod = self._ensure_loaded() + pad_frac = float(self.args.doctor_identity_pad_frac) + sample_fps = float( + getattr( + self.args, + "doctor_identity_segment_sample_fps", + getattr(self.args, "doctor_identity_sample_fps", 3.0), + ) + ) + + if use_file_source and video_path is not None and video_path.is_file(): + best_crop = mod.pick_best_person_crop_in_window( + video_path, + self._landmarker, + start_sec, + end_sec, + sample_fps, + pad_frac, + ) + elif frames: + best_crop = mod.pick_best_person_crop_from_frames( + frames, + self._landmarker, + start_sec, + end_sec, + pad_frac, + ) + else: + return {"ok": False, "reason": "无可用视频源或缓存帧"} + + raw_pid, conf = mod.run_inference_preloaded( + best_crop, + self._reid_model, + self._reid_device, + self._label_to_pid, + self._transform, + ) + return self._format_result(raw_pid, conf) + except Exception as exc: # noqa: BLE001 + return {"ok": False, "reason": str(exc)} + + def infer_whole_video(self, video_path: Path) -> str: + """离线全片:取视频中间窗口识别,返回展示用文本。""" + if not bool(getattr(self.args, "doctor_identity_enabled", True)): + return "未启用" + + try: + mod = self._ensure_loaded() + best_crop = mod.pick_best_person_crop( + video_path=video_path, + landmarker=self._landmarker, + middle_seconds=float(self.args.doctor_identity_middle_seconds), + sample_fps=float(self.args.doctor_identity_sample_fps), + pad_frac=float(self.args.doctor_identity_pad_frac), + ) + raw_pid, conf = mod.run_inference_preloaded( + best_crop, + self._reid_model, + self._reid_device, + self._label_to_pid, + self._transform, + ) + res = self._format_result(raw_pid, conf) + suffix = " [低置信度]" if res.get("low_confidence") else "" + name = res.get("doctor_name") or "" + if name: + return f"{name} (id={raw_pid}, conf={conf:.4f}){suffix}" + return f"doctor_id={raw_pid} (conf={conf:.4f}){suffix}" + except Exception as exc: # noqa: BLE001 + return f"识别失败({exc})" + + +def stream_doctor_enabled(args: Namespace) -> bool: + return bool(getattr(args, "doctor_identity_enabled", True)) and bool( + getattr(args, "doctor_identity_stream_enabled", True) + ) + + +def infer_doctor_text_offline(args: Namespace, video_path: Path) -> str: + """离线入口:校验资源后返回医生信息文本。""" + if not bool(getattr(args, "doctor_identity_enabled", True)): + return "未启用" + + checkpoint = Path(args.doctor_identity_checkpoint).resolve() + labels_csv = Path(args.doctor_identity_labels_csv).resolve() + if not checkpoint.is_file(): + return f"识别失败(缺少权重: {checkpoint})" + if not labels_csv.is_file(): + return f"识别失败(缺少标签映射: {labels_csv})" + if not _doctor_script_path().is_file(): + return f"识别失败(缺少脚本: {_doctor_script_path()})" + + svc = DoctorIdentityService(args) + try: + return svc.infer_whole_video(video_path.resolve()) + finally: + svc.close() diff --git a/src/orchestrator.py b/src/orchestrator.py index dc25b8b..8460f89 100644 --- a/src/orchestrator.py +++ b/src/orchestrator.py @@ -1,7 +1,6 @@ """主流程编排:与仓库 main_pipeline.PipelineManager 逻辑一致,参数来自 YAML(SimpleNamespace)。""" from __future__ import annotations -import importlib.util import tempfile from argparse import Namespace from pathlib import Path @@ -26,69 +25,11 @@ from run_segments_consumable_vote import pad_box_bottom_only as _pad_box from ultralytics import YOLO from basket_segmenter import build_segments_from_basket +from doctor_identity import infer_doctor_text_offline from pack_utils import load_allowed_names_from_excel, log, resolve_allowed_class_idx from stream_orchestrator import _haocai_infer_kwargs -def _load_doctor_module(script_path: Path) -> Any: - spec = importlib.util.spec_from_file_location("doctor_identity_runtime", script_path) - if spec is None or spec.loader is None: - raise RuntimeError(f"无法加载医生识别脚本: {script_path}") - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - -def _infer_doctor_text(args: Namespace, video_path: Path) -> str: - if not bool(getattr(args, "doctor_identity_enabled", True)): - return "未启用" - - checkpoint = Path(args.doctor_identity_checkpoint).resolve() - labels_csv = Path(args.doctor_identity_labels_csv).resolve() - if not checkpoint.is_file(): - return f"识别失败(缺少权重: {checkpoint})" - if not labels_csv.is_file(): - return f"识别失败(缺少标签映射: {labels_csv})" - - pack_root = Path(__file__).resolve().parent.parent - script_path = pack_root / "doctor_identity_package" / "infer_doctor_from_video.py" - if not script_path.is_file(): - return f"识别失败(缺少脚本: {script_path})" - - try: - doctor_mod = _load_doctor_module(script_path) - model_path = doctor_mod._ensure_pose_lite_model(script_path.parent / ".mediapipe_models") - opts = doctor_mod.PoseLandmarkerOptions( - base_options=doctor_mod.BaseOptions(model_asset_path=str(model_path)), - running_mode=doctor_mod.VisionRunningMode.IMAGE, - min_pose_detection_confidence=float( - args.doctor_identity_pose_min_detection_confidence - ), - ) - landmarker = doctor_mod.PoseLandmarker.create_from_options(opts) - try: - best_crop = doctor_mod.pick_best_person_crop( - video_path=video_path, - landmarker=landmarker, - middle_seconds=float(args.doctor_identity_middle_seconds), - sample_fps=float(args.doctor_identity_sample_fps), - pad_frac=float(args.doctor_identity_pad_frac), - ) - finally: - landmarker.close() - - raw_pid, conf = doctor_mod.run_inference(best_crop, checkpoint) - min_conf = float(args.doctor_identity_min_identity_confidence) - name_map = doctor_mod.load_name_mapping(labels_csv) - doctor_name = name_map.get(str(raw_pid), "") - suffix = " [低置信度]" if conf < min_conf else "" - if doctor_name: - return f"{doctor_name} (id={raw_pid}, conf={conf:.4f}){suffix}" - return f"doctor_id={raw_pid} (conf={conf:.4f}){suffix}" - except Exception as exc: # noqa: BLE001 - return f"识别失败({exc})" - - def _resolve_allowed_names(args: Namespace, excel_path: Path) -> list[str] | None: if not getattr(args, "use_whitelist", True): return [] @@ -394,7 +335,7 @@ class PipelineManager: cap.release() log("医生识别:开始执行…") - doctor_text = _infer_doctor_text(args, video_path) + doctor_text = infer_doctor_text_offline(args, video_path) log(f"医生识别:{doctor_text}") lines_out.append(f"医生信息:{doctor_text}") diff --git a/src/stream_orchestrator.py b/src/stream_orchestrator.py index 1424b26..f0ee797 100644 --- a/src/stream_orchestrator.py +++ b/src/stream_orchestrator.py @@ -4,6 +4,7 @@ from __future__ import annotations import gc import time from argparse import Namespace +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any @@ -23,9 +24,26 @@ from basket_segmenter import ( _select_basket_roi_tkinter, save_basket_roi_json, ) +from doctor_identity import DoctorIdentityService, stream_doctor_enabled from pack_utils import log, resolve_allowed_class_idx from stream_basket_session import CachedClip, StreamBasketSession +_HAOCAI_COLS = [ + "rank", + "start_sec", + "end_sec", + "product_id_top1", + "top1_name", + "top1_conf", + "product_id_top2", + "top2_name", + "top2_conf", + "product_id_top3", + "top3_name", + "top3_conf", +] +_DOCTOR_COLS = ["doctor_id", "doctor_name", "doctor_conf"] + def _validate_stream_weights(args: Namespace) -> bool: for p, lab in ( @@ -65,6 +83,14 @@ def _resolve_basket_roi( return roi +def _doctor_row_cells(doc: dict[str, Any] | None) -> list[str]: + if doc is None or not doc.get("ok"): + return ["", "", ""] + conf = doc.get("doctor_conf") + conf_s = f"{float(conf):.6f}" if conf is not None else "" + return [str(doc.get("doctor_id", "")), str(doc.get("doctor_name", "")), conf_s] + + def _format_result_row( rank: int, t0: float, @@ -73,6 +99,8 @@ def _format_result_row( product_map: dict[str, str], *, legacy_12_col: bool, + include_doctor_cols: bool = False, + doctor: dict[str, Any] | None = None, ) -> str: sep = "\t" if not info.get("ok"): @@ -93,6 +121,8 @@ def _format_result_row( ] if not legacy_12_col: row.extend(["", ""]) + if include_doctor_cols: + row.extend(_doctor_row_cells(doctor)) return sep.join(row) n1, n2, n3 = info["top_names"] @@ -120,9 +150,105 @@ def _format_result_row( ] if not legacy_12_col: row.extend(["", ""]) + if include_doctor_cols: + row.extend(_doctor_row_cells(doctor)) return sep.join(row) +def _log_doctor_result(rank: int, doc: dict[str, Any] | None) -> None: + if doc is None: + return + if doc.get("ok"): + name = doc.get("doctor_name") or doc.get("doctor_id", "") + conf = doc.get("doctor_conf", 0.0) + low = " [低置信度]" if doc.get("low_confidence") else "" + log(f"[stream] rank={rank} 医生: {name} (id={doc.get('doctor_id')}, conf={conf:.4f}){low}") + else: + log(f"[stream] rank={rank} 医生识别失败: {doc.get('reason', '')}") + + +def _process_one_clip( + rank: int, + clip: CachedClip, + *, + det: YOLO, + hc: HaocaiOnlyClassifier, + infer_cap: cv2.VideoCapture | None, + use_file_infer: bool, + is_file: bool, + source: str, + args: Namespace, + cls_names: dict, + allowed_idx: frozenset[int] | None, + predict_kw: dict[str, Any], + product_map: dict[str, str], + out_path: Path, + doctor_svc: DoctorIdentityService | None, + include_doctor_cols: bool, +) -> None: + log( + f"[stream] 识别 rank={rank} [{clip.start_sec:.3f},{clip.end_sec:.3f}] " + f"({len(clip.frames)} 帧)…" + ) + frames_copy = list(clip.frames) if doctor_svc is not None and not use_file_infer else None + video_path = Path(source).resolve() if is_file else None + + doc: dict[str, Any] | None = None + if doctor_svc is None: + info = _infer_clip( + clip, + det=det, + hc=hc, + cap=infer_cap, + use_file_infer=use_file_infer, + args=args, + cls_names=cls_names, + allowed_idx=allowed_idx, + predict_kw=predict_kw, + rank=rank, + ) + else: + with ThreadPoolExecutor(max_workers=2) as pool: + fut_h = pool.submit( + _infer_clip, + clip, + det=det, + hc=hc, + cap=infer_cap, + use_file_infer=use_file_infer, + args=args, + cls_names=cls_names, + allowed_idx=allowed_idx, + predict_kw=predict_kw, + rank=rank, + ) + fut_d = pool.submit( + doctor_svc.infer_segment, + start_sec=clip.start_sec, + end_sec=clip.end_sec, + video_path=video_path, + use_file_source=use_file_infer and is_file, + frames=frames_copy, + ) + info = fut_h.result() + doc = fut_d.result() + + line = _format_result_row( + rank, + clip.start_sec, + clip.end_sec, + info, + product_map, + legacy_12_col=bool(args.legacy_12_col_only), + include_doctor_cols=include_doctor_cols, + doctor=doc, + ) + with out_path.open("a", encoding="utf-8") as f: + f.write(line + "\n") + _log_doctor_result(rank, doc) + log(f"[stream] rank={rank} 已写入") + + def _maybe_free_gpu() -> None: gc.collect() try: @@ -267,6 +393,17 @@ class StreamBasketOrchestrator: else: log("[stream] 白名单已关闭,使用全 41 类") + include_doctor_cols = stream_doctor_enabled(args) + doctor_svc: DoctorIdentityService | None = None + if include_doctor_cols: + try: + doctor_svc = DoctorIdentityService(args) + log("[stream] 医生身份识别已启用(段内与耗材并行)") + except Exception as exc: # noqa: BLE001 + log(f"[stream] 医生识别初始化失败,本 run 不写入医生列: {exc}") + include_doctor_cols = False + doctor_svc = None + cap = cv2.VideoCapture(source) if not cap.isOpened(): log(f"[stream] 无法打开流: {source}") @@ -347,23 +484,10 @@ class StreamBasketOrchestrator: fallback = str(getattr(args, "stream_infer_fallback", "cache")) log(f"[stream] 段内识别: JPEG 缓存帧(infer_fallback={fallback})") - header = "\t".join( - [ - "rank", - "start_sec", - "end_sec", - "product_id_top1", - "top1_name", - "top1_conf", - "product_id_top2", - "top2_name", - "top2_conf", - "product_id_top3", - "top3_name", - "top3_conf", - ] - ) - out_path.write_text(header + "\n", encoding="utf-8") + header_cols = list(_HAOCAI_COLS) + if include_doctor_cols: + header_cols.extend(_DOCTOR_COLS) + out_path.write_text("\t".join(header_cols) + "\n", encoding="utf-8") rank = 0 frame_idx = 0 @@ -371,33 +495,24 @@ class StreamBasketOrchestrator: nonlocal rank for clip in session.poll_ready_clips(): rank += 1 - log( - f"[stream] 识别 rank={rank} [{clip.start_sec:.3f},{clip.end_sec:.3f}] " - f"({len(clip.frames)} 帧)…" - ) - info = _infer_clip( + _process_one_clip( + rank, clip, det=det, hc=hc, - cap=infer_cap, + infer_cap=infer_cap, use_file_infer=use_file_infer, + is_file=is_file, + source=source, args=args, cls_names=cls_names, allowed_idx=allowed_idx, predict_kw=predict_kw, - rank=rank, + product_map=product_map, + out_path=out_path, + doctor_svc=doctor_svc, + include_doctor_cols=include_doctor_cols, ) - line = _format_result_row( - rank, - clip.start_sec, - clip.end_sec, - info, - product_map, - legacy_12_col=bool(args.legacy_12_col_only), - ) - with out_path.open("a", encoding="utf-8") as f: - f.write(line + "\n") - log(f"[stream] rank={rank} 已写入") session.push_frame(t0, first) process_ready() @@ -425,34 +540,31 @@ class StreamBasketOrchestrator: except KeyboardInterrupt: log("[stream] 用户中断") finally: + for clip in session.poll_ready_clips(): + rank += 1 + _process_one_clip( + rank, + clip, + det=det, + hc=hc, + infer_cap=infer_cap, + use_file_infer=use_file_infer, + is_file=is_file, + source=source, + args=args, + cls_names=cls_names, + allowed_idx=allowed_idx, + predict_kw=predict_kw, + product_map=product_map, + out_path=out_path, + doctor_svc=doctor_svc, + include_doctor_cols=include_doctor_cols, + ) cap.release() if infer_cap is not None: infer_cap.release() - - for clip in session.poll_ready_clips(): - rank += 1 - info = _infer_clip( - clip, - det=det, - hc=hc, - cap=infer_cap, - use_file_infer=use_file_infer, - args=args, - cls_names=cls_names, - allowed_idx=allowed_idx, - predict_kw=predict_kw, - rank=rank, - ) - line = _format_result_row( - rank, - clip.start_sec, - clip.end_sec, - info, - product_map, - legacy_12_col=bool(args.legacy_12_col_only), - ) - with out_path.open("a", encoding="utf-8") as f: - f.write(line + "\n") + if doctor_svc is not None: + doctor_svc.close() log(f"[stream] 结束,共 {rank} 段,结果: {out_path}") return 0 if rank > 0 or is_file else 0