#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 人体身份重识别(Person ReID)对比学习训练脚本(单机单文件)。 - Dataset:Market-1501 风格文件名解析 pid/cam; - PK 采样器:每个 batch 内 P 个身份 × 每张 K 样本; - 模型:ImageNet ResNet50 骨干 + GAP + 512 维嵌入 + ID 分类头; - 损失:Batch-Hard Triplet + ID(交叉熵)联合; """ from __future__ import annotations import argparse import random import re from pathlib import Path from typing import Iterator import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader, Dataset, Sampler from torchvision import models, transforms # --------------------------------------------------------------------------- # 文件名解析:例如 24502_c1_s1_00001.jpg → pid=24502,cam_id=1 # --------------------------------------------------------------------------- _NAME_RE = re.compile( r"^(?P\d+)_c(?P\d+)_", flags=re.I, ) def parse_market1501_style_name(stem: str) -> tuple[int | None, int | None]: """从「不含后缀」的文件名 stem 中提取身份 ID、机位 ID。""" m = _NAME_RE.match(stem) if not m: return None, None return int(m.group("pid")), int(m.group("cam")) class DoctorReIDDataset(Dataset): """医生 ReID:解析 pid/cam,将 pid 重映射到 0..num_classes-1。""" def __init__(self, image_root: Path, augment: bool) -> None: self.image_root = Path(image_root).resolve() exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} paths = sorted( p for p in self.image_root.rglob("*") if p.is_file() and p.suffix.lower() in exts ) pid_raw_list: list[int] = [] cam_raw_list: list[int] = [] valid_paths: list[Path] = [] for p in paths: pid_raw, cam_raw = parse_market1501_style_name(p.stem) if pid_raw is None: continue valid_paths.append(p) pid_raw_list.append(pid_raw) cam_raw_list.append(int(cam_raw)) unique_pids = sorted(set(pid_raw_list)) self.pid_to_label = {pid: i for i, pid in enumerate(unique_pids)} self.labels: list[int] = [self.pid_to_label[r] for r in pid_raw_list] self.cam_ids: list[int] = cam_raw_list self.paths = valid_paths if len(self.paths) == 0: raise RuntimeError(f"目录下未发现有效图像: {self.image_root}") # Resize(128,256) 在 torchvision 中为 (height, width) → (256,128) 常见于 ReID。 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if augment: self.transform = transforms.Compose( [ transforms.Resize((256, 128)), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter( brightness=0.15, contrast=0.15, saturation=0.15, hue=0.02 ), transforms.ToTensor(), normalize, # RandomErasure 需作用在 Tensor 上,故放在 ToTensor 之后; transforms.RandomErasing(p=0.5, scale=(0.02, 0.25), ratio=(0.3, 3.3)), ] ) else: self.transform = transforms.Compose( [ transforms.Resize((256, 128)), transforms.ToTensor(), normalize, ] ) def __len__(self) -> int: return len(self.paths) def __getitem__(self, idx: int): img = Image.open(self.paths[idx]).convert("RGB") x = self.transform(img) y = torch.tensor(self.labels[idx], dtype=torch.long) cam = torch.tensor(self.cam_ids[idx], dtype=torch.long) return x, y, cam class RandomIdentityPKSampler(Sampler[int]): """ PK / Random Identity Sampler(Random Identity Sampling) 【原理简述】对比学习三元组需在 batch 内出现「同类多图」才能把 anchor / positive 找出来。 PK 策略:在每个 mini-batch 中固定结构为: - 先随机选 **P** 个不同身份; - 每个身份若无放回地抽 **K** 张图像(样本不足则从该身份放回随机抽满 K); 则 batch_size = P * K。 这样保证了每个身份在同一个 batch 中至少有若干张可用于 Triplet, Batch-Hard 才能选「最难的正对 / 最难的负样本」。 【注意】P 不能超过数据集中可用身份总数;若 K > 该类张数则用放回采样。 """ def __init__( self, labels: list[int], p: int, k: int, *, seed: int = 0, length: int | None = None, ) -> None: super().__init__() self.labels_np = np.array(labels) self.p = int(p) self.k = int(k) self.seed = seed self.epoch = 0 self.identities = sorted(np.unique(self.labels_np).tolist()) self.num_identities = len(self.identities) idx_by_label: dict[int, list[int]] = {} for i, lbl in enumerate(self.labels_np.tolist()): idx_by_label.setdefault(int(lbl), []).append(i) self.idx_by_label = {k_: np.array(v) for k_, v in idx_by_label.items()} if self.num_identities == 0: raise RuntimeError("PKSampler: 无任何身份类别") if self.p > self.num_identities: raise ValueError( f"P={self.p} 大于数据集身份数 {self.num_identities}。" " 请减小 --p。" ) # 每个 epoch 内「迭代步数」:batch 数(每次 batch 含 P×K 张图) self.labels_flat = labels if length is None: # 约按「扫过一遍身份组合」的量级设一个稳定值 self.num_batches_est = max(32, min(200, len(self.labels_flat) // max(1, self.p * self.k))) else: self.num_batches_est = int(length) self.batch_size = self.p * self.k def __len__(self) -> int: return self.num_batches_est * self.batch_size def set_epoch(self, epoch: int) -> None: self.epoch = epoch def __iter__(self) -> Iterator[int]: rng = np.random.RandomState((self.epoch * 9973 + self.seed) & 0xFFFFFFFF) ids_arr = np.asarray(self.identities, dtype=np.int64) for _batch_id in range(self.num_batches_est): # 一步:无放回抽 P 个身份(若身份不够则允许放回,实际 5 类且 P≤5 时总是无放回) if ids_arr.size >= self.p: chosen = rng.choice(ids_arr, size=self.p, replace=False) else: chosen = rng.choice(ids_arr, size=self.p, replace=True) out: list[int] = [] for pid_pick in chosen.tolist(): pool = self.idx_by_label[int(pid_pick)] if pool.size >= self.k: idx_pick = rng.choice(pool, size=self.k, replace=False) else: idx_pick = rng.choice(pool, size=self.k, replace=True) out.extend(int(t) for t in idx_pick) yield from out def batch_hard_triplet_loss(embeddings: torch.Tensor, labels: torch.Tensor, margin: float) -> torch.Tensor: """ Batch-Hard Triplet Loss(Hermans ECCV17 一类的标准形式) 对每个 anchor(batch 里的每个样本 i)在同一 batch 中选: - Positive:与它同身份的样本 j 里面,距离**最大**的那一个( hardest positive ); - Negative:与它不同身份的样本里,距离**最小**的那一个( hardest negative ); 单项损失通常为 relu(d_pos_hard - d_neg_hard + margin) 对每个有效 anchor 求平均。 【数学含义】拉大「最难正对」相对「最易负样本」的间隔。 【实现】使用欧氏距离;若某样本在 batch 内无「异类」(极少见)或无「同类第二张」则不参与均值。 """ dist = torch.cdist(embeddings.float(), embeddings.float(), p=2.0).clamp(min=1e-8) bs = dist.size(0) lbl = labels.long() same = lbl.unsqueeze(1).eq(lbl.unsqueeze(0)) losses: list[torch.Tensor] = [] for i in range(bs): same_pos = same[i].clone() same_pos[i] = False if not same_pos.any(): continue hardest_pos = dist[i][same_pos].max() neg_mask = ~same[i] neg_mask[i] = False if not neg_mask.any(): continue hardest_neg = dist[i][neg_mask].min() losses.append(F.relu(hardest_pos - hardest_neg + margin)) if not losses: return embeddings.sum() * 0.0 return torch.stack(losses).mean() class ReIDEmbedModel(nn.Module): """ResNet50 预训练 backbone(去掉 fc)→ BN 512 嵌入 → logits(num_classes)。""" def __init__(self, num_classes: int, feat_dim: int = 512) -> None: super().__init__() backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) self.backbone = nn.Sequential(*list(backbone.children())[:-2]) # 到 GAP 前,输出 [B,2048,7,7] self.gap = nn.AdaptiveAvgPool2d((1, 1)) self.bottleneck = nn.Sequential( nn.Linear(2048, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU(inplace=True), ) self.classifier = nn.Linear(feat_dim, num_classes) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: h = self.backbone(x) h = self.gap(h) h = h.view(h.size(0), -1) emb = self.bottleneck(h) logits = self.classifier(emb) return emb, logits def collate_fn_pk(batch): xs, ys, cams = zip(*batch, strict=True) return torch.stack(xs, dim=0), torch.stack(ys, dim=0), torch.stack(cams, dim=0) def train_one_epoch( model: nn.Module, loader: DataLoader, optim: torch.optim.Optimizer, device: torch.device, margin: float, triplet_w: float, id_w: float, ) -> tuple[float, float, float]: model.train() sum_t = 0.0 sum_id = 0.0 n = 0 ce = nn.CrossEntropyLoss() for x, y, _cam in loader: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) emb, logits = model(x) emb = F.normalize(emb, p=2, dim=1) loss_t = batch_hard_triplet_loss(emb, y, margin=margin) loss_id = ce(logits, y) loss = triplet_w * loss_t + id_w * loss_id optim.zero_grad() loss.backward() optim.step() bs = x.size(0) sum_t += loss_t.detach().item() * bs sum_id += loss_id.detach().item() * bs n += bs return sum_t / max(n, 1), sum_id / max(n, 1) @torch.no_grad() def estimate_id_accuracy(model: nn.Module, loader: DataLoader, device: torch.device) -> float: model.eval() correct = 0 total = 0 for x, y, _ in loader: x = x.to(device) y = y.to(device) _, logits = model(x) pred = logits.argmax(dim=1) correct += int((pred == y).sum().item()) total += y.numel() return correct / max(total, 1) def resolve_image_root(cli: str | Path) -> Path: p = Path(cli).resolve() sub = p / "doctor_picture" if sub.is_dir(): return sub if p.is_dir(): jpgs = list(p.glob("*.jpg")) if len(jpgs) > 0: return p raise FileNotFoundError(f"未找到图像目录(可传 doctor_info_detect 或 doctor_picture): {cli}") def parse_args(): ap = argparse.ArgumentParser(description="Doctor Re-ID PK Triplet + ID 训练") ap.add_argument( "--data-root", type=Path, default=Path(__file__).resolve().parent, help="含 doctor_picture 或图片根目录的路径", ) ap.add_argument("--epochs", type=int, default=50, help="epoch 建议在 40–60(默认 50)") ap.add_argument( "--batch-p", type=int, default=5, help="PK 采样:每 batch 采样的身份数 P(将自动不大于身份总数,如默认 5 类)", ) ap.add_argument( "--batch-k", type=int, default=8, help="PK 采样:每位身份抽样张数 K;batch_size=P×K", ) ap.add_argument("--lr", type=float, default=3e-4) ap.add_argument("--triplet-margin", type=float, default=0.3) ap.add_argument("--triplet-weight", type=float, default=1.0) ap.add_argument("--id-weight", type=float, default=2.0, help="ID Loss(CrossEntropy)权重") ap.add_argument("--workers", type=int, default=4) ap.add_argument("--seed", type=int, default=42) ap.add_argument( "--save", type=Path, default=Path(__file__).resolve().parent / "doctor_reid_best.pth", help="最佳权重路径", ) return ap.parse_args() def main() -> None: args = parse_args() rng = random.Random(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) img_root = resolve_image_root(args.data_root) ds_train_aug = DoctorReIDDataset(img_root, augment=True) ds_eval = DoctorReIDDataset(img_root, augment=False) n = len(ds_train_aug) perm = list(range(n)) rng.shuffle(perm) n_val = max(32, int(0.1 * n)) val_ix = sorted(perm[:n_val]) train_ix = perm[n_val:] labels_tr = [ds_train_aug.labels[i] for i in train_ix] num_classes = len(ds_train_aug.pid_to_label) p_eff = min(args.batch_p, len(set(labels_tr))) from torch.utils.data import Subset sampler = RandomIdentityPKSampler( labels_tr, p=p_eff, k=args.batch_k, seed=args.seed, ) train_loader = DataLoader( Subset(ds_train_aug, train_ix), batch_size=sampler.batch_size, sampler=sampler, num_workers=args.workers, pin_memory=True, drop_last=True, collate_fn=collate_fn_pk, ) val_loader = DataLoader( Subset(ds_eval, val_ix), batch_size=64, shuffle=False, num_workers=args.workers, collate_fn=collate_fn_pk, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(device) optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4) scheduler = CosineAnnealingLR(optim, T_max=args.epochs, eta_min=1e-6) best_acc = -1.0 for epoch in range(1, args.epochs + 1): sampler.set_epoch(epoch) tr_t, tr_id = train_one_epoch( model, train_loader, optim, device, margin=args.triplet_margin, triplet_w=args.triplet_weight, id_w=args.id_weight, ) scheduler.step() lr_now = optim.param_groups[0]["lr"] val_acc = estimate_id_accuracy(model, val_loader, device) print( f"epoch {epoch:03d}/{args.epochs} | " f"triplet {tr_t:.4f} | id_loss_ce {tr_id:.4f} | " f"lr {lr_now:.6f} | val_id_acc ~ {val_acc:.4f}" ) if val_acc >= best_acc: best_acc = val_acc torch.save( { "epoch": epoch, "model_state": model.state_dict(), "num_classes": num_classes, "pid_to_label": ds_train_aug.pid_to_label, "args": vars(args), }, args.save, ) print(f"[保存] checkpoint → {args.save} (best val_id_acc {best_acc:.4f})") print(f"训练结束。最佳 val_id_acc≈{best_acc:.4f}, 权重: {args.save}") if __name__ == "__main__": main()