459 lines
16 KiB
Python
459 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
人体身份重识别(Person ReID)对比学习训练脚本(单机单文件)。
|
||
- Dataset:Market-1501 风格文件名解析 pid/cam;
|
||
- PK 采样器:每个 batch 内 P 个身份 × 每张 K 样本;
|
||
- 模型:ImageNet ResNet50 骨干 + GAP + 512 维嵌入 + ID 分类头;
|
||
- 损失:Batch-Hard Triplet + ID(交叉熵)联合;
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import random
|
||
import re
|
||
from pathlib import Path
|
||
from typing import Iterator
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from PIL import Image
|
||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||
from torch.utils.data import DataLoader, Dataset, Sampler
|
||
from torchvision import models, transforms
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 文件名解析:例如 24502_c1_s1_00001.jpg → pid=24502,cam_id=1
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_NAME_RE = re.compile(
|
||
r"^(?P<pid>\d+)_c(?P<cam>\d+)_",
|
||
flags=re.I,
|
||
)
|
||
|
||
|
||
def parse_market1501_style_name(stem: str) -> tuple[int | None, int | None]:
|
||
"""从「不含后缀」的文件名 stem 中提取身份 ID、机位 ID。"""
|
||
m = _NAME_RE.match(stem)
|
||
if not m:
|
||
return None, None
|
||
return int(m.group("pid")), int(m.group("cam"))
|
||
|
||
|
||
class DoctorReIDDataset(Dataset):
|
||
"""医生 ReID:解析 pid/cam,将 pid 重映射到 0..num_classes-1。"""
|
||
|
||
def __init__(self, image_root: Path, augment: bool) -> None:
|
||
self.image_root = Path(image_root).resolve()
|
||
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
|
||
paths = sorted(
|
||
p
|
||
for p in self.image_root.rglob("*")
|
||
if p.is_file() and p.suffix.lower() in exts
|
||
)
|
||
pid_raw_list: list[int] = []
|
||
cam_raw_list: list[int] = []
|
||
valid_paths: list[Path] = []
|
||
for p in paths:
|
||
pid_raw, cam_raw = parse_market1501_style_name(p.stem)
|
||
if pid_raw is None:
|
||
continue
|
||
valid_paths.append(p)
|
||
pid_raw_list.append(pid_raw)
|
||
cam_raw_list.append(int(cam_raw))
|
||
|
||
unique_pids = sorted(set(pid_raw_list))
|
||
self.pid_to_label = {pid: i for i, pid in enumerate(unique_pids)}
|
||
self.labels: list[int] = [self.pid_to_label[r] for r in pid_raw_list]
|
||
self.cam_ids: list[int] = cam_raw_list
|
||
self.paths = valid_paths
|
||
|
||
if len(self.paths) == 0:
|
||
raise RuntimeError(f"目录下未发现有效图像: {self.image_root}")
|
||
|
||
# Resize(128,256) 在 torchvision 中为 (height, width) → (256,128) 常见于 ReID。
|
||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||
if augment:
|
||
self.transform = transforms.Compose(
|
||
[
|
||
transforms.Resize((256, 128)),
|
||
transforms.RandomHorizontalFlip(p=0.5),
|
||
transforms.ColorJitter(
|
||
brightness=0.15, contrast=0.15, saturation=0.15, hue=0.02
|
||
),
|
||
transforms.ToTensor(),
|
||
normalize,
|
||
# RandomErasure 需作用在 Tensor 上,故放在 ToTensor 之后;
|
||
transforms.RandomErasing(p=0.5, scale=(0.02, 0.25), ratio=(0.3, 3.3)),
|
||
]
|
||
)
|
||
else:
|
||
self.transform = transforms.Compose(
|
||
[
|
||
transforms.Resize((256, 128)),
|
||
transforms.ToTensor(),
|
||
normalize,
|
||
]
|
||
)
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.paths)
|
||
|
||
def __getitem__(self, idx: int):
|
||
img = Image.open(self.paths[idx]).convert("RGB")
|
||
x = self.transform(img)
|
||
y = torch.tensor(self.labels[idx], dtype=torch.long)
|
||
cam = torch.tensor(self.cam_ids[idx], dtype=torch.long)
|
||
return x, y, cam
|
||
|
||
|
||
class RandomIdentityPKSampler(Sampler[int]):
|
||
"""
|
||
PK / Random Identity Sampler(Random Identity Sampling)
|
||
|
||
【原理简述】对比学习三元组需在 batch 内出现「同类多图」才能把 anchor / positive 找出来。
|
||
PK 策略:在每个 mini-batch 中固定结构为:
|
||
- 先随机选 **P** 个不同身份;
|
||
- 每个身份若无放回地抽 **K** 张图像(样本不足则从该身份放回随机抽满 K);
|
||
则 batch_size = P * K。
|
||
这样保证了每个身份在同一个 batch 中至少有若干张可用于 Triplet,
|
||
Batch-Hard 才能选「最难的正对 / 最难的负样本」。
|
||
【注意】P 不能超过数据集中可用身份总数;若 K > 该类张数则用放回采样。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
labels: list[int],
|
||
p: int,
|
||
k: int,
|
||
*,
|
||
seed: int = 0,
|
||
length: int | None = None,
|
||
) -> None:
|
||
super().__init__()
|
||
self.labels_np = np.array(labels)
|
||
self.p = int(p)
|
||
self.k = int(k)
|
||
self.seed = seed
|
||
self.epoch = 0
|
||
|
||
self.identities = sorted(np.unique(self.labels_np).tolist())
|
||
self.num_identities = len(self.identities)
|
||
|
||
idx_by_label: dict[int, list[int]] = {}
|
||
for i, lbl in enumerate(self.labels_np.tolist()):
|
||
idx_by_label.setdefault(int(lbl), []).append(i)
|
||
self.idx_by_label = {k_: np.array(v) for k_, v in idx_by_label.items()}
|
||
|
||
if self.num_identities == 0:
|
||
raise RuntimeError("PKSampler: 无任何身份类别")
|
||
|
||
if self.p > self.num_identities:
|
||
raise ValueError(
|
||
f"P={self.p} 大于数据集身份数 {self.num_identities}。"
|
||
" 请减小 --p。"
|
||
)
|
||
|
||
# 每个 epoch 内「迭代步数」:batch 数(每次 batch 含 P×K 张图)
|
||
self.labels_flat = labels
|
||
if length is None:
|
||
# 约按「扫过一遍身份组合」的量级设一个稳定值
|
||
self.num_batches_est = max(32, min(200, len(self.labels_flat) // max(1, self.p * self.k)))
|
||
else:
|
||
self.num_batches_est = int(length)
|
||
|
||
self.batch_size = self.p * self.k
|
||
|
||
def __len__(self) -> int:
|
||
return self.num_batches_est * self.batch_size
|
||
|
||
def set_epoch(self, epoch: int) -> None:
|
||
self.epoch = epoch
|
||
|
||
def __iter__(self) -> Iterator[int]:
|
||
rng = np.random.RandomState((self.epoch * 9973 + self.seed) & 0xFFFFFFFF)
|
||
ids_arr = np.asarray(self.identities, dtype=np.int64)
|
||
|
||
for _batch_id in range(self.num_batches_est):
|
||
# 一步:无放回抽 P 个身份(若身份不够则允许放回,实际 5 类且 P≤5 时总是无放回)
|
||
if ids_arr.size >= self.p:
|
||
chosen = rng.choice(ids_arr, size=self.p, replace=False)
|
||
else:
|
||
chosen = rng.choice(ids_arr, size=self.p, replace=True)
|
||
|
||
out: list[int] = []
|
||
for pid_pick in chosen.tolist():
|
||
pool = self.idx_by_label[int(pid_pick)]
|
||
if pool.size >= self.k:
|
||
idx_pick = rng.choice(pool, size=self.k, replace=False)
|
||
else:
|
||
idx_pick = rng.choice(pool, size=self.k, replace=True)
|
||
out.extend(int(t) for t in idx_pick)
|
||
|
||
yield from out
|
||
|
||
|
||
def batch_hard_triplet_loss(embeddings: torch.Tensor, labels: torch.Tensor, margin: float) -> torch.Tensor:
|
||
"""
|
||
Batch-Hard Triplet Loss(Hermans ECCV17 一类的标准形式)
|
||
|
||
对每个 anchor(batch 里的每个样本 i)在同一 batch 中选:
|
||
- Positive:与它同身份的样本 j 里面,距离**最大**的那一个( hardest positive );
|
||
- Negative:与它不同身份的样本里,距离**最小**的那一个( hardest negative );
|
||
单项损失通常为 relu(d_pos_hard - d_neg_hard + margin) 对每个有效 anchor 求平均。
|
||
【数学含义】拉大「最难正对」相对「最易负样本」的间隔。
|
||
【实现】使用欧氏距离;若某样本在 batch 内无「异类」(极少见)或无「同类第二张」则不参与均值。
|
||
"""
|
||
dist = torch.cdist(embeddings.float(), embeddings.float(), p=2.0).clamp(min=1e-8)
|
||
|
||
bs = dist.size(0)
|
||
lbl = labels.long()
|
||
same = lbl.unsqueeze(1).eq(lbl.unsqueeze(0))
|
||
losses: list[torch.Tensor] = []
|
||
|
||
for i in range(bs):
|
||
same_pos = same[i].clone()
|
||
same_pos[i] = False
|
||
if not same_pos.any():
|
||
continue
|
||
hardest_pos = dist[i][same_pos].max()
|
||
|
||
neg_mask = ~same[i]
|
||
neg_mask[i] = False
|
||
if not neg_mask.any():
|
||
continue
|
||
hardest_neg = dist[i][neg_mask].min()
|
||
|
||
losses.append(F.relu(hardest_pos - hardest_neg + margin))
|
||
|
||
if not losses:
|
||
return embeddings.sum() * 0.0
|
||
return torch.stack(losses).mean()
|
||
|
||
|
||
class ReIDEmbedModel(nn.Module):
|
||
"""ResNet50 预训练 backbone(去掉 fc)→ BN 512 嵌入 → logits(num_classes)。"""
|
||
|
||
def __init__(self, num_classes: int, feat_dim: int = 512) -> None:
|
||
super().__init__()
|
||
backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
|
||
self.backbone = nn.Sequential(*list(backbone.children())[:-2]) # 到 GAP 前,输出 [B,2048,7,7]
|
||
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
||
self.bottleneck = nn.Sequential(
|
||
nn.Linear(2048, feat_dim),
|
||
nn.BatchNorm1d(feat_dim),
|
||
nn.ReLU(inplace=True),
|
||
)
|
||
self.classifier = nn.Linear(feat_dim, num_classes)
|
||
|
||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||
h = self.backbone(x)
|
||
h = self.gap(h)
|
||
h = h.view(h.size(0), -1)
|
||
emb = self.bottleneck(h)
|
||
logits = self.classifier(emb)
|
||
return emb, logits
|
||
|
||
|
||
def collate_fn_pk(batch):
|
||
xs, ys, cams = zip(*batch, strict=True)
|
||
return torch.stack(xs, dim=0), torch.stack(ys, dim=0), torch.stack(cams, dim=0)
|
||
|
||
|
||
def train_one_epoch(
|
||
model: nn.Module,
|
||
loader: DataLoader,
|
||
optim: torch.optim.Optimizer,
|
||
device: torch.device,
|
||
margin: float,
|
||
triplet_w: float,
|
||
id_w: float,
|
||
) -> tuple[float, float, float]:
|
||
model.train()
|
||
sum_t = 0.0
|
||
sum_id = 0.0
|
||
n = 0
|
||
ce = nn.CrossEntropyLoss()
|
||
|
||
for x, y, _cam in loader:
|
||
x = x.to(device, non_blocking=True)
|
||
y = y.to(device, non_blocking=True)
|
||
emb, logits = model(x)
|
||
emb = F.normalize(emb, p=2, dim=1)
|
||
loss_t = batch_hard_triplet_loss(emb, y, margin=margin)
|
||
loss_id = ce(logits, y)
|
||
loss = triplet_w * loss_t + id_w * loss_id
|
||
|
||
optim.zero_grad()
|
||
loss.backward()
|
||
optim.step()
|
||
|
||
bs = x.size(0)
|
||
sum_t += loss_t.detach().item() * bs
|
||
sum_id += loss_id.detach().item() * bs
|
||
n += bs
|
||
|
||
return sum_t / max(n, 1), sum_id / max(n, 1)
|
||
|
||
|
||
@torch.no_grad()
|
||
def estimate_id_accuracy(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
|
||
model.eval()
|
||
correct = 0
|
||
total = 0
|
||
for x, y, _ in loader:
|
||
x = x.to(device)
|
||
y = y.to(device)
|
||
_, logits = model(x)
|
||
pred = logits.argmax(dim=1)
|
||
correct += int((pred == y).sum().item())
|
||
total += y.numel()
|
||
return correct / max(total, 1)
|
||
|
||
|
||
def resolve_image_root(cli: str | Path) -> Path:
|
||
p = Path(cli).resolve()
|
||
sub = p / "doctor_picture"
|
||
if sub.is_dir():
|
||
return sub
|
||
if p.is_dir():
|
||
jpgs = list(p.glob("*.jpg"))
|
||
if len(jpgs) > 0:
|
||
return p
|
||
raise FileNotFoundError(f"未找到图像目录(可传 doctor_info_detect 或 doctor_picture): {cli}")
|
||
|
||
|
||
def parse_args():
|
||
ap = argparse.ArgumentParser(description="Doctor Re-ID PK Triplet + ID 训练")
|
||
ap.add_argument(
|
||
"--data-root",
|
||
type=Path,
|
||
default=Path(__file__).resolve().parent,
|
||
help="含 doctor_picture 或图片根目录的路径",
|
||
)
|
||
ap.add_argument("--epochs", type=int, default=50, help="epoch 建议在 40–60(默认 50)")
|
||
ap.add_argument(
|
||
"--batch-p",
|
||
type=int,
|
||
default=5,
|
||
help="PK 采样:每 batch 采样的身份数 P(将自动不大于身份总数,如默认 5 类)",
|
||
)
|
||
ap.add_argument(
|
||
"--batch-k",
|
||
type=int,
|
||
default=8,
|
||
help="PK 采样:每位身份抽样张数 K;batch_size=P×K",
|
||
)
|
||
ap.add_argument("--lr", type=float, default=3e-4)
|
||
ap.add_argument("--triplet-margin", type=float, default=0.3)
|
||
ap.add_argument("--triplet-weight", type=float, default=1.0)
|
||
ap.add_argument("--id-weight", type=float, default=2.0, help="ID Loss(CrossEntropy)权重")
|
||
ap.add_argument("--workers", type=int, default=4)
|
||
ap.add_argument("--seed", type=int, default=42)
|
||
ap.add_argument(
|
||
"--save",
|
||
type=Path,
|
||
default=Path(__file__).resolve().parent / "doctor_reid_best.pth",
|
||
help="最佳权重路径",
|
||
)
|
||
return ap.parse_args()
|
||
|
||
|
||
def main() -> None:
|
||
args = parse_args()
|
||
rng = random.Random(args.seed)
|
||
torch.manual_seed(args.seed)
|
||
np.random.seed(args.seed)
|
||
|
||
img_root = resolve_image_root(args.data_root)
|
||
|
||
ds_train_aug = DoctorReIDDataset(img_root, augment=True)
|
||
ds_eval = DoctorReIDDataset(img_root, augment=False)
|
||
|
||
n = len(ds_train_aug)
|
||
perm = list(range(n))
|
||
rng.shuffle(perm)
|
||
n_val = max(32, int(0.1 * n))
|
||
val_ix = sorted(perm[:n_val])
|
||
train_ix = perm[n_val:]
|
||
|
||
labels_tr = [ds_train_aug.labels[i] for i in train_ix]
|
||
|
||
num_classes = len(ds_train_aug.pid_to_label)
|
||
p_eff = min(args.batch_p, len(set(labels_tr)))
|
||
|
||
from torch.utils.data import Subset
|
||
|
||
sampler = RandomIdentityPKSampler(
|
||
labels_tr,
|
||
p=p_eff,
|
||
k=args.batch_k,
|
||
seed=args.seed,
|
||
)
|
||
train_loader = DataLoader(
|
||
Subset(ds_train_aug, train_ix),
|
||
batch_size=sampler.batch_size,
|
||
sampler=sampler,
|
||
num_workers=args.workers,
|
||
pin_memory=True,
|
||
drop_last=True,
|
||
collate_fn=collate_fn_pk,
|
||
)
|
||
val_loader = DataLoader(
|
||
Subset(ds_eval, val_ix),
|
||
batch_size=64,
|
||
shuffle=False,
|
||
num_workers=args.workers,
|
||
collate_fn=collate_fn_pk,
|
||
)
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(device)
|
||
|
||
optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4)
|
||
scheduler = CosineAnnealingLR(optim, T_max=args.epochs, eta_min=1e-6)
|
||
|
||
best_acc = -1.0
|
||
for epoch in range(1, args.epochs + 1):
|
||
sampler.set_epoch(epoch)
|
||
tr_t, tr_id = train_one_epoch(
|
||
model,
|
||
train_loader,
|
||
optim,
|
||
device,
|
||
margin=args.triplet_margin,
|
||
triplet_w=args.triplet_weight,
|
||
id_w=args.id_weight,
|
||
)
|
||
scheduler.step()
|
||
lr_now = optim.param_groups[0]["lr"]
|
||
val_acc = estimate_id_accuracy(model, val_loader, device)
|
||
|
||
print(
|
||
f"epoch {epoch:03d}/{args.epochs} | "
|
||
f"triplet {tr_t:.4f} | id_loss_ce {tr_id:.4f} | "
|
||
f"lr {lr_now:.6f} | val_id_acc ~ {val_acc:.4f}"
|
||
)
|
||
|
||
if val_acc >= best_acc:
|
||
best_acc = val_acc
|
||
torch.save(
|
||
{
|
||
"epoch": epoch,
|
||
"model_state": model.state_dict(),
|
||
"num_classes": num_classes,
|
||
"pid_to_label": ds_train_aug.pid_to_label,
|
||
"args": vars(args),
|
||
},
|
||
args.save,
|
||
)
|
||
print(f"[保存] checkpoint → {args.save} (best val_id_acc {best_acc:.4f})")
|
||
|
||
print(f"训练结束。最佳 val_id_acc≈{best_acc:.4f}, 权重: {args.save}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|