Files
OperationRoomMonitor/doctor_identity_package/train_reid_contrastive.py

459 lines
16 KiB
Python
Raw Normal View History

2026-06-03 14:46:16 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
人体身份重识别Person ReID对比学习训练脚本单机单文件
- DatasetMarket-1501 风格文件名解析 pid/cam
- PK 采样器每个 batch P 个身份 × 每张 K 样本
- 模型ImageNet ResNet50 骨干 + GAP + 512 维嵌入 + ID 分类头
- 损失Batch-Hard Triplet + ID交叉熵联合
"""
from __future__ import annotations
import argparse
import random
import re
from pathlib import Path
from typing import Iterator
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision import models, transforms
# ---------------------------------------------------------------------------
# 文件名解析:例如 24502_c1_s1_00001.jpg → pid=24502cam_id=1
# ---------------------------------------------------------------------------
_NAME_RE = re.compile(
r"^(?P<pid>\d+)_c(?P<cam>\d+)_",
flags=re.I,
)
def parse_market1501_style_name(stem: str) -> tuple[int | None, int | None]:
"""从「不含后缀」的文件名 stem 中提取身份 ID、机位 ID。"""
m = _NAME_RE.match(stem)
if not m:
return None, None
return int(m.group("pid")), int(m.group("cam"))
class DoctorReIDDataset(Dataset):
"""医生 ReID解析 pid/cam将 pid 重映射到 0..num_classes-1。"""
def __init__(self, image_root: Path, augment: bool) -> None:
self.image_root = Path(image_root).resolve()
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
paths = sorted(
p
for p in self.image_root.rglob("*")
if p.is_file() and p.suffix.lower() in exts
)
pid_raw_list: list[int] = []
cam_raw_list: list[int] = []
valid_paths: list[Path] = []
for p in paths:
pid_raw, cam_raw = parse_market1501_style_name(p.stem)
if pid_raw is None:
continue
valid_paths.append(p)
pid_raw_list.append(pid_raw)
cam_raw_list.append(int(cam_raw))
unique_pids = sorted(set(pid_raw_list))
self.pid_to_label = {pid: i for i, pid in enumerate(unique_pids)}
self.labels: list[int] = [self.pid_to_label[r] for r in pid_raw_list]
self.cam_ids: list[int] = cam_raw_list
self.paths = valid_paths
if len(self.paths) == 0:
raise RuntimeError(f"目录下未发现有效图像: {self.image_root}")
# Resize(128,256) 在 torchvision 中为 (height, width) → (256,128) 常见于 ReID。
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if augment:
self.transform = transforms.Compose(
[
transforms.Resize((256, 128)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(
brightness=0.15, contrast=0.15, saturation=0.15, hue=0.02
),
transforms.ToTensor(),
normalize,
# RandomErasure 需作用在 Tensor 上,故放在 ToTensor 之后;
transforms.RandomErasing(p=0.5, scale=(0.02, 0.25), ratio=(0.3, 3.3)),
]
)
else:
self.transform = transforms.Compose(
[
transforms.Resize((256, 128)),
transforms.ToTensor(),
normalize,
]
)
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, idx: int):
img = Image.open(self.paths[idx]).convert("RGB")
x = self.transform(img)
y = torch.tensor(self.labels[idx], dtype=torch.long)
cam = torch.tensor(self.cam_ids[idx], dtype=torch.long)
return x, y, cam
class RandomIdentityPKSampler(Sampler[int]):
"""
PK / Random Identity SamplerRandom Identity Sampling
原理简述对比学习三元组需在 batch 内出现同类多图才能把 anchor / positive 找出来
PK 策略在每个 mini-batch 中固定结构为
- 先随机选 **P** 个不同身份
- 每个身份若无放回地抽 **K** 张图像样本不足则从该身份放回随机抽满 K
batch_size = P * K
这样保证了每个身份在同一个 batch 中至少有若干张可用于 Triplet
Batch-Hard 才能选最难的正对 / 最难的负样本
注意P 不能超过数据集中可用身份总数 K > 该类张数则用放回采样
"""
def __init__(
self,
labels: list[int],
p: int,
k: int,
*,
seed: int = 0,
length: int | None = None,
) -> None:
super().__init__()
self.labels_np = np.array(labels)
self.p = int(p)
self.k = int(k)
self.seed = seed
self.epoch = 0
self.identities = sorted(np.unique(self.labels_np).tolist())
self.num_identities = len(self.identities)
idx_by_label: dict[int, list[int]] = {}
for i, lbl in enumerate(self.labels_np.tolist()):
idx_by_label.setdefault(int(lbl), []).append(i)
self.idx_by_label = {k_: np.array(v) for k_, v in idx_by_label.items()}
if self.num_identities == 0:
raise RuntimeError("PKSampler: 无任何身份类别")
if self.p > self.num_identities:
raise ValueError(
f"P={self.p} 大于数据集身份数 {self.num_identities}"
" 请减小 --p。"
)
# 每个 epoch 内「迭代步数」batch 数(每次 batch 含 P×K 张图)
self.labels_flat = labels
if length is None:
# 约按「扫过一遍身份组合」的量级设一个稳定值
self.num_batches_est = max(32, min(200, len(self.labels_flat) // max(1, self.p * self.k)))
else:
self.num_batches_est = int(length)
self.batch_size = self.p * self.k
def __len__(self) -> int:
return self.num_batches_est * self.batch_size
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
def __iter__(self) -> Iterator[int]:
rng = np.random.RandomState((self.epoch * 9973 + self.seed) & 0xFFFFFFFF)
ids_arr = np.asarray(self.identities, dtype=np.int64)
for _batch_id in range(self.num_batches_est):
# 一步:无放回抽 P 个身份(若身份不够则允许放回,实际 5 类且 P≤5 时总是无放回)
if ids_arr.size >= self.p:
chosen = rng.choice(ids_arr, size=self.p, replace=False)
else:
chosen = rng.choice(ids_arr, size=self.p, replace=True)
out: list[int] = []
for pid_pick in chosen.tolist():
pool = self.idx_by_label[int(pid_pick)]
if pool.size >= self.k:
idx_pick = rng.choice(pool, size=self.k, replace=False)
else:
idx_pick = rng.choice(pool, size=self.k, replace=True)
out.extend(int(t) for t in idx_pick)
yield from out
def batch_hard_triplet_loss(embeddings: torch.Tensor, labels: torch.Tensor, margin: float) -> torch.Tensor:
"""
Batch-Hard Triplet LossHermans ECCV17 一类的标准形式
对每个 anchorbatch 里的每个样本 i在同一 batch 中选
- Positive与它同身份的样本 j 里面距离**最大**的那一个 hardest positive
- Negative与它不同身份的样本里距离**最小**的那一个 hardest negative
单项损失通常为 relu(d_pos_hard - d_neg_hard + margin) 对每个有效 anchor 求平均
数学含义拉大最难正对相对最易负样本的间隔
实现使用欧氏距离若某样本在 batch 内无异类极少见或无同类第二张则不参与均值
"""
dist = torch.cdist(embeddings.float(), embeddings.float(), p=2.0).clamp(min=1e-8)
bs = dist.size(0)
lbl = labels.long()
same = lbl.unsqueeze(1).eq(lbl.unsqueeze(0))
losses: list[torch.Tensor] = []
for i in range(bs):
same_pos = same[i].clone()
same_pos[i] = False
if not same_pos.any():
continue
hardest_pos = dist[i][same_pos].max()
neg_mask = ~same[i]
neg_mask[i] = False
if not neg_mask.any():
continue
hardest_neg = dist[i][neg_mask].min()
losses.append(F.relu(hardest_pos - hardest_neg + margin))
if not losses:
return embeddings.sum() * 0.0
return torch.stack(losses).mean()
class ReIDEmbedModel(nn.Module):
"""ResNet50 预训练 backbone去掉 fc→ BN 512 嵌入 → logitsnum_classes"""
def __init__(self, num_classes: int, feat_dim: int = 512) -> None:
super().__init__()
backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
self.backbone = nn.Sequential(*list(backbone.children())[:-2]) # 到 GAP 前,输出 [B,2048,7,7]
self.gap = nn.AdaptiveAvgPool2d((1, 1))
self.bottleneck = nn.Sequential(
nn.Linear(2048, feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(inplace=True),
)
self.classifier = nn.Linear(feat_dim, num_classes)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = self.backbone(x)
h = self.gap(h)
h = h.view(h.size(0), -1)
emb = self.bottleneck(h)
logits = self.classifier(emb)
return emb, logits
def collate_fn_pk(batch):
xs, ys, cams = zip(*batch, strict=True)
return torch.stack(xs, dim=0), torch.stack(ys, dim=0), torch.stack(cams, dim=0)
def train_one_epoch(
model: nn.Module,
loader: DataLoader,
optim: torch.optim.Optimizer,
device: torch.device,
margin: float,
triplet_w: float,
id_w: float,
) -> tuple[float, float, float]:
model.train()
sum_t = 0.0
sum_id = 0.0
n = 0
ce = nn.CrossEntropyLoss()
for x, y, _cam in loader:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
emb, logits = model(x)
emb = F.normalize(emb, p=2, dim=1)
loss_t = batch_hard_triplet_loss(emb, y, margin=margin)
loss_id = ce(logits, y)
loss = triplet_w * loss_t + id_w * loss_id
optim.zero_grad()
loss.backward()
optim.step()
bs = x.size(0)
sum_t += loss_t.detach().item() * bs
sum_id += loss_id.detach().item() * bs
n += bs
return sum_t / max(n, 1), sum_id / max(n, 1)
@torch.no_grad()
def estimate_id_accuracy(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
model.eval()
correct = 0
total = 0
for x, y, _ in loader:
x = x.to(device)
y = y.to(device)
_, logits = model(x)
pred = logits.argmax(dim=1)
correct += int((pred == y).sum().item())
total += y.numel()
return correct / max(total, 1)
def resolve_image_root(cli: str | Path) -> Path:
p = Path(cli).resolve()
sub = p / "doctor_picture"
if sub.is_dir():
return sub
if p.is_dir():
jpgs = list(p.glob("*.jpg"))
if len(jpgs) > 0:
return p
raise FileNotFoundError(f"未找到图像目录(可传 doctor_info_detect 或 doctor_picture: {cli}")
def parse_args():
ap = argparse.ArgumentParser(description="Doctor Re-ID PK Triplet + ID 训练")
ap.add_argument(
"--data-root",
type=Path,
default=Path(__file__).resolve().parent,
help="含 doctor_picture 或图片根目录的路径",
)
ap.add_argument("--epochs", type=int, default=50, help="epoch 建议在 4060默认 50")
ap.add_argument(
"--batch-p",
type=int,
default=5,
help="PK 采样:每 batch 采样的身份数 P将自动不大于身份总数如默认 5 类)",
)
ap.add_argument(
"--batch-k",
type=int,
default=8,
help="PK 采样:每位身份抽样张数 Kbatch_size=P×K",
)
ap.add_argument("--lr", type=float, default=3e-4)
ap.add_argument("--triplet-margin", type=float, default=0.3)
ap.add_argument("--triplet-weight", type=float, default=1.0)
ap.add_argument("--id-weight", type=float, default=2.0, help="ID LossCrossEntropy权重")
ap.add_argument("--workers", type=int, default=4)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument(
"--save",
type=Path,
default=Path(__file__).resolve().parent / "doctor_reid_best.pth",
help="最佳权重路径",
)
return ap.parse_args()
def main() -> None:
args = parse_args()
rng = random.Random(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
img_root = resolve_image_root(args.data_root)
ds_train_aug = DoctorReIDDataset(img_root, augment=True)
ds_eval = DoctorReIDDataset(img_root, augment=False)
n = len(ds_train_aug)
perm = list(range(n))
rng.shuffle(perm)
n_val = max(32, int(0.1 * n))
val_ix = sorted(perm[:n_val])
train_ix = perm[n_val:]
labels_tr = [ds_train_aug.labels[i] for i in train_ix]
num_classes = len(ds_train_aug.pid_to_label)
p_eff = min(args.batch_p, len(set(labels_tr)))
from torch.utils.data import Subset
sampler = RandomIdentityPKSampler(
labels_tr,
p=p_eff,
k=args.batch_k,
seed=args.seed,
)
train_loader = DataLoader(
Subset(ds_train_aug, train_ix),
batch_size=sampler.batch_size,
sampler=sampler,
num_workers=args.workers,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn_pk,
)
val_loader = DataLoader(
Subset(ds_eval, val_ix),
batch_size=64,
shuffle=False,
num_workers=args.workers,
collate_fn=collate_fn_pk,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(device)
optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4)
scheduler = CosineAnnealingLR(optim, T_max=args.epochs, eta_min=1e-6)
best_acc = -1.0
for epoch in range(1, args.epochs + 1):
sampler.set_epoch(epoch)
tr_t, tr_id = train_one_epoch(
model,
train_loader,
optim,
device,
margin=args.triplet_margin,
triplet_w=args.triplet_weight,
id_w=args.id_weight,
)
scheduler.step()
lr_now = optim.param_groups[0]["lr"]
val_acc = estimate_id_accuracy(model, val_loader, device)
print(
f"epoch {epoch:03d}/{args.epochs} | "
f"triplet {tr_t:.4f} | id_loss_ce {tr_id:.4f} | "
f"lr {lr_now:.6f} | val_id_acc ~ {val_acc:.4f}"
)
if val_acc >= best_acc:
best_acc = val_acc
torch.save(
{
"epoch": epoch,
"model_state": model.state_dict(),
"num_classes": num_classes,
"pid_to_label": ds_train_aug.pid_to_label,
"args": vars(args),
},
args.save,
)
print(f"[保存] checkpoint → {args.save} (best val_id_acc {best_acc:.4f})")
print(f"训练结束。最佳 val_id_acc≈{best_acc:.4f}, 权重: {args.save}")
if __name__ == "__main__":
main()