Files
OperationRoomMonitor/doctor_identity_package/train_reid_contrastive.py
2026-06-03 14:46:16 +08:00

459 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
人体身份重识别Person ReID对比学习训练脚本单机单文件
- DatasetMarket-1501 风格文件名解析 pid/cam
- PK 采样器:每个 batch 内 P 个身份 × 每张 K 样本;
- 模型ImageNet ResNet50 骨干 + GAP + 512 维嵌入 + ID 分类头;
- 损失Batch-Hard Triplet + ID交叉熵联合
"""
from __future__ import annotations
import argparse
import random
import re
from pathlib import Path
from typing import Iterator
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision import models, transforms
# ---------------------------------------------------------------------------
# 文件名解析:例如 24502_c1_s1_00001.jpg → pid=24502cam_id=1
# ---------------------------------------------------------------------------
_NAME_RE = re.compile(
r"^(?P<pid>\d+)_c(?P<cam>\d+)_",
flags=re.I,
)
def parse_market1501_style_name(stem: str) -> tuple[int | None, int | None]:
"""从「不含后缀」的文件名 stem 中提取身份 ID、机位 ID。"""
m = _NAME_RE.match(stem)
if not m:
return None, None
return int(m.group("pid")), int(m.group("cam"))
class DoctorReIDDataset(Dataset):
"""医生 ReID解析 pid/cam将 pid 重映射到 0..num_classes-1。"""
def __init__(self, image_root: Path, augment: bool) -> None:
self.image_root = Path(image_root).resolve()
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
paths = sorted(
p
for p in self.image_root.rglob("*")
if p.is_file() and p.suffix.lower() in exts
)
pid_raw_list: list[int] = []
cam_raw_list: list[int] = []
valid_paths: list[Path] = []
for p in paths:
pid_raw, cam_raw = parse_market1501_style_name(p.stem)
if pid_raw is None:
continue
valid_paths.append(p)
pid_raw_list.append(pid_raw)
cam_raw_list.append(int(cam_raw))
unique_pids = sorted(set(pid_raw_list))
self.pid_to_label = {pid: i for i, pid in enumerate(unique_pids)}
self.labels: list[int] = [self.pid_to_label[r] for r in pid_raw_list]
self.cam_ids: list[int] = cam_raw_list
self.paths = valid_paths
if len(self.paths) == 0:
raise RuntimeError(f"目录下未发现有效图像: {self.image_root}")
# Resize(128,256) 在 torchvision 中为 (height, width) → (256,128) 常见于 ReID。
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if augment:
self.transform = transforms.Compose(
[
transforms.Resize((256, 128)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(
brightness=0.15, contrast=0.15, saturation=0.15, hue=0.02
),
transforms.ToTensor(),
normalize,
# RandomErasure 需作用在 Tensor 上,故放在 ToTensor 之后;
transforms.RandomErasing(p=0.5, scale=(0.02, 0.25), ratio=(0.3, 3.3)),
]
)
else:
self.transform = transforms.Compose(
[
transforms.Resize((256, 128)),
transforms.ToTensor(),
normalize,
]
)
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, idx: int):
img = Image.open(self.paths[idx]).convert("RGB")
x = self.transform(img)
y = torch.tensor(self.labels[idx], dtype=torch.long)
cam = torch.tensor(self.cam_ids[idx], dtype=torch.long)
return x, y, cam
class RandomIdentityPKSampler(Sampler[int]):
"""
PK / Random Identity SamplerRandom Identity Sampling
【原理简述】对比学习三元组需在 batch 内出现「同类多图」才能把 anchor / positive 找出来。
PK 策略:在每个 mini-batch 中固定结构为:
- 先随机选 **P** 个不同身份;
- 每个身份若无放回地抽 **K** 张图像(样本不足则从该身份放回随机抽满 K
则 batch_size = P * K。
这样保证了每个身份在同一个 batch 中至少有若干张可用于 Triplet
Batch-Hard 才能选「最难的正对 / 最难的负样本」。
【注意】P 不能超过数据集中可用身份总数;若 K > 该类张数则用放回采样。
"""
def __init__(
self,
labels: list[int],
p: int,
k: int,
*,
seed: int = 0,
length: int | None = None,
) -> None:
super().__init__()
self.labels_np = np.array(labels)
self.p = int(p)
self.k = int(k)
self.seed = seed
self.epoch = 0
self.identities = sorted(np.unique(self.labels_np).tolist())
self.num_identities = len(self.identities)
idx_by_label: dict[int, list[int]] = {}
for i, lbl in enumerate(self.labels_np.tolist()):
idx_by_label.setdefault(int(lbl), []).append(i)
self.idx_by_label = {k_: np.array(v) for k_, v in idx_by_label.items()}
if self.num_identities == 0:
raise RuntimeError("PKSampler: 无任何身份类别")
if self.p > self.num_identities:
raise ValueError(
f"P={self.p} 大于数据集身份数 {self.num_identities}"
" 请减小 --p。"
)
# 每个 epoch 内「迭代步数」batch 数(每次 batch 含 P×K 张图)
self.labels_flat = labels
if length is None:
# 约按「扫过一遍身份组合」的量级设一个稳定值
self.num_batches_est = max(32, min(200, len(self.labels_flat) // max(1, self.p * self.k)))
else:
self.num_batches_est = int(length)
self.batch_size = self.p * self.k
def __len__(self) -> int:
return self.num_batches_est * self.batch_size
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
def __iter__(self) -> Iterator[int]:
rng = np.random.RandomState((self.epoch * 9973 + self.seed) & 0xFFFFFFFF)
ids_arr = np.asarray(self.identities, dtype=np.int64)
for _batch_id in range(self.num_batches_est):
# 一步:无放回抽 P 个身份(若身份不够则允许放回,实际 5 类且 P≤5 时总是无放回)
if ids_arr.size >= self.p:
chosen = rng.choice(ids_arr, size=self.p, replace=False)
else:
chosen = rng.choice(ids_arr, size=self.p, replace=True)
out: list[int] = []
for pid_pick in chosen.tolist():
pool = self.idx_by_label[int(pid_pick)]
if pool.size >= self.k:
idx_pick = rng.choice(pool, size=self.k, replace=False)
else:
idx_pick = rng.choice(pool, size=self.k, replace=True)
out.extend(int(t) for t in idx_pick)
yield from out
def batch_hard_triplet_loss(embeddings: torch.Tensor, labels: torch.Tensor, margin: float) -> torch.Tensor:
"""
Batch-Hard Triplet LossHermans ECCV17 一类的标准形式)
对每个 anchorbatch 里的每个样本 i在同一 batch 中选:
- Positive与它同身份的样本 j 里面,距离**最大**的那一个( hardest positive
- Negative与它不同身份的样本里距离**最小**的那一个( hardest negative
单项损失通常为 relu(d_pos_hard - d_neg_hard + margin) 对每个有效 anchor 求平均。
【数学含义】拉大「最难正对」相对「最易负样本」的间隔。
【实现】使用欧氏距离;若某样本在 batch 内无「异类」(极少见)或无「同类第二张」则不参与均值。
"""
dist = torch.cdist(embeddings.float(), embeddings.float(), p=2.0).clamp(min=1e-8)
bs = dist.size(0)
lbl = labels.long()
same = lbl.unsqueeze(1).eq(lbl.unsqueeze(0))
losses: list[torch.Tensor] = []
for i in range(bs):
same_pos = same[i].clone()
same_pos[i] = False
if not same_pos.any():
continue
hardest_pos = dist[i][same_pos].max()
neg_mask = ~same[i]
neg_mask[i] = False
if not neg_mask.any():
continue
hardest_neg = dist[i][neg_mask].min()
losses.append(F.relu(hardest_pos - hardest_neg + margin))
if not losses:
return embeddings.sum() * 0.0
return torch.stack(losses).mean()
class ReIDEmbedModel(nn.Module):
"""ResNet50 预训练 backbone去掉 fc→ BN 512 嵌入 → logitsnum_classes"""
def __init__(self, num_classes: int, feat_dim: int = 512) -> None:
super().__init__()
backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
self.backbone = nn.Sequential(*list(backbone.children())[:-2]) # 到 GAP 前,输出 [B,2048,7,7]
self.gap = nn.AdaptiveAvgPool2d((1, 1))
self.bottleneck = nn.Sequential(
nn.Linear(2048, feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(inplace=True),
)
self.classifier = nn.Linear(feat_dim, num_classes)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = self.backbone(x)
h = self.gap(h)
h = h.view(h.size(0), -1)
emb = self.bottleneck(h)
logits = self.classifier(emb)
return emb, logits
def collate_fn_pk(batch):
xs, ys, cams = zip(*batch, strict=True)
return torch.stack(xs, dim=0), torch.stack(ys, dim=0), torch.stack(cams, dim=0)
def train_one_epoch(
model: nn.Module,
loader: DataLoader,
optim: torch.optim.Optimizer,
device: torch.device,
margin: float,
triplet_w: float,
id_w: float,
) -> tuple[float, float, float]:
model.train()
sum_t = 0.0
sum_id = 0.0
n = 0
ce = nn.CrossEntropyLoss()
for x, y, _cam in loader:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
emb, logits = model(x)
emb = F.normalize(emb, p=2, dim=1)
loss_t = batch_hard_triplet_loss(emb, y, margin=margin)
loss_id = ce(logits, y)
loss = triplet_w * loss_t + id_w * loss_id
optim.zero_grad()
loss.backward()
optim.step()
bs = x.size(0)
sum_t += loss_t.detach().item() * bs
sum_id += loss_id.detach().item() * bs
n += bs
return sum_t / max(n, 1), sum_id / max(n, 1)
@torch.no_grad()
def estimate_id_accuracy(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
model.eval()
correct = 0
total = 0
for x, y, _ in loader:
x = x.to(device)
y = y.to(device)
_, logits = model(x)
pred = logits.argmax(dim=1)
correct += int((pred == y).sum().item())
total += y.numel()
return correct / max(total, 1)
def resolve_image_root(cli: str | Path) -> Path:
p = Path(cli).resolve()
sub = p / "doctor_picture"
if sub.is_dir():
return sub
if p.is_dir():
jpgs = list(p.glob("*.jpg"))
if len(jpgs) > 0:
return p
raise FileNotFoundError(f"未找到图像目录(可传 doctor_info_detect 或 doctor_picture: {cli}")
def parse_args():
ap = argparse.ArgumentParser(description="Doctor Re-ID PK Triplet + ID 训练")
ap.add_argument(
"--data-root",
type=Path,
default=Path(__file__).resolve().parent,
help="含 doctor_picture 或图片根目录的路径",
)
ap.add_argument("--epochs", type=int, default=50, help="epoch 建议在 4060默认 50")
ap.add_argument(
"--batch-p",
type=int,
default=5,
help="PK 采样:每 batch 采样的身份数 P将自动不大于身份总数如默认 5 类)",
)
ap.add_argument(
"--batch-k",
type=int,
default=8,
help="PK 采样:每位身份抽样张数 Kbatch_size=P×K",
)
ap.add_argument("--lr", type=float, default=3e-4)
ap.add_argument("--triplet-margin", type=float, default=0.3)
ap.add_argument("--triplet-weight", type=float, default=1.0)
ap.add_argument("--id-weight", type=float, default=2.0, help="ID LossCrossEntropy权重")
ap.add_argument("--workers", type=int, default=4)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument(
"--save",
type=Path,
default=Path(__file__).resolve().parent / "doctor_reid_best.pth",
help="最佳权重路径",
)
return ap.parse_args()
def main() -> None:
args = parse_args()
rng = random.Random(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
img_root = resolve_image_root(args.data_root)
ds_train_aug = DoctorReIDDataset(img_root, augment=True)
ds_eval = DoctorReIDDataset(img_root, augment=False)
n = len(ds_train_aug)
perm = list(range(n))
rng.shuffle(perm)
n_val = max(32, int(0.1 * n))
val_ix = sorted(perm[:n_val])
train_ix = perm[n_val:]
labels_tr = [ds_train_aug.labels[i] for i in train_ix]
num_classes = len(ds_train_aug.pid_to_label)
p_eff = min(args.batch_p, len(set(labels_tr)))
from torch.utils.data import Subset
sampler = RandomIdentityPKSampler(
labels_tr,
p=p_eff,
k=args.batch_k,
seed=args.seed,
)
train_loader = DataLoader(
Subset(ds_train_aug, train_ix),
batch_size=sampler.batch_size,
sampler=sampler,
num_workers=args.workers,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn_pk,
)
val_loader = DataLoader(
Subset(ds_eval, val_ix),
batch_size=64,
shuffle=False,
num_workers=args.workers,
collate_fn=collate_fn_pk,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(device)
optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4)
scheduler = CosineAnnealingLR(optim, T_max=args.epochs, eta_min=1e-6)
best_acc = -1.0
for epoch in range(1, args.epochs + 1):
sampler.set_epoch(epoch)
tr_t, tr_id = train_one_epoch(
model,
train_loader,
optim,
device,
margin=args.triplet_margin,
triplet_w=args.triplet_weight,
id_w=args.id_weight,
)
scheduler.step()
lr_now = optim.param_groups[0]["lr"]
val_acc = estimate_id_accuracy(model, val_loader, device)
print(
f"epoch {epoch:03d}/{args.epochs} | "
f"triplet {tr_t:.4f} | id_loss_ce {tr_id:.4f} | "
f"lr {lr_now:.6f} | val_id_acc ~ {val_acc:.4f}"
)
if val_acc >= best_acc:
best_acc = val_acc
torch.save(
{
"epoch": epoch,
"model_state": model.state_dict(),
"num_classes": num_classes,
"pid_to_label": ds_train_aug.pid_to_label,
"args": vars(args),
},
args.save,
)
print(f"[保存] checkpoint → {args.save} (best val_id_acc {best_acc:.4f})")
print(f"训练结束。最佳 val_id_acc≈{best_acc:.4f}, 权重: {args.save}")
if __name__ == "__main__":
main()