doctor_identify

This commit is contained in:
hsz
2026-06-03 14:46:16 +08:00
parent 1569056904
commit 0d27ae415e
9 changed files with 991 additions and 168 deletions

View File

@@ -1,6 +1,6 @@
# 手术室耗材篮子识别包(离线 + 推流)
段内流程:**手检≥2 手 union→ 好坏帧门控 → 耗材分类**离线另含**医生识别**。
段内流程:**手检≥2 手 union→ 好坏帧门控 → 耗材分类****推流**在每个接触段内与耗材**并行**医生识别**离线**在全片结束后追加一行医生信息
`configs/default_config.yaml` 当前参数一致(`imgsz_det: 1920``contact+1~+6` 等)。
@@ -27,7 +27,7 @@ pip install -r requirements.txt
| 脚本 | 用途 |
|------|------|
| `main_basket.py` | **离线**:全片篮子接触分段 → Phase2 → gap 合并 → 医生识别 |
| `main_basket_stream.py` | **推流/本地 MP4 模拟推流**:逐帧触发 → 段内识别 → 实时写 TSV |
| `main_basket_stream.py` | **推流/本地 MP4 模拟推流**:逐帧触发 → 段内耗材+医生并行 → 实时写 15 列 TSV |
| `main_segments_offline.py` | 按 TSV 时间段对离线 MP4 重跑段内识别(校验用) |
## 1. 离线跑视频
@@ -59,6 +59,7 @@ python main_basket_stream.py \
- 本地 MP4`stream.infer_source: file` → 段内**回源 4K**(与离线一致)
- 真 RTSP无法 seek 时回退 JPEG 缓存(`cache_max_width: 1920`
- 医生识别:`doctor_identity.stream_enabled: true`每段并行识别TSV 追加 `doctor_id` / `doctor_name` / `doctor_conf`
## 3. HEVC 视频
@@ -78,16 +79,19 @@ bash scripts/remux_hevc.sh /path/to/source.mp4
| `basket` | `iou_on: 0.03``confirm: 0.1``cooldown: 3`,窗口 contact+1~+6 |
| `stream` | 段窗口与 basket 一致;`infer_source: file` |
| `io` | `use_whitelist: false`(全 41 类) |
| `doctor_identity` | `enabled` / `stream_enabled`;推流用 `segment_sample_fps` 在段 `[start,end]` 内采样 |
## 模型文件(`weights/`
- `hand_detect.pt` — 手部检测
- `goodbad_frame.pt` — 好坏帧门控
- `haocai_classify.pt` — 耗材分类
- `doctor_identity_package/doctor_info.pth` — 医生 ReID需同目录 `train_reid_contrastive.py`
## 输出格式
12 列 TSV + 离线末尾一行 `医生信息:...`(推流无医生行)。
- **推流**15 列 TSV12 列耗材 + `doctor_id` / `doctor_name` / `doctor_conf`),无末尾汇总行
- **离线**12 列 TSV + 末尾一行 `医生信息:...`(全片中间窗口识别)
## 目录结构
@@ -99,7 +103,7 @@ bash scripts/remux_hevc.sh /path/to/source.mp4
├── configs/default_config.yaml
├── weights/ # 3 个 YOLO 权重
├── input/视频中的商品信息表.xlsx
├── doctor_identity_package/ # 医生识别(离线)
├── doctor_identity_package/ # 医生识别(离线整片 + 推流段内
├── src/ code/ # 编排与算法
├── output/ # 结果输出目录
├── setup.sh requirements.txt

View File

@@ -56,12 +56,14 @@ output:
doctor_identity:
enabled: true
stream_enabled: true
checkpoint: doctor_identity_package/doctor_info.pth
labels_csv: doctor_identity_package/labels.csv
pose_min_detection_confidence: 0.30
min_identity_confidence: 0.00
middle_seconds: 10.0
sample_fps: 3.0
segment_sample_fps: 3.0
pad_frac: 0.15
# 篮子接触分段main_basket.py / main_basket_stream.py

View File

@@ -139,6 +139,19 @@ def expand_bbox_with_padding(
return nx1, ny1, nx2, ny2
def sample_window_timestamps(t0: float, t1: float, sample_fps: float) -> list[float]:
"""在 [t0, t1) 内按 sample_fps 均匀采样时间戳。"""
if t1 <= t0 or sample_fps <= 0:
return []
step = 1.0 / sample_fps
ts: list[float] = []
t = float(t0)
while t < t1 - 1e-6:
ts.append(t)
t += step
return ts
def sample_middle_timestamps(duration_sec: float, middle_seconds: float, sample_fps: float) -> list[float]:
if duration_sec <= 0 or middle_seconds <= 0 or sample_fps <= 0:
return []
@@ -146,13 +159,37 @@ def sample_middle_timestamps(duration_sec: float, middle_seconds: float, sample_
half = middle_seconds / 2.0
t0 = max(0.0, center - half)
t1 = min(duration_sec, center + half)
step = 1.0 / sample_fps
ts = []
t = t0
while t < t1 - 1e-6:
ts.append(t)
t += step
return ts
return sample_window_timestamps(t0, t1, sample_fps)
def _update_best_crop_from_frame(
frame: np.ndarray,
landmarker: PoseLandmarker,
pad_frac: float,
*,
best_area: int,
best_crop: np.ndarray | None,
) -> tuple[int, np.ndarray | None]:
h, w = frame.shape[:2]
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
mp_img = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
res = landmarker.detect(mp_img)
if not res.pose_landmarks:
return best_area, best_crop
for lmk in res.pose_landmarks:
box = bbox_from_normalized_pose_landmarks(w, h, lmk)
if box is None:
continue
ex1, ey1, ex2, ey2 = expand_bbox_with_padding(*box, w, h, pad_frac=pad_frac)
crop = frame[ey1:ey2, ex1:ex2]
if crop.size == 0:
continue
area = int((ex2 - ex1) * (ey2 - ey1))
if area > best_area:
best_area = area
best_crop = crop.copy()
return best_area, best_crop
def pick_best_person_crop(
@@ -182,25 +219,9 @@ def pick_best_person_crop(
ok, frame = cap.read()
if not ok or frame is None:
continue
h, w = frame.shape[:2]
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
mp_img = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
res = landmarker.detect(mp_img)
if not res.pose_landmarks:
continue
for lmk in res.pose_landmarks:
box = bbox_from_normalized_pose_landmarks(w, h, lmk)
if box is None:
continue
ex1, ey1, ex2, ey2 = expand_bbox_with_padding(*box, w, h, pad_frac=pad_frac)
crop = frame[ey1:ey2, ex1:ex2]
if crop.size == 0:
continue
area = int((ex2 - ex1) * (ey2 - ey1))
if area > best_area:
best_area = area
best_crop = crop.copy()
best_area, best_crop = _update_best_crop_from_frame(
frame, landmarker, pad_frac, best_area=best_area, best_crop=best_crop
)
cap.release()
if best_crop is None:
@@ -208,6 +229,63 @@ def pick_best_person_crop(
return best_crop
def pick_best_person_crop_in_window(
video_path: Path,
landmarker: PoseLandmarker,
start_sec: float,
end_sec: float,
sample_fps: float,
pad_frac: float,
) -> np.ndarray:
"""在视频 [start_sec, end_sec] 窗口内采样并取最大人体 crop。"""
timestamps = sample_window_timestamps(start_sec, end_sec, sample_fps)
if not timestamps:
raise RuntimeError("No valid timestamps in segment window.")
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise RuntimeError(f"Cannot open video: {video_path}")
best_area = -1
best_crop: np.ndarray | None = None
try:
for ts in timestamps:
cap.set(cv2.CAP_PROP_POS_MSEC, ts * 1000.0)
ok, frame = cap.read()
if not ok or frame is None:
continue
best_area, best_crop = _update_best_crop_from_frame(
frame, landmarker, pad_frac, best_area=best_area, best_crop=best_crop
)
finally:
cap.release()
if best_crop is None:
raise RuntimeError("No person detected in segment window.")
return best_crop
def pick_best_person_crop_from_frames(
frames: list[tuple[float, np.ndarray]],
landmarker: PoseLandmarker,
start_sec: float,
end_sec: float,
pad_frac: float,
) -> np.ndarray:
"""在缓存帧列表中按时间段筛选并取最大人体 crop推流 RTSP / 缓存路径)。"""
best_area = -1
best_crop: np.ndarray | None = None
for t_sec, frame in frames:
if t_sec < start_sec - 1e-6 or t_sec > end_sec + 1e-6:
continue
best_area, best_crop = _update_best_crop_from_frame(
frame, landmarker, pad_frac, best_area=best_area, best_crop=best_crop
)
if best_crop is None:
raise RuntimeError("No person detected in cached segment frames.")
return best_crop
def build_label_to_pid(pid_to_label: dict) -> dict[int, str]:
label_to_pid: dict[int, str] = {}
for raw_pid, label in pid_to_label.items():
@@ -233,22 +311,8 @@ def load_name_mapping(labels_csv: Path) -> dict[str, str]:
return mapping
def run_inference(crop_bgr: np.ndarray, checkpoint_path: Path) -> tuple[str, float]:
if not checkpoint_path.is_file():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
num_classes = int(ckpt["num_classes"])
pid_to_label = ckpt.get("pid_to_label", {})
if not isinstance(pid_to_label, dict):
raise RuntimeError("Checkpoint missing valid pid_to_label dict.")
model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(device)
model.load_state_dict(ckpt["model_state"])
model.eval()
transform = transforms.Compose(
def _default_reid_transform() -> transforms.Compose:
return transforms.Compose(
[
transforms.Resize((256, 128)),
transforms.ToTensor(),
@@ -258,6 +322,37 @@ def run_inference(crop_bgr: np.ndarray, checkpoint_path: Path) -> tuple[str, flo
),
]
)
def load_reid_model(
checkpoint_path: Path,
device: torch.device | None = None,
) -> tuple[ReIDEmbedModel, torch.device, dict[int, str], transforms.Compose]:
"""加载 ReID 模型(推流多段复用,避免重复 torch.load"""
if not checkpoint_path.is_file():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
dev = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(checkpoint_path, map_location=dev, weights_only=False)
num_classes = int(ckpt["num_classes"])
pid_to_label = ckpt.get("pid_to_label", {})
if not isinstance(pid_to_label, dict):
raise RuntimeError("Checkpoint missing valid pid_to_label dict.")
model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(dev)
model.load_state_dict(ckpt["model_state"])
model.eval()
label_to_pid = build_label_to_pid(pid_to_label)
return model, dev, label_to_pid, _default_reid_transform()
def run_inference_preloaded(
crop_bgr: np.ndarray,
model: ReIDEmbedModel,
device: torch.device,
label_to_pid: dict[int, str],
transform: transforms.Compose,
) -> tuple[str, float]:
crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
inp = transform(Image.fromarray(crop_rgb)).unsqueeze(0).to(device)
@@ -267,13 +362,17 @@ def run_inference(crop_bgr: np.ndarray, checkpoint_path: Path) -> tuple[str, flo
pred_label = int(torch.argmax(probs, dim=1).item())
conf = float(probs[0, pred_label].item())
label_to_pid = build_label_to_pid(pid_to_label)
raw_pid = label_to_pid.get(pred_label)
if raw_pid is None:
raise RuntimeError(f"Predicted label {pred_label} not found in pid mapping.")
return raw_pid, conf
def run_inference(crop_bgr: np.ndarray, checkpoint_path: Path) -> tuple[str, float]:
model, device, label_to_pid, transform = load_reid_model(checkpoint_path)
return run_inference_preloaded(crop_bgr, model, device, label_to_pid, transform)
def main() -> int:
args = parse_args()
if not args.video.is_file():

View File

@@ -0,0 +1,458 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
人体身份重识别Person ReID对比学习训练脚本单机单文件
- DatasetMarket-1501 风格文件名解析 pid/cam
- PK 采样器:每个 batch 内 P 个身份 × 每张 K 样本;
- 模型ImageNet ResNet50 骨干 + GAP + 512 维嵌入 + ID 分类头;
- 损失Batch-Hard Triplet + ID交叉熵联合
"""
from __future__ import annotations
import argparse
import random
import re
from pathlib import Path
from typing import Iterator
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision import models, transforms
# ---------------------------------------------------------------------------
# 文件名解析:例如 24502_c1_s1_00001.jpg → pid=24502cam_id=1
# ---------------------------------------------------------------------------
_NAME_RE = re.compile(
r"^(?P<pid>\d+)_c(?P<cam>\d+)_",
flags=re.I,
)
def parse_market1501_style_name(stem: str) -> tuple[int | None, int | None]:
"""从「不含后缀」的文件名 stem 中提取身份 ID、机位 ID。"""
m = _NAME_RE.match(stem)
if not m:
return None, None
return int(m.group("pid")), int(m.group("cam"))
class DoctorReIDDataset(Dataset):
"""医生 ReID解析 pid/cam将 pid 重映射到 0..num_classes-1。"""
def __init__(self, image_root: Path, augment: bool) -> None:
self.image_root = Path(image_root).resolve()
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
paths = sorted(
p
for p in self.image_root.rglob("*")
if p.is_file() and p.suffix.lower() in exts
)
pid_raw_list: list[int] = []
cam_raw_list: list[int] = []
valid_paths: list[Path] = []
for p in paths:
pid_raw, cam_raw = parse_market1501_style_name(p.stem)
if pid_raw is None:
continue
valid_paths.append(p)
pid_raw_list.append(pid_raw)
cam_raw_list.append(int(cam_raw))
unique_pids = sorted(set(pid_raw_list))
self.pid_to_label = {pid: i for i, pid in enumerate(unique_pids)}
self.labels: list[int] = [self.pid_to_label[r] for r in pid_raw_list]
self.cam_ids: list[int] = cam_raw_list
self.paths = valid_paths
if len(self.paths) == 0:
raise RuntimeError(f"目录下未发现有效图像: {self.image_root}")
# Resize(128,256) 在 torchvision 中为 (height, width) → (256,128) 常见于 ReID。
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if augment:
self.transform = transforms.Compose(
[
transforms.Resize((256, 128)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(
brightness=0.15, contrast=0.15, saturation=0.15, hue=0.02
),
transforms.ToTensor(),
normalize,
# RandomErasure 需作用在 Tensor 上,故放在 ToTensor 之后;
transforms.RandomErasing(p=0.5, scale=(0.02, 0.25), ratio=(0.3, 3.3)),
]
)
else:
self.transform = transforms.Compose(
[
transforms.Resize((256, 128)),
transforms.ToTensor(),
normalize,
]
)
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, idx: int):
img = Image.open(self.paths[idx]).convert("RGB")
x = self.transform(img)
y = torch.tensor(self.labels[idx], dtype=torch.long)
cam = torch.tensor(self.cam_ids[idx], dtype=torch.long)
return x, y, cam
class RandomIdentityPKSampler(Sampler[int]):
"""
PK / Random Identity SamplerRandom Identity Sampling
【原理简述】对比学习三元组需在 batch 内出现「同类多图」才能把 anchor / positive 找出来。
PK 策略:在每个 mini-batch 中固定结构为:
- 先随机选 **P** 个不同身份;
- 每个身份若无放回地抽 **K** 张图像(样本不足则从该身份放回随机抽满 K
则 batch_size = P * K。
这样保证了每个身份在同一个 batch 中至少有若干张可用于 Triplet
Batch-Hard 才能选「最难的正对 / 最难的负样本」。
【注意】P 不能超过数据集中可用身份总数;若 K > 该类张数则用放回采样。
"""
def __init__(
self,
labels: list[int],
p: int,
k: int,
*,
seed: int = 0,
length: int | None = None,
) -> None:
super().__init__()
self.labels_np = np.array(labels)
self.p = int(p)
self.k = int(k)
self.seed = seed
self.epoch = 0
self.identities = sorted(np.unique(self.labels_np).tolist())
self.num_identities = len(self.identities)
idx_by_label: dict[int, list[int]] = {}
for i, lbl in enumerate(self.labels_np.tolist()):
idx_by_label.setdefault(int(lbl), []).append(i)
self.idx_by_label = {k_: np.array(v) for k_, v in idx_by_label.items()}
if self.num_identities == 0:
raise RuntimeError("PKSampler: 无任何身份类别")
if self.p > self.num_identities:
raise ValueError(
f"P={self.p} 大于数据集身份数 {self.num_identities}"
" 请减小 --p。"
)
# 每个 epoch 内「迭代步数」batch 数(每次 batch 含 P×K 张图)
self.labels_flat = labels
if length is None:
# 约按「扫过一遍身份组合」的量级设一个稳定值
self.num_batches_est = max(32, min(200, len(self.labels_flat) // max(1, self.p * self.k)))
else:
self.num_batches_est = int(length)
self.batch_size = self.p * self.k
def __len__(self) -> int:
return self.num_batches_est * self.batch_size
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
def __iter__(self) -> Iterator[int]:
rng = np.random.RandomState((self.epoch * 9973 + self.seed) & 0xFFFFFFFF)
ids_arr = np.asarray(self.identities, dtype=np.int64)
for _batch_id in range(self.num_batches_est):
# 一步:无放回抽 P 个身份(若身份不够则允许放回,实际 5 类且 P≤5 时总是无放回)
if ids_arr.size >= self.p:
chosen = rng.choice(ids_arr, size=self.p, replace=False)
else:
chosen = rng.choice(ids_arr, size=self.p, replace=True)
out: list[int] = []
for pid_pick in chosen.tolist():
pool = self.idx_by_label[int(pid_pick)]
if pool.size >= self.k:
idx_pick = rng.choice(pool, size=self.k, replace=False)
else:
idx_pick = rng.choice(pool, size=self.k, replace=True)
out.extend(int(t) for t in idx_pick)
yield from out
def batch_hard_triplet_loss(embeddings: torch.Tensor, labels: torch.Tensor, margin: float) -> torch.Tensor:
"""
Batch-Hard Triplet LossHermans ECCV17 一类的标准形式)
对每个 anchorbatch 里的每个样本 i在同一 batch 中选:
- Positive与它同身份的样本 j 里面,距离**最大**的那一个( hardest positive
- Negative与它不同身份的样本里距离**最小**的那一个( hardest negative
单项损失通常为 relu(d_pos_hard - d_neg_hard + margin) 对每个有效 anchor 求平均。
【数学含义】拉大「最难正对」相对「最易负样本」的间隔。
【实现】使用欧氏距离;若某样本在 batch 内无「异类」(极少见)或无「同类第二张」则不参与均值。
"""
dist = torch.cdist(embeddings.float(), embeddings.float(), p=2.0).clamp(min=1e-8)
bs = dist.size(0)
lbl = labels.long()
same = lbl.unsqueeze(1).eq(lbl.unsqueeze(0))
losses: list[torch.Tensor] = []
for i in range(bs):
same_pos = same[i].clone()
same_pos[i] = False
if not same_pos.any():
continue
hardest_pos = dist[i][same_pos].max()
neg_mask = ~same[i]
neg_mask[i] = False
if not neg_mask.any():
continue
hardest_neg = dist[i][neg_mask].min()
losses.append(F.relu(hardest_pos - hardest_neg + margin))
if not losses:
return embeddings.sum() * 0.0
return torch.stack(losses).mean()
class ReIDEmbedModel(nn.Module):
"""ResNet50 预训练 backbone去掉 fc→ BN 512 嵌入 → logitsnum_classes"""
def __init__(self, num_classes: int, feat_dim: int = 512) -> None:
super().__init__()
backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
self.backbone = nn.Sequential(*list(backbone.children())[:-2]) # 到 GAP 前,输出 [B,2048,7,7]
self.gap = nn.AdaptiveAvgPool2d((1, 1))
self.bottleneck = nn.Sequential(
nn.Linear(2048, feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(inplace=True),
)
self.classifier = nn.Linear(feat_dim, num_classes)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = self.backbone(x)
h = self.gap(h)
h = h.view(h.size(0), -1)
emb = self.bottleneck(h)
logits = self.classifier(emb)
return emb, logits
def collate_fn_pk(batch):
xs, ys, cams = zip(*batch, strict=True)
return torch.stack(xs, dim=0), torch.stack(ys, dim=0), torch.stack(cams, dim=0)
def train_one_epoch(
model: nn.Module,
loader: DataLoader,
optim: torch.optim.Optimizer,
device: torch.device,
margin: float,
triplet_w: float,
id_w: float,
) -> tuple[float, float, float]:
model.train()
sum_t = 0.0
sum_id = 0.0
n = 0
ce = nn.CrossEntropyLoss()
for x, y, _cam in loader:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
emb, logits = model(x)
emb = F.normalize(emb, p=2, dim=1)
loss_t = batch_hard_triplet_loss(emb, y, margin=margin)
loss_id = ce(logits, y)
loss = triplet_w * loss_t + id_w * loss_id
optim.zero_grad()
loss.backward()
optim.step()
bs = x.size(0)
sum_t += loss_t.detach().item() * bs
sum_id += loss_id.detach().item() * bs
n += bs
return sum_t / max(n, 1), sum_id / max(n, 1)
@torch.no_grad()
def estimate_id_accuracy(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
model.eval()
correct = 0
total = 0
for x, y, _ in loader:
x = x.to(device)
y = y.to(device)
_, logits = model(x)
pred = logits.argmax(dim=1)
correct += int((pred == y).sum().item())
total += y.numel()
return correct / max(total, 1)
def resolve_image_root(cli: str | Path) -> Path:
p = Path(cli).resolve()
sub = p / "doctor_picture"
if sub.is_dir():
return sub
if p.is_dir():
jpgs = list(p.glob("*.jpg"))
if len(jpgs) > 0:
return p
raise FileNotFoundError(f"未找到图像目录(可传 doctor_info_detect 或 doctor_picture: {cli}")
def parse_args():
ap = argparse.ArgumentParser(description="Doctor Re-ID PK Triplet + ID 训练")
ap.add_argument(
"--data-root",
type=Path,
default=Path(__file__).resolve().parent,
help="含 doctor_picture 或图片根目录的路径",
)
ap.add_argument("--epochs", type=int, default=50, help="epoch 建议在 4060默认 50")
ap.add_argument(
"--batch-p",
type=int,
default=5,
help="PK 采样:每 batch 采样的身份数 P将自动不大于身份总数如默认 5 类)",
)
ap.add_argument(
"--batch-k",
type=int,
default=8,
help="PK 采样:每位身份抽样张数 Kbatch_size=P×K",
)
ap.add_argument("--lr", type=float, default=3e-4)
ap.add_argument("--triplet-margin", type=float, default=0.3)
ap.add_argument("--triplet-weight", type=float, default=1.0)
ap.add_argument("--id-weight", type=float, default=2.0, help="ID LossCrossEntropy权重")
ap.add_argument("--workers", type=int, default=4)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument(
"--save",
type=Path,
default=Path(__file__).resolve().parent / "doctor_reid_best.pth",
help="最佳权重路径",
)
return ap.parse_args()
def main() -> None:
args = parse_args()
rng = random.Random(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
img_root = resolve_image_root(args.data_root)
ds_train_aug = DoctorReIDDataset(img_root, augment=True)
ds_eval = DoctorReIDDataset(img_root, augment=False)
n = len(ds_train_aug)
perm = list(range(n))
rng.shuffle(perm)
n_val = max(32, int(0.1 * n))
val_ix = sorted(perm[:n_val])
train_ix = perm[n_val:]
labels_tr = [ds_train_aug.labels[i] for i in train_ix]
num_classes = len(ds_train_aug.pid_to_label)
p_eff = min(args.batch_p, len(set(labels_tr)))
from torch.utils.data import Subset
sampler = RandomIdentityPKSampler(
labels_tr,
p=p_eff,
k=args.batch_k,
seed=args.seed,
)
train_loader = DataLoader(
Subset(ds_train_aug, train_ix),
batch_size=sampler.batch_size,
sampler=sampler,
num_workers=args.workers,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn_pk,
)
val_loader = DataLoader(
Subset(ds_eval, val_ix),
batch_size=64,
shuffle=False,
num_workers=args.workers,
collate_fn=collate_fn_pk,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(device)
optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4)
scheduler = CosineAnnealingLR(optim, T_max=args.epochs, eta_min=1e-6)
best_acc = -1.0
for epoch in range(1, args.epochs + 1):
sampler.set_epoch(epoch)
tr_t, tr_id = train_one_epoch(
model,
train_loader,
optim,
device,
margin=args.triplet_margin,
triplet_w=args.triplet_weight,
id_w=args.id_weight,
)
scheduler.step()
lr_now = optim.param_groups[0]["lr"]
val_acc = estimate_id_accuracy(model, val_loader, device)
print(
f"epoch {epoch:03d}/{args.epochs} | "
f"triplet {tr_t:.4f} | id_loss_ce {tr_id:.4f} | "
f"lr {lr_now:.6f} | val_id_acc ~ {val_acc:.4f}"
)
if val_acc >= best_acc:
best_acc = val_acc
torch.save(
{
"epoch": epoch,
"model_state": model.state_dict(),
"num_classes": num_classes,
"pid_to_label": ds_train_aug.pid_to_label,
"args": vars(args),
},
args.save,
)
print(f"[保存] checkpoint → {args.save} (best val_id_acc {best_acc:.4f})")
print(f"训练结束。最佳 val_id_acc≈{best_acc:.4f}, 权重: {args.save}")
if __name__ == "__main__":
main()

View File

@@ -25,6 +25,7 @@ for w in hand_detect.pt goodbad_frame.pt haocai_classify.pt; do
test -f "weights/$w" && echo " OK weights/$w" || echo " 缺失 weights/$w"
done
test -f doctor_identity_package/doctor_info.pth && echo " OK doctor_info.pth" || echo " 缺失 doctor_info.pth"
test -f doctor_identity_package/train_reid_contrastive.py && echo " OK train_reid_contrastive.py" || echo " 缺失 train_reid_contrastive.py"
test -f input/视频中的商品信息表.xlsx && echo " OK Excel" || echo " 缺失 Excel"
echo ""

View File

@@ -129,6 +129,7 @@ def load_run_config(pack_root: Path, config_path: Path) -> Namespace:
gap_merge_enabled=bool(gm.get("enabled", False)),
gap_merge_max_gap_sec=float(gm.get("max_gap_sec", 2.0)),
doctor_identity_enabled=bool(did.get("enabled", True)),
doctor_identity_stream_enabled=bool(did.get("stream_enabled", True)),
doctor_identity_checkpoint=_rel(pack_root, doctor_ckpt_raw),
doctor_identity_labels_csv=_rel(pack_root, doctor_labels_raw),
doctor_identity_pose_min_detection_confidence=float(
@@ -137,6 +138,9 @@ def load_run_config(pack_root: Path, config_path: Path) -> Namespace:
doctor_identity_min_identity_confidence=float(did.get("min_identity_confidence", 0.0)),
doctor_identity_middle_seconds=float(did.get("middle_seconds", 10.0)),
doctor_identity_sample_fps=float(did.get("sample_fps", 3.0)),
doctor_identity_segment_sample_fps=float(
did.get("segment_sample_fps", did.get("sample_fps", 3.0))
),
doctor_identity_pad_frac=float(did.get("pad_frac", 0.15)),
basket_det_conf=float(bk.get("det_conf", p2["det_conf"])),
basket_contact_iou_threshold=legacy_contact_iou,

202
src/doctor_identity.py Normal file
View File

@@ -0,0 +1,202 @@
"""医生身份识别:离线整片 + 推流段内(与耗材并行)。"""
from __future__ import annotations
import importlib.util
import sys
from argparse import Namespace
from pathlib import Path
from typing import Any
import numpy as np
import torch
PACK_ROOT = Path(__file__).resolve().parent.parent
def _load_doctor_module(script_path: Path) -> Any:
spec = importlib.util.spec_from_file_location("doctor_identity_runtime", script_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"无法加载医生识别脚本: {script_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def _doctor_script_path() -> Path:
return PACK_ROOT / "doctor_identity_package" / "infer_doctor_from_video.py"
class DoctorIdentityService:
"""一次性加载 Pose + ReID供推流每段复用。"""
def __init__(self, args: Namespace) -> None:
self.args = args
self._mod: Any | None = None
self._landmarker: Any | None = None
self._reid_model: Any | None = None
self._reid_device: torch.device | None = None
self._label_to_pid: dict[int, str] | None = None
self._transform: Any | None = None
self._name_map: dict[str, str] = {}
def _ensure_loaded(self) -> Any:
if self._mod is not None:
return self._mod
script_path = _doctor_script_path()
if not script_path.is_file():
raise FileNotFoundError(f"缺少脚本: {script_path}")
pack_dir = script_path.parent
if str(pack_dir) not in sys.path:
sys.path.insert(0, str(pack_dir))
mod = _load_doctor_module(script_path)
checkpoint = Path(self.args.doctor_identity_checkpoint).resolve()
labels_csv = Path(self.args.doctor_identity_labels_csv).resolve()
if not checkpoint.is_file():
raise FileNotFoundError(f"缺少权重: {checkpoint}")
if not labels_csv.is_file():
raise FileNotFoundError(f"缺少标签映射: {labels_csv}")
model_path = mod._ensure_pose_lite_model(pack_dir / ".mediapipe_models")
opts = mod.PoseLandmarkerOptions(
base_options=mod.BaseOptions(model_asset_path=str(model_path)),
running_mode=mod.VisionRunningMode.IMAGE,
min_pose_detection_confidence=float(
self.args.doctor_identity_pose_min_detection_confidence
),
)
self._landmarker = mod.PoseLandmarker.create_from_options(opts)
self._reid_model, self._reid_device, self._label_to_pid, self._transform = (
mod.load_reid_model(checkpoint)
)
self._name_map = mod.load_name_mapping(labels_csv)
self._mod = mod
return mod
def close(self) -> None:
if self._landmarker is not None:
self._landmarker.close()
self._landmarker = None
def _format_result(self, raw_pid: str, conf: float) -> dict[str, Any]:
min_conf = float(self.args.doctor_identity_min_identity_confidence)
name = self._name_map.get(str(raw_pid), "")
low = conf < min_conf
return {
"ok": True,
"doctor_id": str(raw_pid),
"doctor_name": name,
"doctor_conf": conf,
"low_confidence": low,
}
def infer_segment(
self,
*,
start_sec: float,
end_sec: float,
video_path: Path | None = None,
use_file_source: bool = False,
frames: list[tuple[float, np.ndarray]] | None = None,
) -> dict[str, Any]:
"""段内医生识别。失败返回 ok=False 与 reason。"""
try:
mod = self._ensure_loaded()
pad_frac = float(self.args.doctor_identity_pad_frac)
sample_fps = float(
getattr(
self.args,
"doctor_identity_segment_sample_fps",
getattr(self.args, "doctor_identity_sample_fps", 3.0),
)
)
if use_file_source and video_path is not None and video_path.is_file():
best_crop = mod.pick_best_person_crop_in_window(
video_path,
self._landmarker,
start_sec,
end_sec,
sample_fps,
pad_frac,
)
elif frames:
best_crop = mod.pick_best_person_crop_from_frames(
frames,
self._landmarker,
start_sec,
end_sec,
pad_frac,
)
else:
return {"ok": False, "reason": "无可用视频源或缓存帧"}
raw_pid, conf = mod.run_inference_preloaded(
best_crop,
self._reid_model,
self._reid_device,
self._label_to_pid,
self._transform,
)
return self._format_result(raw_pid, conf)
except Exception as exc: # noqa: BLE001
return {"ok": False, "reason": str(exc)}
def infer_whole_video(self, video_path: Path) -> str:
"""离线全片:取视频中间窗口识别,返回展示用文本。"""
if not bool(getattr(self.args, "doctor_identity_enabled", True)):
return "未启用"
try:
mod = self._ensure_loaded()
best_crop = mod.pick_best_person_crop(
video_path=video_path,
landmarker=self._landmarker,
middle_seconds=float(self.args.doctor_identity_middle_seconds),
sample_fps=float(self.args.doctor_identity_sample_fps),
pad_frac=float(self.args.doctor_identity_pad_frac),
)
raw_pid, conf = mod.run_inference_preloaded(
best_crop,
self._reid_model,
self._reid_device,
self._label_to_pid,
self._transform,
)
res = self._format_result(raw_pid, conf)
suffix = " [低置信度]" if res.get("low_confidence") else ""
name = res.get("doctor_name") or ""
if name:
return f"{name} (id={raw_pid}, conf={conf:.4f}){suffix}"
return f"doctor_id={raw_pid} (conf={conf:.4f}){suffix}"
except Exception as exc: # noqa: BLE001
return f"识别失败({exc}"
def stream_doctor_enabled(args: Namespace) -> bool:
return bool(getattr(args, "doctor_identity_enabled", True)) and bool(
getattr(args, "doctor_identity_stream_enabled", True)
)
def infer_doctor_text_offline(args: Namespace, video_path: Path) -> str:
"""离线入口:校验资源后返回医生信息文本。"""
if not bool(getattr(args, "doctor_identity_enabled", True)):
return "未启用"
checkpoint = Path(args.doctor_identity_checkpoint).resolve()
labels_csv = Path(args.doctor_identity_labels_csv).resolve()
if not checkpoint.is_file():
return f"识别失败(缺少权重: {checkpoint}"
if not labels_csv.is_file():
return f"识别失败(缺少标签映射: {labels_csv}"
if not _doctor_script_path().is_file():
return f"识别失败(缺少脚本: {_doctor_script_path()}"
svc = DoctorIdentityService(args)
try:
return svc.infer_whole_video(video_path.resolve())
finally:
svc.close()

View File

@@ -1,7 +1,6 @@
"""主流程编排:与仓库 main_pipeline.PipelineManager 逻辑一致,参数来自 YAMLSimpleNamespace"""
from __future__ import annotations
import importlib.util
import tempfile
from argparse import Namespace
from pathlib import Path
@@ -26,69 +25,11 @@ from run_segments_consumable_vote import pad_box_bottom_only as _pad_box
from ultralytics import YOLO
from basket_segmenter import build_segments_from_basket
from doctor_identity import infer_doctor_text_offline
from pack_utils import load_allowed_names_from_excel, log, resolve_allowed_class_idx
from stream_orchestrator import _haocai_infer_kwargs
def _load_doctor_module(script_path: Path) -> Any:
spec = importlib.util.spec_from_file_location("doctor_identity_runtime", script_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"无法加载医生识别脚本: {script_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def _infer_doctor_text(args: Namespace, video_path: Path) -> str:
if not bool(getattr(args, "doctor_identity_enabled", True)):
return "未启用"
checkpoint = Path(args.doctor_identity_checkpoint).resolve()
labels_csv = Path(args.doctor_identity_labels_csv).resolve()
if not checkpoint.is_file():
return f"识别失败(缺少权重: {checkpoint}"
if not labels_csv.is_file():
return f"识别失败(缺少标签映射: {labels_csv}"
pack_root = Path(__file__).resolve().parent.parent
script_path = pack_root / "doctor_identity_package" / "infer_doctor_from_video.py"
if not script_path.is_file():
return f"识别失败(缺少脚本: {script_path}"
try:
doctor_mod = _load_doctor_module(script_path)
model_path = doctor_mod._ensure_pose_lite_model(script_path.parent / ".mediapipe_models")
opts = doctor_mod.PoseLandmarkerOptions(
base_options=doctor_mod.BaseOptions(model_asset_path=str(model_path)),
running_mode=doctor_mod.VisionRunningMode.IMAGE,
min_pose_detection_confidence=float(
args.doctor_identity_pose_min_detection_confidence
),
)
landmarker = doctor_mod.PoseLandmarker.create_from_options(opts)
try:
best_crop = doctor_mod.pick_best_person_crop(
video_path=video_path,
landmarker=landmarker,
middle_seconds=float(args.doctor_identity_middle_seconds),
sample_fps=float(args.doctor_identity_sample_fps),
pad_frac=float(args.doctor_identity_pad_frac),
)
finally:
landmarker.close()
raw_pid, conf = doctor_mod.run_inference(best_crop, checkpoint)
min_conf = float(args.doctor_identity_min_identity_confidence)
name_map = doctor_mod.load_name_mapping(labels_csv)
doctor_name = name_map.get(str(raw_pid), "")
suffix = " [低置信度]" if conf < min_conf else ""
if doctor_name:
return f"{doctor_name} (id={raw_pid}, conf={conf:.4f}){suffix}"
return f"doctor_id={raw_pid} (conf={conf:.4f}){suffix}"
except Exception as exc: # noqa: BLE001
return f"识别失败({exc}"
def _resolve_allowed_names(args: Namespace, excel_path: Path) -> list[str] | None:
if not getattr(args, "use_whitelist", True):
return []
@@ -394,7 +335,7 @@ class PipelineManager:
cap.release()
log("医生识别:开始执行…")
doctor_text = _infer_doctor_text(args, video_path)
doctor_text = infer_doctor_text_offline(args, video_path)
log(f"医生识别:{doctor_text}")
lines_out.append(f"医生信息:{doctor_text}")

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import gc
import time
from argparse import Namespace
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any
@@ -23,9 +24,26 @@ from basket_segmenter import (
_select_basket_roi_tkinter,
save_basket_roi_json,
)
from doctor_identity import DoctorIdentityService, stream_doctor_enabled
from pack_utils import log, resolve_allowed_class_idx
from stream_basket_session import CachedClip, StreamBasketSession
_HAOCAI_COLS = [
"rank",
"start_sec",
"end_sec",
"product_id_top1",
"top1_name",
"top1_conf",
"product_id_top2",
"top2_name",
"top2_conf",
"product_id_top3",
"top3_name",
"top3_conf",
]
_DOCTOR_COLS = ["doctor_id", "doctor_name", "doctor_conf"]
def _validate_stream_weights(args: Namespace) -> bool:
for p, lab in (
@@ -65,6 +83,14 @@ def _resolve_basket_roi(
return roi
def _doctor_row_cells(doc: dict[str, Any] | None) -> list[str]:
if doc is None or not doc.get("ok"):
return ["", "", ""]
conf = doc.get("doctor_conf")
conf_s = f"{float(conf):.6f}" if conf is not None else ""
return [str(doc.get("doctor_id", "")), str(doc.get("doctor_name", "")), conf_s]
def _format_result_row(
rank: int,
t0: float,
@@ -73,6 +99,8 @@ def _format_result_row(
product_map: dict[str, str],
*,
legacy_12_col: bool,
include_doctor_cols: bool = False,
doctor: dict[str, Any] | None = None,
) -> str:
sep = "\t"
if not info.get("ok"):
@@ -93,6 +121,8 @@ def _format_result_row(
]
if not legacy_12_col:
row.extend(["", ""])
if include_doctor_cols:
row.extend(_doctor_row_cells(doctor))
return sep.join(row)
n1, n2, n3 = info["top_names"]
@@ -120,9 +150,105 @@ def _format_result_row(
]
if not legacy_12_col:
row.extend(["", ""])
if include_doctor_cols:
row.extend(_doctor_row_cells(doctor))
return sep.join(row)
def _log_doctor_result(rank: int, doc: dict[str, Any] | None) -> None:
if doc is None:
return
if doc.get("ok"):
name = doc.get("doctor_name") or doc.get("doctor_id", "")
conf = doc.get("doctor_conf", 0.0)
low = " [低置信度]" if doc.get("low_confidence") else ""
log(f"[stream] rank={rank} 医生: {name} (id={doc.get('doctor_id')}, conf={conf:.4f}){low}")
else:
log(f"[stream] rank={rank} 医生识别失败: {doc.get('reason', '')}")
def _process_one_clip(
rank: int,
clip: CachedClip,
*,
det: YOLO,
hc: HaocaiOnlyClassifier,
infer_cap: cv2.VideoCapture | None,
use_file_infer: bool,
is_file: bool,
source: str,
args: Namespace,
cls_names: dict,
allowed_idx: frozenset[int] | None,
predict_kw: dict[str, Any],
product_map: dict[str, str],
out_path: Path,
doctor_svc: DoctorIdentityService | None,
include_doctor_cols: bool,
) -> None:
log(
f"[stream] 识别 rank={rank} [{clip.start_sec:.3f},{clip.end_sec:.3f}] "
f"({len(clip.frames)} 帧)…"
)
frames_copy = list(clip.frames) if doctor_svc is not None and not use_file_infer else None
video_path = Path(source).resolve() if is_file else None
doc: dict[str, Any] | None = None
if doctor_svc is None:
info = _infer_clip(
clip,
det=det,
hc=hc,
cap=infer_cap,
use_file_infer=use_file_infer,
args=args,
cls_names=cls_names,
allowed_idx=allowed_idx,
predict_kw=predict_kw,
rank=rank,
)
else:
with ThreadPoolExecutor(max_workers=2) as pool:
fut_h = pool.submit(
_infer_clip,
clip,
det=det,
hc=hc,
cap=infer_cap,
use_file_infer=use_file_infer,
args=args,
cls_names=cls_names,
allowed_idx=allowed_idx,
predict_kw=predict_kw,
rank=rank,
)
fut_d = pool.submit(
doctor_svc.infer_segment,
start_sec=clip.start_sec,
end_sec=clip.end_sec,
video_path=video_path,
use_file_source=use_file_infer and is_file,
frames=frames_copy,
)
info = fut_h.result()
doc = fut_d.result()
line = _format_result_row(
rank,
clip.start_sec,
clip.end_sec,
info,
product_map,
legacy_12_col=bool(args.legacy_12_col_only),
include_doctor_cols=include_doctor_cols,
doctor=doc,
)
with out_path.open("a", encoding="utf-8") as f:
f.write(line + "\n")
_log_doctor_result(rank, doc)
log(f"[stream] rank={rank} 已写入")
def _maybe_free_gpu() -> None:
gc.collect()
try:
@@ -267,6 +393,17 @@ class StreamBasketOrchestrator:
else:
log("[stream] 白名单已关闭,使用全 41 类")
include_doctor_cols = stream_doctor_enabled(args)
doctor_svc: DoctorIdentityService | None = None
if include_doctor_cols:
try:
doctor_svc = DoctorIdentityService(args)
log("[stream] 医生身份识别已启用(段内与耗材并行)")
except Exception as exc: # noqa: BLE001
log(f"[stream] 医生识别初始化失败,本 run 不写入医生列: {exc}")
include_doctor_cols = False
doctor_svc = None
cap = cv2.VideoCapture(source)
if not cap.isOpened():
log(f"[stream] 无法打开流: {source}")
@@ -347,23 +484,10 @@ class StreamBasketOrchestrator:
fallback = str(getattr(args, "stream_infer_fallback", "cache"))
log(f"[stream] 段内识别: JPEG 缓存帧infer_fallback={fallback}")
header = "\t".join(
[
"rank",
"start_sec",
"end_sec",
"product_id_top1",
"top1_name",
"top1_conf",
"product_id_top2",
"top2_name",
"top2_conf",
"product_id_top3",
"top3_name",
"top3_conf",
]
)
out_path.write_text(header + "\n", encoding="utf-8")
header_cols = list(_HAOCAI_COLS)
if include_doctor_cols:
header_cols.extend(_DOCTOR_COLS)
out_path.write_text("\t".join(header_cols) + "\n", encoding="utf-8")
rank = 0
frame_idx = 0
@@ -371,33 +495,24 @@ class StreamBasketOrchestrator:
nonlocal rank
for clip in session.poll_ready_clips():
rank += 1
log(
f"[stream] 识别 rank={rank} [{clip.start_sec:.3f},{clip.end_sec:.3f}] "
f"({len(clip.frames)} 帧)…"
)
info = _infer_clip(
_process_one_clip(
rank,
clip,
det=det,
hc=hc,
cap=infer_cap,
infer_cap=infer_cap,
use_file_infer=use_file_infer,
is_file=is_file,
source=source,
args=args,
cls_names=cls_names,
allowed_idx=allowed_idx,
predict_kw=predict_kw,
rank=rank,
product_map=product_map,
out_path=out_path,
doctor_svc=doctor_svc,
include_doctor_cols=include_doctor_cols,
)
line = _format_result_row(
rank,
clip.start_sec,
clip.end_sec,
info,
product_map,
legacy_12_col=bool(args.legacy_12_col_only),
)
with out_path.open("a", encoding="utf-8") as f:
f.write(line + "\n")
log(f"[stream] rank={rank} 已写入")
session.push_frame(t0, first)
process_ready()
@@ -425,34 +540,31 @@ class StreamBasketOrchestrator:
except KeyboardInterrupt:
log("[stream] 用户中断")
finally:
for clip in session.poll_ready_clips():
rank += 1
_process_one_clip(
rank,
clip,
det=det,
hc=hc,
infer_cap=infer_cap,
use_file_infer=use_file_infer,
is_file=is_file,
source=source,
args=args,
cls_names=cls_names,
allowed_idx=allowed_idx,
predict_kw=predict_kw,
product_map=product_map,
out_path=out_path,
doctor_svc=doctor_svc,
include_doctor_cols=include_doctor_cols,
)
cap.release()
if infer_cap is not None:
infer_cap.release()
for clip in session.poll_ready_clips():
rank += 1
info = _infer_clip(
clip,
det=det,
hc=hc,
cap=infer_cap,
use_file_infer=use_file_infer,
args=args,
cls_names=cls_names,
allowed_idx=allowed_idx,
predict_kw=predict_kw,
rank=rank,
)
line = _format_result_row(
rank,
clip.start_sec,
clip.end_sec,
info,
product_map,
legacy_12_col=bool(args.legacy_12_col_only),
)
with out_path.open("a", encoding="utf-8") as f:
f.write(line + "\n")
if doctor_svc is not None:
doctor_svc.close()
log(f"[stream] 结束,共 {rank} 段,结果: {out_path}")
return 0 if rank > 0 or is_file else 0