commit a5951461c4447057a4014db1ee340d60ca53fa90 Author: hsz <2091085305@qq.com> Date: Tue May 26 11:35:14 2026 +0800 first commit: 手术室耗材离线推理代码 Co-authored-by: Cursor diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f1a95b1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,46 @@ +# ===== 模型与权重(运行前自行下载到 weights/)===== +/weights/* +!/weights/.gitkeep + +# ===== 医生识别(checkpoint + mediapipe 模型)===== +/doctor_identity_package/doctor_info.pth +/doctor_identity_package/.mediapipe_models/ + +# ===== 输入数据(视频、Excel)===== +/input/* +!/input/.gitkeep + +# ===== 推理输出 ===== +/output/* +!/output/.gitkeep + +# ===== Python ===== +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +.venv/ +venv/ +.env + +# ===== 编译产物(actionformer nms 扩展)===== +**/build/ +*.o +*.egg-info/ +dist/ + +# ===== 运行时/中间文件 ===== +*.npy +*.pkl +*.mp4 +*.avi +*.mkv +*.xlsx +*.xls + +# ===== IDE / 系统 ===== +.DS_Store +.idea/ +.vscode/ +*.swp diff --git a/README.md b/README.md new file mode 100644 index 0000000..c6d4b47 --- /dev/null +++ b/README.md @@ -0,0 +1,97 @@ +# 手术室耗材离线推理包 + +**主入口**:`python main.py`(读取 `configs/default_config.yaml` 或 `--config` 指定 yaml) + +## 功能 + +- **输入**:主视角 MP4(`io.video`)+ 商品 Excel(`io.excel`,白名单与商品编码) +- **输出**:制表符分隔结果文件(`io.out`,默认 `output/result.txt`) +- **流程**:VideoSwin 特征 → ActionFormer 切时段 → 段内 YOLO 耗材推断 → 可选撕膜相邻段合并 → 末行医生信息 + +**输出格式与 5.17 开发包完全一致**(12 列 TSV + `医生信息:` 行)。 + +## 安装 + +```bash +cd /path/to/本目录 + +# 1. 先按 https://pytorch.org 安装与 CUDA 匹配的 torch / torchvision +# 2. 安装依赖 +pip install -r requirements.txt +pip install -e code/actionformer_release/libs/utils +``` + +## 运行 + +1. 将待分析 **MP4** 与 **Excel** 放入 `input/`(或改 yaml 为绝对路径) +2. 确认 `weights/` 内 5 个模型文件齐全 +3. 确认 `doctor_identity_package/doctor_info.pth` 存在(医生识别默认开启) +4. 推荐复制 [`configs/run_tracking_template.yaml`](configs/run_tracking_template.yaml) 并修改 `io.video` / `io.out` +5. 执行: + +```bash +python main.py --config configs/run_your_video.yaml +``` + +### HEVC 视频(4K 主视角) + +VideoSwin 特征提取对 HEVC 原片可能解码失败,请先转 H.264: + +```bash +./scripts/remux_hevc.sh /path/to/source.mp4 +# 输出默认: input/remuxed/_h264.mp4 +``` + +然后在 yaml 中将 `io.video` 指向转码后的文件。 + +### Debug(Excel 时间段,跳过 ActionFormer) + +```bash +python main_debug.py \ + --video input/remuxed/xxx_h264.mp4 \ + --excel input/视频中的商品信息表.xlsx \ + --out output/result_debug.txt \ + --config configs/run_tracking_template.yaml +``` + +### 可视化(可选) + +主流程完成后,单独生成带手框与耗材标签的 MP4: + +```bash +python visualize_result_video.py \ + --video input/remuxed/xxx_h264.mp4 \ + --result-txt output/result.txt \ + --out-video output/result_vis.mp4 +``` + +## 输出格式 + +默认 **12 列** TSV(Tab 分隔);文件末尾一行 **`医生信息:...`**。 + +``` +rank start_sec end_sec product_id_top1 top1_name top1_conf ... +医生信息:付玉峰 (id=24503, conf=0.8552) +``` + +关闭医生识别:在 yaml 中设 `doctor_identity.enabled: false`。 + +## 目录结构 + +``` +├── main.py # 主流程入口 +├── main_debug.py # Excel 时间段 debug 入口 +├── visualize_result_video.py # 可选可视化 +├── configs/ +│ ├── default_config.yaml +│ └── run_tracking_template.yaml +├── scripts/remux_hevc.sh # HEVC → H.264 转码 +├── src/ # 配置与编排 +├── weights/ # 5 个模型 +├── input/remuxed/ # 转码后视频 +├── output/ # 结果 txt +├── doctor_identity_package/ +└── code/ # 算法子树(一般勿改) +``` + +不包含:RTSP 模拟推流、训练脚本。 diff --git a/code/actionformer_release/eval.py b/code/actionformer_release/eval.py new file mode 100644 index 0000000..d032254 --- /dev/null +++ b/code/actionformer_release/eval.py @@ -0,0 +1,127 @@ +# python imports +import argparse +import os +import glob +import time +from pprint import pprint + +# torch imports +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +import torch.utils.data + +# our code +from libs.core import load_config +from libs.datasets import make_dataset, make_data_loader +from libs.modeling import make_meta_arch +from libs.utils import valid_one_epoch, ANETdetection, fix_random_seed + + +################################################################################ +def main(args): + """0. load config""" + # sanity check + if os.path.isfile(args.config): + cfg = load_config(args.config) + else: + raise ValueError("Config file does not exist.") + assert len(cfg['val_split']) > 0, "Test set must be specified!" + if ".pth.tar" in args.ckpt: + assert os.path.isfile(args.ckpt), "CKPT file does not exist!" + ckpt_file = args.ckpt + else: + assert os.path.isdir(args.ckpt), "CKPT file folder does not exist!" + if args.epoch > 0: + ckpt_file = os.path.join( + args.ckpt, 'epoch_{:03d}.pth.tar'.format(args.epoch) + ) + else: + ckpt_file_list = sorted(glob.glob(os.path.join(args.ckpt, '*.pth.tar'))) + ckpt_file = ckpt_file_list[-1] + assert os.path.exists(ckpt_file) + + if args.topk > 0: + cfg['model']['test_cfg']['max_seg_num'] = args.topk + pprint(cfg) + + """1. fix all randomness""" + # fix the random seeds (this will fix everything) + _ = fix_random_seed(0, include_cuda=True) + + """2. create dataset / dataloader""" + val_dataset = make_dataset( + cfg['dataset_name'], False, cfg['val_split'], **cfg['dataset'] + ) + # set bs = 1, and disable shuffle + val_loader = make_data_loader( + val_dataset, False, None, 1, cfg['loader']['num_workers'] + ) + + """3. create model and evaluator""" + # model + model = make_meta_arch(cfg['model_name'], **cfg['model']) + # not ideal for multi GPU training, ok for now + model = nn.DataParallel(model, device_ids=cfg['devices']) + + """4. load ckpt""" + print("=> loading checkpoint '{}'".format(ckpt_file)) + # load ckpt, reset epoch / best rmse + checkpoint = torch.load( + ckpt_file, + map_location = lambda storage, loc: storage.cuda(cfg['devices'][0]) + ) + # load ema model instead + print("Loading from EMA model ...") + model.load_state_dict(checkpoint['state_dict_ema']) + del checkpoint + + # set up evaluator + det_eval, output_file = None, None + if not args.saveonly: + val_db_vars = val_dataset.get_attributes() + det_eval = ANETdetection( + val_dataset.json_file, + val_dataset.split[0], + tiou_thresholds = val_db_vars['tiou_thresholds'] + ) + else: + output_file = os.path.join(os.path.split(ckpt_file)[0], 'eval_results.pkl') + + """5. Test the model""" + print("\nStart testing model {:s} ...".format(cfg['model_name'])) + start = time.time() + mAP = valid_one_epoch( + val_loader, + model, + -1, + evaluator=det_eval, + output_file=output_file, + ext_score_file=cfg['test_cfg']['ext_score_file'], + tb_writer=None, + print_freq=args.print_freq + ) + end = time.time() + print("All done! Total time: {:0.2f} sec".format(end - start)) + return + +################################################################################ +if __name__ == '__main__': + """Entry Point""" + # the arg parser + parser = argparse.ArgumentParser( + description='Train a point-based transformer for action localization') + parser.add_argument('config', type=str, metavar='DIR', + help='path to a config file') + parser.add_argument('ckpt', type=str, metavar='DIR', + help='path to a checkpoint') + parser.add_argument('-epoch', type=int, default=-1, + help='checkpoint epoch') + parser.add_argument('-t', '--topk', default=-1, type=int, + help='max number of output actions (default: -1)') + parser.add_argument('--saveonly', action='store_true', + help='Only save the ouputs without evaluation (e.g., for test set)') + parser.add_argument('-p', '--print-freq', default=10, type=int, + help='print frequency (default: 10 iterations)') + args = parser.parse_args() + main(args) diff --git a/code/actionformer_release/libs/__init__.py b/code/actionformer_release/libs/__init__.py new file mode 100644 index 0000000..55f79d0 --- /dev/null +++ b/code/actionformer_release/libs/__init__.py @@ -0,0 +1 @@ +"""Local ActionFormer libs package.""" diff --git a/code/actionformer_release/libs/core/__init__.py b/code/actionformer_release/libs/core/__init__.py new file mode 100644 index 0000000..b2d4e54 --- /dev/null +++ b/code/actionformer_release/libs/core/__init__.py @@ -0,0 +1,3 @@ +from .config import load_default_config, load_config + +__all__ = ['load_default_config', 'load_config'] diff --git a/code/actionformer_release/libs/core/config.py b/code/actionformer_release/libs/core/config.py new file mode 100644 index 0000000..425a07f --- /dev/null +++ b/code/actionformer_release/libs/core/config.py @@ -0,0 +1,160 @@ +import yaml + + +DEFAULTS = { + # random seed for reproducibility, a large number is preferred + "init_rand_seed": 1234567891, + # dataset loader, specify the dataset here + "dataset_name": "epic", + "devices": ['cuda:0'], # default: single gpu + "train_split": ('training', ), + "val_split": ('validation', ), + "model_name": "LocPointTransformer", + "dataset": { + # temporal stride of the feats + "feat_stride": 16, + # number of frames for each feat + "num_frames": 32, + # default fps, may vary across datasets; Set to none for read from json file + "default_fps": None, + # input feat dim + "input_dim": 2304, + # number of classes + "num_classes": 97, + # downsampling rate of features, 1 to use original resolution + "downsample_rate": 1, + # max sequence length during training + "max_seq_len": 2304, + # threshold for truncating an action + "trunc_thresh": 0.5, + # set to a tuple (e.g., (0.9, 1.0)) to enable random feature cropping + # might not be implemented by the dataloader + "crop_ratio": None, + # if true, force upsampling of the input features into a fixed size + # only used for ActivityNet + "force_upsampling": False, + }, + "loader": { + "batch_size": 8, + "num_workers": 4, + }, + # network architecture + "model": { + # type of backbone (convTransformer | conv) + "backbone_type": 'convTransformer', + # type of FPN (fpn | identity) + "fpn_type": "identity", + "backbone_arch": (2, 2, 5), + # scale factor between pyramid levels + "scale_factor": 2, + # regression range for pyramid levels + "regression_range": [(0, 4), (4, 8), (8, 16), (16, 32), (32, 64), (64, 10000)], + # number of heads in self-attention + "n_head": 4, + # window size for self attention; <=1 to use full seq (ie global attention) + "n_mha_win_size": -1, + # kernel size for embedding network + "embd_kernel_size": 3, + # (output) feature dim for embedding network + "embd_dim": 512, + # if attach group norm to embedding network + "embd_with_ln": True, + # feat dim for FPN + "fpn_dim": 512, + # if add ln at the end of fpn outputs + "fpn_with_ln": True, + # starting level for fpn + "fpn_start_level": 0, + # feat dim for head + "head_dim": 512, + # kernel size for reg/cls/center heads + "head_kernel_size": 3, + # number of layers in the head (including the final one) + "head_num_layers": 3, + # if attach group norm to heads + "head_with_ln": True, + # defines the max length of the buffered points + "max_buffer_len_factor": 6.0, + # disable abs position encoding (added to input embedding) + "use_abs_pe": False, + # use rel position encoding (added to self-attention) + "use_rel_pe": False, + }, + "train_cfg": { + # radius | none (if to use center sampling) + "center_sample": "radius", + "center_sample_radius": 1.5, + "loss_weight": 1.0, # on reg_loss, use -1 to enable auto balancing + "cls_prior_prob": 0.01, + "init_loss_norm": 2000, + # gradient cliping, not needed for pre-LN transformer + "clip_grad_l2norm": -1, + # cls head without data (a fix to epic-kitchens / thumos) + "head_empty_cls": [], + # dropout ratios for tranformers + "dropout": 0.0, + # ratio for drop path + "droppath": 0.1, + # if to use label smoothing (>0.0) + "label_smoothing": 0.0, + }, + "test_cfg": { + "pre_nms_thresh": 0.001, + "pre_nms_topk": 5000, + "iou_threshold": 0.1, + "min_score": 0.01, + "max_seg_num": 1000, + "nms_method": 'soft', # soft | hard | none + "nms_sigma" : 0.5, + "duration_thresh": 0.05, + "multiclass_nms": True, + "ext_score_file": None, + "voting_thresh" : 0.75, + }, + # optimizer (for training) + "opt": { + # solver + "type": "AdamW", # SGD or AdamW + # solver params + "momentum": 0.9, + "weight_decay": 0.0, + "learning_rate": 1e-3, + # excluding the warmup epochs + "epochs": 30, + # lr scheduler: cosine / multistep + "warmup": True, + "warmup_epochs": 5, + "schedule_type": "cosine", + # in #epochs excluding warmup + "schedule_steps": [], + "schedule_gamma": 0.1, + } +} + +def _merge(src, dst): + for k, v in src.items(): + if k in dst: + if isinstance(v, dict): + _merge(src[k], dst[k]) + else: + dst[k] = v + +def load_default_config(): + config = DEFAULTS + return config + +def _update_config(config): + # fill in derived fields + config["model"]["input_dim"] = config["dataset"]["input_dim"] + config["model"]["num_classes"] = config["dataset"]["num_classes"] + config["model"]["max_seq_len"] = config["dataset"]["max_seq_len"] + config["model"]["train_cfg"] = config["train_cfg"] + config["model"]["test_cfg"] = config["test_cfg"] + return config + +def load_config(config_file, defaults=DEFAULTS): + with open(config_file, "r") as fd: + config = yaml.load(fd, Loader=yaml.FullLoader) + _merge(defaults, config) + config = _update_config(config) + return config \ No newline at end of file diff --git a/code/actionformer_release/libs/datasets/__init__.py b/code/actionformer_release/libs/datasets/__init__.py new file mode 100644 index 0000000..2a086fc --- /dev/null +++ b/code/actionformer_release/libs/datasets/__init__.py @@ -0,0 +1,6 @@ +from .data_utils import worker_init_reset_seed, truncate_feats +from .datasets import make_dataset, make_data_loader +from . import epic_kitchens, thumos14, anet, ego4d # other datasets go here + +__all__ = ['worker_init_reset_seed', 'truncate_feats', + 'make_dataset', 'make_data_loader'] diff --git a/code/actionformer_release/libs/datasets/anet.py b/code/actionformer_release/libs/datasets/anet.py new file mode 100644 index 0000000..be5149b --- /dev/null +++ b/code/actionformer_release/libs/datasets/anet.py @@ -0,0 +1,248 @@ +import os +import json +import h5py +import numpy as np + +import torch +from torch.utils.data import Dataset +from torch.nn import functional as F + +from .datasets import register_dataset +from .data_utils import truncate_feats +from ..utils import remove_duplicate_annotations + +@register_dataset("anet") +class ActivityNetDataset(Dataset): + def __init__( + self, + is_training, # if in training mode + split, # split, a tuple/list allowing concat of subsets + feat_folder, # folder for features + json_file, # json file for annotations + feat_stride, # temporal stride of the feats + num_frames, # number of frames for each feat + default_fps, # default fps + downsample_rate, # downsample rate for feats + max_seq_len, # maximum sequence length during training + trunc_thresh, # threshold for truncate an action segment + crop_ratio, # a tuple (e.g., (0.9, 1.0)) for random cropping + input_dim, # input feat dim + num_classes, # number of action categories + file_prefix, # feature file prefix if any + file_ext, # feature file extension if any + force_upsampling # force to upsample to max_seq_len + ): + # file path + assert os.path.exists(feat_folder) and os.path.exists(json_file) + assert isinstance(split, tuple) or isinstance(split, list) + assert crop_ratio == None or len(crop_ratio) == 2 + self.feat_folder = feat_folder + self.use_hdf5 = '.hdf5' in feat_folder + if file_prefix is not None: + self.file_prefix = file_prefix + else: + self.file_prefix = '' + self.file_ext = file_ext + self.json_file = json_file + + # anet uses fixed length features, make sure there is no downsampling + self.force_upsampling = force_upsampling + + # split / training mode + self.split = split + self.is_training = is_training + + # features meta info + self.feat_stride = feat_stride + self.num_frames = num_frames + self.input_dim = input_dim + self.default_fps = default_fps + self.downsample_rate = downsample_rate + self.max_seq_len = max_seq_len + self.trunc_thresh = trunc_thresh + self.num_classes = num_classes + self.label_dict = None + self.crop_ratio = crop_ratio + + # load database and select the subset + dict_db, label_dict = self._load_json_db(self.json_file) + # proposal vs action categories + assert (num_classes == 1) or (len(label_dict) == num_classes) + self.data_list = dict_db + self.label_dict = label_dict + + # dataset specific attributes + self.db_attributes = { + 'dataset_name': 'ActivityNet 1.3', + 'tiou_thresholds': np.linspace(0.5, 0.95, 10), + 'empty_label_ids': [] + } + + def get_attributes(self): + return self.db_attributes + + def _load_json_db(self, json_file): + # load database and select the subset + with open(json_file, 'r') as fid: + json_data = json.load(fid) + json_db = json_data['database'] + + # if label_dict is not available + if self.label_dict is None: + label_dict = {} + for key, value in json_db.items(): + for act in value['annotations']: + label_dict[act['label']] = act['label_id'] + + # fill in the db (immutable afterwards) + dict_db = tuple() + for key, value in json_db.items(): + # skip the video if not in the split + if value['subset'].lower() not in self.split: + continue + + # get fps if available + if self.default_fps is not None: + fps = self.default_fps + elif 'fps' in value: + fps = value['fps'] + else: + assert False, "Unknown video FPS." + duration = value['duration'] + + # get annotations if available + if ('annotations' in value) and (len(value['annotations']) > 0): + valid_acts = remove_duplicate_annotations(value['annotations']) + num_acts = len(valid_acts) + segments = np.zeros([num_acts, 2], dtype=np.float32) + labels = np.zeros([num_acts, ], dtype=np.int64) + for idx, act in enumerate(valid_acts): + segments[idx][0] = act['segment'][0] + segments[idx][1] = act['segment'][1] + if self.num_classes == 1: + labels[idx] = 0 + else: + labels[idx] = label_dict[act['label']] + else: + segments = None + labels = None + dict_db += ({'id': key, + 'fps' : fps, + 'duration' : duration, + 'segments' : segments, + 'labels' : labels + }, ) + + return dict_db, label_dict + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, idx): + # directly return a (truncated) data point (so it is very fast!) + # auto batching will be disabled in the subsequent dataloader + # instead the model will need to decide how to batch / preporcess the data + video_item = self.data_list[idx] + + # load features + if self.use_hdf5: + with h5py.File(self.feat_folder, 'r') as h5_fid: + feats = np.asarray( + h5_fid[self.file_prefix + video_item['id']][()], + dtype=np.float32 + ) + else: + filename = os.path.join(self.feat_folder, + self.file_prefix + video_item['id'] + self.file_ext) + feats = np.load(filename).astype(np.float32) + + # we support both fixed length features / variable length features + # case 1: variable length features for training + if self.feat_stride > 0 and (not self.force_upsampling): + # var length features + feat_stride, num_frames = self.feat_stride, self.num_frames + # only apply down sampling here + if self.downsample_rate > 1: + feats = feats[::self.downsample_rate, :] + feat_stride = self.feat_stride * self.downsample_rate + # case 2: variable length features for input, yet resized for training + elif self.feat_stride > 0 and self.force_upsampling: + feat_stride = float( + (feats.shape[0] - 1) * self.feat_stride + self.num_frames + ) / self.max_seq_len + # center the features + num_frames = feat_stride + # case 3: fixed length features for input + else: + # deal with fixed length feature, recompute feat_stride, num_frames + seq_len = feats.shape[0] + assert seq_len <= self.max_seq_len + if self.force_upsampling: + # reset to max_seq_len + seq_len = self.max_seq_len + feat_stride = video_item['duration'] * video_item['fps'] / seq_len + # center the features + num_frames = feat_stride + feat_offset = 0.5 * num_frames / feat_stride + + # T x C -> C x T + feats = torch.from_numpy(np.ascontiguousarray(feats.transpose())) + + # resize the features if needed + if (feats.shape[-1] != self.max_seq_len) and self.force_upsampling: + resize_feats = F.interpolate( + feats.unsqueeze(0), + size=self.max_seq_len, + mode='linear', + align_corners=False + ) + feats = resize_feats.squeeze(0) + + # convert time stamp (in second) into temporal feature grids + # ok to have small negative values here + if video_item['segments'] is not None: + segments = torch.from_numpy( + video_item['segments'] * video_item['fps'] / feat_stride - feat_offset + ) + labels = torch.from_numpy(video_item['labels']) + # for activity net, we have a few videos with a bunch of missing frames + # here is a quick fix for training + if self.is_training: + vid_len = feats.shape[1] + feat_offset + valid_seg_list, valid_label_list = [], [] + for seg, label in zip(segments, labels): + if seg[0] >= vid_len: + # skip an action outside of the feature map + continue + # skip an action that is mostly outside of the feature map + ratio = ( + (min(seg[1].item(), vid_len) - seg[0].item()) + / (seg[1].item() - seg[0].item()) + ) + if ratio >= self.trunc_thresh: + valid_seg_list.append(seg.clamp(max=vid_len)) + # some weird bug here if not converting to size 1 tensor + valid_label_list.append(label.view(1)) + segments = torch.stack(valid_seg_list, dim=0) + labels = torch.cat(valid_label_list) + else: + segments, labels = None, None + + # return a data dict + data_dict = {'video_id' : video_item['id'], + 'feats' : feats, # C x T + 'segments' : segments, # N x 2 + 'labels' : labels, # N + 'fps' : video_item['fps'], + 'duration' : video_item['duration'], + 'feat_stride' : feat_stride, + 'feat_num_frames' : num_frames} + + # no truncation is needed + # truncate the features during training + if self.is_training and (segments is not None): + data_dict = truncate_feats( + data_dict, self.max_seq_len, self.trunc_thresh, feat_offset, self.crop_ratio + ) + + return data_dict diff --git a/code/actionformer_release/libs/datasets/data_utils.py b/code/actionformer_release/libs/datasets/data_utils.py new file mode 100644 index 0000000..8cdb5b5 --- /dev/null +++ b/code/actionformer_release/libs/datasets/data_utils.py @@ -0,0 +1,112 @@ +import os +import copy +import random +import numpy as np +import random +import torch + + +def trivial_batch_collator(batch): + """ + A batch collator that does nothing + """ + return batch + +def worker_init_reset_seed(worker_id): + """ + Reset random seed for each worker + """ + seed = torch.initial_seed() % 2 ** 31 + np.random.seed(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + +def truncate_feats( + data_dict, + max_seq_len, + trunc_thresh, + offset, + crop_ratio=None, + max_num_trials=200, + has_action=True, + no_trunc=False +): + """ + Truncate feats and time stamps in a dict item + + data_dict = {'video_id' : str + 'feats' : Tensor C x T + 'segments' : Tensor N x 2 (in feature grid) + 'labels' : Tensor N + 'fps' : float + 'feat_stride' : int + 'feat_num_frames' : in + + """ + # get the meta info + feat_len = data_dict['feats'].shape[1] + num_segs = data_dict['segments'].shape[0] + + # seq_len < max_seq_len + if feat_len <= max_seq_len: + # do nothing + if crop_ratio == None: + return data_dict + # randomly crop the seq by setting max_seq_len to a value in [l, r] + else: + max_seq_len = random.randint( + max(round(crop_ratio[0] * feat_len), 1), + min(round(crop_ratio[1] * feat_len), feat_len), + ) + # # corner case + if feat_len == max_seq_len: + return data_dict + + # otherwise, deep copy the dict + data_dict = copy.deepcopy(data_dict) + + # try a few times till a valid truncation with at least one action + for _ in range(max_num_trials): + + # sample a random truncation of the video feats + st = random.randint(0, feat_len - max_seq_len) + ed = st + max_seq_len + window = torch.as_tensor([st, ed], dtype=torch.float32) + + # compute the intersection between the sampled window and all segments + window = window[None].repeat(num_segs, 1) + left = torch.maximum(window[:, 0] - offset, data_dict['segments'][:, 0]) + right = torch.minimum(window[:, 1] + offset, data_dict['segments'][:, 1]) + inter = (right - left).clamp(min=0) + area_segs = torch.abs( + data_dict['segments'][:, 1] - data_dict['segments'][:, 0]) + inter_ratio = inter / area_segs + + # only select those segments over the thresh + seg_idx = (inter_ratio >= trunc_thresh) + + if no_trunc: + # with at least one action and not truncating any actions + seg_trunc_idx = torch.logical_and( + (inter_ratio > 0.0), (inter_ratio < 1.0) + ) + if (seg_idx.sum().item() > 0) and (seg_trunc_idx.sum().item() == 0): + break + elif has_action: + # with at least one action + if seg_idx.sum().item() > 0: + break + else: + # without any constraints + break + + # feats: C x T + data_dict['feats'] = data_dict['feats'][:, st:ed].clone() + # segments: N x 2 in feature grids + data_dict['segments'] = torch.stack((left[seg_idx], right[seg_idx]), dim=1) + # shift the time stamps due to truncation + data_dict['segments'] = data_dict['segments'] - st + # labels: N + data_dict['labels'] = data_dict['labels'][seg_idx].clone() + + return data_dict diff --git a/code/actionformer_release/libs/datasets/datasets.py b/code/actionformer_release/libs/datasets/datasets.py new file mode 100644 index 0000000..0dd8f20 --- /dev/null +++ b/code/actionformer_release/libs/datasets/datasets.py @@ -0,0 +1,34 @@ +import os +import torch +from .data_utils import trivial_batch_collator, worker_init_reset_seed + +datasets = {} +def register_dataset(name): + def decorator(cls): + datasets[name] = cls + return cls + return decorator + +def make_dataset(name, is_training, split, **kwargs): + """ + A simple dataset builder + """ + dataset = datasets[name](is_training, split, **kwargs) + return dataset + +def make_data_loader(dataset, is_training, generator, batch_size, num_workers): + """ + A simple dataloder builder + """ + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=trivial_batch_collator, + worker_init_fn=(worker_init_reset_seed if is_training else None), + shuffle=is_training, + drop_last=is_training, + generator=generator, + persistent_workers=True + ) + return loader diff --git a/code/actionformer_release/libs/datasets/ego4d.py b/code/actionformer_release/libs/datasets/ego4d.py new file mode 100644 index 0000000..f526b4a --- /dev/null +++ b/code/actionformer_release/libs/datasets/ego4d.py @@ -0,0 +1,197 @@ +import os +import json +import numpy as np + +import torch +from torch.utils.data import Dataset +from torch.nn import functional as F + +from .datasets import register_dataset +from .data_utils import truncate_feats + +@register_dataset("ego4d") +class EGO4DDataset(Dataset): + def __init__( + self, + is_training, # if in training mode + split, # split, a tuple/list allowing concat of subsets + feat_folder, # folder for features + json_file, # json file for annotations + feat_stride, # temporal stride of the feats + num_frames, # number of frames for each feat + default_fps, # default fps + downsample_rate, # downsample rate for feats + max_seq_len, # maximum sequence length during training + trunc_thresh, # threshold for truncate an action segment + crop_ratio, # a tuple (e.g., (0.9, 1.0)) for random cropping + input_dim, # input feat dim + num_classes, # number of action categories + file_prefix, # feature file prefix if any + file_ext, # feature file extension if any + force_upsampling # force to upsample to max_seq_len + ): + # file path + if not isinstance(feat_folder, (list, tuple)): + feat_folder = (feat_folder, ) + assert all([os.path.exists(folder) for folder in feat_folder]) + assert os.path.exists(json_file) + assert isinstance(split, tuple) or isinstance(split, list) + assert crop_ratio == None or len(crop_ratio) == 2 + self.feat_folder = feat_folder + if file_prefix is not None: + self.file_prefix = file_prefix + else: + self.file_prefix = '' + self.file_ext = file_ext + self.json_file = json_file + + # split / training mode + self.split = split + self.is_training = is_training + + # features meta info + self.feat_stride = feat_stride + self.num_frames = num_frames + self.input_dim = input_dim + self.default_fps = default_fps + self.downsample_rate = downsample_rate + self.max_seq_len = max_seq_len + self.trunc_thresh = trunc_thresh + self.num_classes = num_classes + self.label_dict = None + self.crop_ratio = crop_ratio + + # load database and select the subset + dict_db, label_dict = self._load_json_db(self.json_file) + assert len(label_dict) == num_classes + self.data_list = dict_db + self.label_dict = label_dict + + # dataset specific attributes + self.db_attributes = { + 'dataset_name': 'ego4d', + 'tiou_thresholds': np.linspace(0.1, 0.5, 5), + 'empty_label_ids': [] + } + + def get_attributes(self): + return self.db_attributes + + def _load_json_db(self, json_file): + # load database and select the subset + with open(json_file, 'r') as fid: + json_data = json.load(fid) + json_db = json_data['database'] + + # if label_dict is not available + if self.label_dict is None: + label_dict = {} + for key, value in json_db.items(): + if 'annotations' not in value: + continue + for act in value['annotations']: + label_dict[act['label']] = act['label_id'] + + # fill in the db (immutable afterwards) + dict_db = tuple() + for key, value in json_db.items(): + # skip the video if not in the split + if value['subset'].lower() not in self.split: + continue + # or does not have the feature file + feat_files = [os.path.join( + folder, self.file_prefix + key + self.file_ext + ) for folder in self.feat_folder] + if not all([os.path.exists(file) for file in feat_files]): + continue + + # get fps if available + if self.default_fps is not None: + fps = self.default_fps + elif 'fps' in value: + fps = value['fps'] + else: + assert False, "Unknown video FPS." + + # get video duration if available + if 'duration' in value: + duration = value['duration'] + else: + duration = 1e8 + + # get annotations if available + if ('annotations' in value) and (len(value['annotations']) > 0): + num_acts = len(value['annotations']) + segments = np.zeros([num_acts, 2], dtype=np.float32) + labels = np.zeros([num_acts, ], dtype=np.int64) + for idx, act in enumerate(value['annotations']): + segments[idx][0] = act['segment'][0] + segments[idx][1] = act['segment'][1] + labels[idx] = label_dict[act['label']] + else: + segments = None + labels = None + + dict_db += ({'id': key, + 'fps' : fps, + 'duration' : duration, + 'segments' : segments, + 'labels' : labels, + 'offset': value.get('offset'), # only for test + }, ) + + return dict_db, label_dict + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, idx): + # directly return a (truncated) data point (so it is very fast!) + # auto batching will be disabled in the subsequent dataloader + # instead the model will need to decide how to batch / preporcess the data + video_item = self.data_list[idx] + + # load features + filenames = [os.path.join( + folder, self.file_prefix + video_item['id'] + self.file_ext + ) for folder in self.feat_folder] + feats = np.concatenate( + [np.load(name).astype(np.float32) for name in filenames], axis=1 + ) + + # deal with downsampling (= increased feat stride) + feats = feats[::self.downsample_rate, :] + feat_stride = self.feat_stride * self.downsample_rate + feat_offset = 0.5 * self.num_frames / feat_stride + # T x C -> C x T + feats = torch.from_numpy(np.ascontiguousarray(feats.transpose())) + + # convert time stamp (in second) into temporal feature grids + # ok to have small negative values here + if video_item['segments'] is not None: + segments = torch.from_numpy( + video_item['segments'] * video_item['fps'] / feat_stride - feat_offset + ) + labels = torch.from_numpy(video_item['labels']) + else: + segments, labels = None, None + + # return a data dict + data_dict = {'video_id' : video_item['id'], + 'feats' : feats, # C x T + 'segments' : segments, # N x 2 + 'labels' : labels, # N + 'fps' : video_item['fps'], + 'duration' : video_item['duration'], + 'feat_stride' : feat_stride, + 'feat_num_frames' : self.num_frames, + 'offset' : video_item['offset'], + } + + # truncate the features during training + if self.is_training and (segments is not None): + data_dict = truncate_feats( + data_dict, self.max_seq_len, self.trunc_thresh, feat_offset, self.crop_ratio + ) + + return data_dict \ No newline at end of file diff --git a/code/actionformer_release/libs/datasets/epic_kitchens.py b/code/actionformer_release/libs/datasets/epic_kitchens.py new file mode 100644 index 0000000..a70eb96 --- /dev/null +++ b/code/actionformer_release/libs/datasets/epic_kitchens.py @@ -0,0 +1,193 @@ +import os +import json +import numpy as np + +import torch +from torch.utils.data import Dataset +from torch.nn import functional as F + +from .datasets import register_dataset +from .data_utils import truncate_feats + +@register_dataset("epic") +class EpicKitchensDataset(Dataset): + def __init__( + self, + is_training, # if in training mode + split, # split, a tuple/list allowing concat of subsets + feat_folder, # folder for features + json_file, # json file for annotations + feat_stride, # temporal stride of the feats + num_frames, # number of frames for each feat + default_fps, # default fps + downsample_rate, # downsample rate for feats + max_seq_len, # maximum sequence length during training + trunc_thresh, # threshold for truncate an action segment + crop_ratio, # a tuple (e.g., (0.9, 1.0)) for random cropping + input_dim, # input feat dim + num_classes, # number of action categories + file_prefix, # feature file prefix if any + file_ext, # feature file extension if any + force_upsampling # force to upsample to max_seq_len + ): + # file path + assert os.path.exists(feat_folder) and os.path.exists(json_file) + assert isinstance(split, tuple) or isinstance(split, list) + assert crop_ratio == None or len(crop_ratio) == 2 + self.feat_folder = feat_folder + if file_prefix is not None: + self.file_prefix = file_prefix + else: + self.file_prefix = '' + self.file_ext = file_ext + self.json_file = json_file + + # split / training mode + self.split = split + self.is_training = is_training + + # features meta info + self.feat_stride = feat_stride + self.num_frames = num_frames + self.input_dim = input_dim + self.default_fps = default_fps + self.downsample_rate = downsample_rate + self.max_seq_len = max_seq_len + self.trunc_thresh = trunc_thresh + self.num_classes = num_classes + self.label_dict = None + self.crop_ratio = crop_ratio + + # load database and select the subset + dict_db, label_dict = self._load_json_db(self.json_file) + # "empty" noun categories on epic-kitchens + assert len(label_dict) <= num_classes + self.data_list = dict_db + self.label_dict = label_dict + + # dataset specific attributes + empty_label_ids = self.find_empty_cls(label_dict, num_classes) + self.db_attributes = { + 'dataset_name': 'epic-kitchens-100', + 'tiou_thresholds': np.linspace(0.1, 0.5, 5), + 'empty_label_ids': empty_label_ids + } + + def find_empty_cls(self, label_dict, num_classes): + # find categories with out a data sample + if len(label_dict) == num_classes: + return [] + empty_label_ids = [] + label_ids = [v for _, v in label_dict.items()] + for id in range(num_classes): + if id not in label_ids: + empty_label_ids.append(id) + return empty_label_ids + + def get_attributes(self): + return self.db_attributes + + def _load_json_db(self, json_file): + # load database and select the subset + with open(json_file, 'r') as fid: + json_data = json.load(fid) + json_db = json_data['database'] + + # if label_dict is not available + if self.label_dict is None: + label_dict = {} + for key, value in json_db.items(): + for act in value['annotations']: + label_dict[act['label']] = act['label_id'] + + # fill in the db (immutable afterwards) + dict_db = tuple() + for key, value in json_db.items(): + # skip the video if not in the split + if value['subset'].lower() not in self.split: + continue + + # get fps if available + if self.default_fps is not None: + fps = self.default_fps + elif 'fps' in value: + fps = value['fps'] + else: + assert False, "Unknown video FPS." + + # get video duration if available + if 'duration' in value: + duration = value['duration'] + else: + duration = 1e8 + + # get annotations if available + if ('annotations' in value) and (len(value['annotations']) > 0): + num_acts = len(value['annotations']) + segments = np.zeros([num_acts, 2], dtype=np.float32) + labels = np.zeros([num_acts, ], dtype=np.int64) + for idx, act in enumerate(value['annotations']): + segments[idx][0] = act['segment'][0] + segments[idx][1] = act['segment'][1] + labels[idx] = label_dict[act['label']] + else: + segments = None + labels = None + dict_db += ({'id': key, + 'fps' : fps, + 'duration' : duration, + 'segments' : segments, + 'labels' : labels + }, ) + + return dict_db, label_dict + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, idx): + # directly return a (truncated) data point (so it is very fast!) + # auto batching will be disabled in the subsequent dataloader + # instead the model will need to decide how to batch / preporcess the data + video_item = self.data_list[idx] + + # load features + filename = os.path.join(self.feat_folder, + self.file_prefix + video_item['id'] + self.file_ext) + with np.load(filename) as data: + feats = data['feats'].astype(np.float32) + + # deal with downsampling (= increased feat stride) + feats = feats[::self.downsample_rate, :] + feat_stride = self.feat_stride * self.downsample_rate + feat_offset = 0.5 * self.num_frames / feat_stride + # T x C -> C x T + feats = torch.from_numpy(np.ascontiguousarray(feats.transpose())) + + # convert time stamp (in second) into temporal feature grids + # ok to have small negative values here + if video_item['segments'] is not None: + segments = torch.from_numpy( + video_item['segments'] * video_item['fps'] / feat_stride - feat_offset + ) + labels = torch.from_numpy(video_item['labels']) + else: + segments, labels = None, None + + # return a data dict + data_dict = {'video_id' : video_item['id'], + 'feats' : feats, # C x T + 'segments' : segments, # N x 2 + 'labels' : labels, # N + 'fps' : video_item['fps'], + 'duration' : video_item['duration'], + 'feat_stride' : feat_stride, + 'feat_num_frames' : self.num_frames} + + # truncate the features during training + if self.is_training and (segments is not None): + data_dict = truncate_feats( + data_dict, self.max_seq_len, self.trunc_thresh, feat_offset, self.crop_ratio + ) + + return data_dict diff --git a/code/actionformer_release/libs/datasets/thumos14.py b/code/actionformer_release/libs/datasets/thumos14.py new file mode 100644 index 0000000..cda87c8 --- /dev/null +++ b/code/actionformer_release/libs/datasets/thumos14.py @@ -0,0 +1,187 @@ +import os +import json +import numpy as np + +import torch +from torch.utils.data import Dataset +from torch.nn import functional as F + +from .datasets import register_dataset +from .data_utils import truncate_feats + +@register_dataset("thumos") +class THUMOS14Dataset(Dataset): + def __init__( + self, + is_training, # if in training mode + split, # split, a tuple/list allowing concat of subsets + feat_folder, # folder for features + json_file, # json file for annotations + feat_stride, # temporal stride of the feats + num_frames, # number of frames for each feat + default_fps, # default fps + downsample_rate, # downsample rate for feats + max_seq_len, # maximum sequence length during training + trunc_thresh, # threshold for truncate an action segment + crop_ratio, # a tuple (e.g., (0.9, 1.0)) for random cropping + input_dim, # input feat dim + num_classes, # number of action categories + file_prefix, # feature file prefix if any + file_ext, # feature file extension if any + force_upsampling # force to upsample to max_seq_len + ): + # file path + assert os.path.exists(feat_folder) and os.path.exists(json_file) + assert isinstance(split, tuple) or isinstance(split, list) + assert crop_ratio == None or len(crop_ratio) == 2 + self.feat_folder = feat_folder + if file_prefix is not None: + self.file_prefix = file_prefix + else: + self.file_prefix = '' + self.file_ext = file_ext + self.json_file = json_file + + # split / training mode + self.split = split + self.is_training = is_training + + # features meta info + self.feat_stride = feat_stride + self.num_frames = num_frames + self.input_dim = input_dim + self.default_fps = default_fps + self.downsample_rate = downsample_rate + self.max_seq_len = max_seq_len + self.trunc_thresh = trunc_thresh + self.num_classes = num_classes + self.label_dict = None + self.crop_ratio = crop_ratio + + # load database and select the subset + dict_db, label_dict = self._load_json_db(self.json_file) + assert len(label_dict) == num_classes + self.data_list = dict_db + self.label_dict = label_dict + + # dataset specific attributes + self.db_attributes = { + 'dataset_name': 'thumos-14', + 'tiou_thresholds': np.linspace(0.3, 0.7, 5), + # we will mask out cliff diving + 'empty_label_ids': [], + } + + def get_attributes(self): + return self.db_attributes + + def _load_json_db(self, json_file): + # load database and select the subset + with open(json_file, 'r') as fid: + json_data = json.load(fid) + json_db = json_data['database'] + + # if label_dict is not available + if self.label_dict is None: + label_dict = {} + for key, value in json_db.items(): + for act in value['annotations']: + label_dict[act['label']] = act['label_id'] + + # fill in the db (immutable afterwards) + dict_db = tuple() + for key, value in json_db.items(): + # skip the video if not in the split + if value['subset'].lower() not in self.split: + continue + # or does not have the feature file + feat_file = os.path.join(self.feat_folder, + self.file_prefix + key + self.file_ext) + if not os.path.exists(feat_file): + continue + + # get fps if available + if self.default_fps is not None: + fps = self.default_fps + elif 'fps' in value: + fps = value['fps'] + else: + assert False, "Unknown video FPS." + + # get video duration if available + if 'duration' in value: + duration = value['duration'] + else: + duration = 1e8 + + # get annotations if available + if ('annotations' in value) and (len(value['annotations']) > 0): + # a fun fact of THUMOS: cliffdiving (4) is a subset of diving (7) + # our code can now handle this corner case + segments, labels = [], [] + for act in value['annotations']: + segments.append(act['segment']) + labels.append([label_dict[act['label']]]) + + segments = np.asarray(segments, dtype=np.float32) + labels = np.squeeze(np.asarray(labels, dtype=np.int64), axis=1) + else: + segments = None + labels = None + dict_db += ({'id': key, + 'fps' : fps, + 'duration' : duration, + 'segments' : segments, + 'labels' : labels + }, ) + + return dict_db, label_dict + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, idx): + # directly return a (truncated) data point (so it is very fast!) + # auto batching will be disabled in the subsequent dataloader + # instead the model will need to decide how to batch / preporcess the data + video_item = self.data_list[idx] + + # load features + filename = os.path.join(self.feat_folder, + self.file_prefix + video_item['id'] + self.file_ext) + feats = np.load(filename).astype(np.float32) + + # deal with downsampling (= increased feat stride) + feats = feats[::self.downsample_rate, :] + feat_stride = self.feat_stride * self.downsample_rate + feat_offset = 0.5 * self.num_frames / feat_stride + # T x C -> C x T + feats = torch.from_numpy(np.ascontiguousarray(feats.transpose())) + + # convert time stamp (in second) into temporal feature grids + # ok to have small negative values here + if video_item['segments'] is not None: + segments = torch.from_numpy( + video_item['segments'] * video_item['fps'] / feat_stride - feat_offset + ) + labels = torch.from_numpy(video_item['labels']) + else: + segments, labels = None, None + + # return a data dict + data_dict = {'video_id' : video_item['id'], + 'feats' : feats, # C x T + 'segments' : segments, # N x 2 + 'labels' : labels, # N + 'fps' : video_item['fps'], + 'duration' : video_item['duration'], + 'feat_stride' : feat_stride, + 'feat_num_frames' : self.num_frames} + + # truncate the features during training + if self.is_training and (segments is not None): + data_dict = truncate_feats( + data_dict, self.max_seq_len, self.trunc_thresh, feat_offset, self.crop_ratio + ) + + return data_dict diff --git a/code/actionformer_release/libs/modeling/__init__.py b/code/actionformer_release/libs/modeling/__init__.py new file mode 100644 index 0000000..6a38c19 --- /dev/null +++ b/code/actionformer_release/libs/modeling/__init__.py @@ -0,0 +1,11 @@ +from .blocks import (MaskedConv1D, MaskedMHCA, MaskedMHA, LayerNorm, + TransformerBlock, ConvBlock, Scale, AffineDropPath) +from .models import make_backbone, make_neck, make_meta_arch, make_generator +from . import backbones # backbones +from . import necks # necks +from . import loc_generators # location generators +from . import meta_archs # full models + +__all__ = ['MaskedConv1D', 'MaskedMHCA', 'MaskedMHA', 'LayerNorm', + 'TransformerBlock', 'ConvBlock', 'Scale', 'AffineDropPath', + 'make_backbone', 'make_neck', 'make_meta_arch', 'make_generator'] diff --git a/code/actionformer_release/libs/modeling/backbones.py b/code/actionformer_release/libs/modeling/backbones.py new file mode 100644 index 0000000..22d5d8f --- /dev/null +++ b/code/actionformer_release/libs/modeling/backbones.py @@ -0,0 +1,266 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .models import register_backbone +from .blocks import (get_sinusoid_encoding, TransformerBlock, MaskedConv1D, + ConvBlock, LayerNorm) + + +@register_backbone("convTransformer") +class ConvTransformerBackbone(nn.Module): + """ + A backbone that combines convolutions with transformers + """ + def __init__( + self, + n_in, # input feature dimension + n_embd, # embedding dimension (after convolution) + n_head, # number of head for self-attention in transformers + n_embd_ks, # conv kernel size of the embedding network + max_len, # max sequence length + arch = (2, 2, 5), # (#convs, #stem transformers, #branch transformers) + mha_win_size = [-1]*6, # size of local window for mha + scale_factor = 2, # dowsampling rate for the branch + with_ln = False, # if to attach layernorm after conv + attn_pdrop = 0.0, # dropout rate for the attention map + proj_pdrop = 0.0, # dropout rate for the projection / MLP + path_pdrop = 0.0, # droput rate for drop path + use_abs_pe = False, # use absolute position embedding + use_rel_pe = False, # use relative position embedding + ): + super().__init__() + assert len(arch) == 3 + assert len(mha_win_size) == (1 + arch[2]) + self.n_in = n_in + self.arch = arch + self.mha_win_size = mha_win_size + self.max_len = max_len + self.relu = nn.ReLU(inplace=True) + self.scale_factor = scale_factor + self.use_abs_pe = use_abs_pe + self.use_rel_pe = use_rel_pe + + # feature projection + self.n_in = n_in + if isinstance(n_in, (list, tuple)): + assert isinstance(n_embd, (list, tuple)) and len(n_in) == len(n_embd) + self.proj = nn.ModuleList([ + MaskedConv1D(c0, c1, 1) for c0, c1 in zip(n_in, n_embd) + ]) + n_in = n_embd = sum(n_embd) + else: + self.proj = None + + # embedding network using convs + self.embd = nn.ModuleList() + self.embd_norm = nn.ModuleList() + for idx in range(arch[0]): + n_in = n_embd if idx > 0 else n_in + self.embd.append( + MaskedConv1D( + n_in, n_embd, n_embd_ks, + stride=1, padding=n_embd_ks//2, bias=(not with_ln) + ) + ) + if with_ln: + self.embd_norm.append(LayerNorm(n_embd)) + else: + self.embd_norm.append(nn.Identity()) + + # position embedding (1, C, T), rescaled by 1/sqrt(n_embd) + if self.use_abs_pe: + pos_embd = get_sinusoid_encoding(self.max_len, n_embd) / (n_embd**0.5) + self.register_buffer("pos_embd", pos_embd, persistent=False) + + # stem network using (vanilla) transformer + self.stem = nn.ModuleList() + for idx in range(arch[1]): + self.stem.append( + TransformerBlock( + n_embd, n_head, + n_ds_strides=(1, 1), + attn_pdrop=attn_pdrop, + proj_pdrop=proj_pdrop, + path_pdrop=path_pdrop, + mha_win_size=self.mha_win_size[0], + use_rel_pe=self.use_rel_pe + ) + ) + + # main branch using transformer with pooling + self.branch = nn.ModuleList() + for idx in range(arch[2]): + self.branch.append( + TransformerBlock( + n_embd, n_head, + n_ds_strides=(self.scale_factor, self.scale_factor), + attn_pdrop=attn_pdrop, + proj_pdrop=proj_pdrop, + path_pdrop=path_pdrop, + mha_win_size=self.mha_win_size[1 + idx], + use_rel_pe=self.use_rel_pe + ) + ) + + # init weights + self.apply(self.__init_weights__) + + def __init_weights__(self, module): + # set nn.Linear/nn.Conv1d bias term to 0 + if isinstance(module, (nn.Linear, nn.Conv1d)): + if module.bias is not None: + torch.nn.init.constant_(module.bias, 0.) + + def forward(self, x, mask): + # x: batch size, feature channel, sequence length, + # mask: batch size, 1, sequence length (bool) + B, C, T = x.size() + + # feature projection + if isinstance(self.n_in, (list, tuple)): + x = torch.cat( + [proj(s, mask)[0] \ + for proj, s in zip(self.proj, x.split(self.n_in, dim=1)) + ], dim=1 + ) + + # embedding network + for idx in range(len(self.embd)): + x, mask = self.embd[idx](x, mask) + x = self.relu(self.embd_norm[idx](x)) + + # training: using fixed length position embeddings + if self.use_abs_pe and self.training: + assert T <= self.max_len, "Reached max length." + pe = self.pos_embd + # add pe to x + x = x + pe[:, :, :T] * mask.to(x.dtype) + + # inference: re-interpolate position embeddings for over-length sequences + if self.use_abs_pe and (not self.training): + if T >= self.max_len: + pe = F.interpolate( + self.pos_embd, T, mode='linear', align_corners=False) + else: + pe = self.pos_embd + # add pe to x + x = x + pe[:, :, :T] * mask.to(x.dtype) + + # stem transformer + for idx in range(len(self.stem)): + x, mask = self.stem[idx](x, mask) + + # prep for outputs + out_feats = (x, ) + out_masks = (mask, ) + + # main branch with downsampling + for idx in range(len(self.branch)): + x, mask = self.branch[idx](x, mask) + out_feats += (x, ) + out_masks += (mask, ) + + return out_feats, out_masks + + +@register_backbone("conv") +class ConvBackbone(nn.Module): + """ + A backbone that with only conv + """ + def __init__( + self, + n_in, # input feature dimension + n_embd, # embedding dimension (after convolution) + n_embd_ks, # conv kernel size of the embedding network + arch = (2, 2, 5), # (#convs, #stem convs, #branch convs) + scale_factor = 2, # dowsampling rate for the branch + with_ln=False, # if to use layernorm + ): + super().__init__() + assert len(arch) == 3 + self.n_in = n_in + self.arch = arch + self.relu = nn.ReLU(inplace=True) + self.scale_factor = scale_factor + + # feature projection + self.n_in = n_in + if isinstance(n_in, (list, tuple)): + assert isinstance(n_embd, (list, tuple)) and len(n_in) == len(n_embd) + self.proj = nn.ModuleList([ + MaskedConv1D(c0, c1, 1) for c0, c1 in zip(n_in, n_embd) + ]) + n_in = n_embd = sum(n_embd) + else: + self.proj = None + + # embedding network using convs + self.embd = nn.ModuleList() + self.embd_norm = nn.ModuleList() + for idx in range(arch[0]): + n_in = n_embd if idx > 0 else n_in + self.embd.append( + MaskedConv1D( + n_in, n_embd, n_embd_ks, + stride=1, padding=n_embd_ks//2, bias=(not with_ln) + ) + ) + if with_ln: + self.embd_norm.append(LayerNorm(n_embd)) + else: + self.embd_norm.append(nn.Identity()) + + # stem network using convs + self.stem = nn.ModuleList() + for idx in range(arch[1]): + self.stem.append(ConvBlock(n_embd, 3, 1)) + + # main branch using convs with pooling + self.branch = nn.ModuleList() + for idx in range(arch[2]): + self.branch.append(ConvBlock(n_embd, 3, self.scale_factor)) + + # init weights + self.apply(self.__init_weights__) + + def __init_weights__(self, module): + # set nn.Linear bias term to 0 + if isinstance(module, (nn.Linear, nn.Conv1d)): + if module.bias is not None: + torch.nn.init.constant_(module.bias, 0.) + + def forward(self, x, mask): + # x: batch size, feature channel, sequence length, + # mask: batch size, 1, sequence length (bool) + B, C, T = x.size() + + # feature projection + if isinstance(self.n_in, (list, tuple)): + x = torch.cat( + [proj(s, mask)[0] \ + for proj, s in zip(self.proj, x.split(self.n_in, dim=1)) + ], dim=1 + ) + + # embedding network + for idx in range(len(self.embd)): + x, mask = self.embd[idx](x, mask) + x = self.relu(self.embd_norm[idx](x)) + + # stem conv + for idx in range(len(self.stem)): + x, mask = self.stem[idx](x, mask) + + # prep for outputs + out_feats = (x, ) + out_masks = (mask, ) + + # main branch with downsampling + for idx in range(len(self.branch)): + x, mask = self.branch[idx](x, mask) + out_feats += (x, ) + out_masks += (mask, ) + + return out_feats, out_masks \ No newline at end of file diff --git a/code/actionformer_release/libs/modeling/blocks.py b/code/actionformer_release/libs/modeling/blocks.py new file mode 100644 index 0000000..b7b2d41 --- /dev/null +++ b/code/actionformer_release/libs/modeling/blocks.py @@ -0,0 +1,854 @@ +import math +import numpy as np + +import torch +import torch.nn.functional as F +from torch import nn +from .weight_init import trunc_normal_ + + +class MaskedConv1D(nn.Module): + """ + Masked 1D convolution. Interface remains the same as Conv1d. + Only support a sub set of 1d convs + """ + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros' + ): + super().__init__() + # element must be aligned + assert (kernel_size % 2 == 1) and (kernel_size // 2 == padding) + # stride + self.stride = stride + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode) + # zero out the bias term if it exists + if bias: + torch.nn.init.constant_(self.conv.bias, 0.) + + def forward(self, x, mask): + # x: batch size, feature channel, sequence length, + # mask: batch size, 1, sequence length (bool) + B, C, T = x.size() + # input length must be divisible by stride + assert T % self.stride == 0 + + # conv + out_conv = self.conv(x) + # compute the mask + if self.stride > 1: + # downsample the mask using nearest neighbor + out_mask = F.interpolate( + mask.to(x.dtype), size=out_conv.size(-1), mode='nearest' + ) + else: + # masking out the features + out_mask = mask.to(x.dtype) + + # masking the output, stop grad to mask + out_conv = out_conv * out_mask.detach() + out_mask = out_mask.bool() + return out_conv, out_mask + + +class LayerNorm(nn.Module): + """ + LayerNorm that supports inputs of size B, C, T + """ + def __init__( + self, + num_channels, + eps = 1e-5, + affine = True, + device = None, + dtype = None, + ): + super().__init__() + factory_kwargs = {'device': device, 'dtype': dtype} + self.num_channels = num_channels + self.eps = eps + self.affine = affine + + if self.affine: + self.weight = nn.Parameter( + torch.ones([1, num_channels, 1], **factory_kwargs)) + self.bias = nn.Parameter( + torch.zeros([1, num_channels, 1], **factory_kwargs)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + def forward(self, x): + assert x.dim() == 3 + assert x.shape[1] == self.num_channels + + # normalization along C channels + mu = torch.mean(x, dim=1, keepdim=True) + res_x = x - mu + sigma = torch.mean(res_x**2, dim=1, keepdim=True) + out = res_x / torch.sqrt(sigma + self.eps) + + # apply weight and bias + if self.affine: + out *= self.weight + out += self.bias + + return out + + +# helper functions for Transformer blocks +def get_sinusoid_encoding(n_position, d_hid): + ''' Sinusoid position encoding table ''' + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + # return a tensor of size 1 C T + return torch.FloatTensor(sinusoid_table).unsqueeze(0).transpose(1, 2) + + +# attention / transformers +class MaskedMHA(nn.Module): + """ + Multi Head Attention with mask + + Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py + """ + + def __init__( + self, + n_embd, # dimension of the input embedding + n_head, # number of heads in multi-head self-attention + attn_pdrop=0.0, # dropout rate for the attention map + proj_pdrop=0.0 # dropout rate for projection op + ): + super().__init__() + assert n_embd % n_head == 0 + self.n_embd = n_embd + self.n_head = n_head + self.n_channels = n_embd // n_head + self.scale = 1.0 / math.sqrt(self.n_channels) + + # key, query, value projections for all heads + # it is OK to ignore masking, as the mask will be attached on the attention + self.key = nn.Conv1d(self.n_embd, self.n_embd, 1) + self.query = nn.Conv1d(self.n_embd, self.n_embd, 1) + self.value = nn.Conv1d(self.n_embd, self.n_embd, 1) + + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.proj_drop = nn.Dropout(proj_pdrop) + + # output projection + self.proj = nn.Conv1d(self.n_embd, self.n_embd, 1) + + def forward(self, x, mask): + # x: batch size, feature channel, sequence length, + # mask: batch size, 1, sequence length (bool) + B, C, T = x.size() + + # calculate query, key, values for all heads in batch + # (B, nh * hs, T) + k = self.key(x) + q = self.query(x) + v = self.value(x) + + # move head forward to be the batch dim + # (B, nh * hs, T) -> (B, nh, T, hs) + k = k.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) + q = q.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) + v = v.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) + + # self-attention: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q * self.scale) @ k.transpose(-2, -1) + # prevent q from attending to invalid tokens + att = att.masked_fill(torch.logical_not(mask[:, :, None, :]), float('-inf')) + # softmax attn + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + out = att @ (v * mask[:, :, :, None].to(v.dtype)) + # re-assemble all head outputs side by side + out = out.transpose(2, 3).contiguous().view(B, C, -1) + + # output projection + skip connection + out = self.proj_drop(self.proj(out)) * mask.to(out.dtype) + return out, mask + + +class MaskedMHCA(nn.Module): + """ + Multi Head Conv Attention with mask + + Add a depthwise convolution within a standard MHA + The extra conv op can be used to + (1) encode relative position information (relacing position encoding); + (2) downsample the features if needed; + (3) match the feature channels + + Note: With current implementation, the downsampled feature will be aligned + to every s+1 time step, where s is the downsampling stride. This allows us + to easily interpolate the corresponding positional embeddings. + + Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py + """ + + def __init__( + self, + n_embd, # dimension of the output features + n_head, # number of heads in multi-head self-attention + n_qx_stride=1, # dowsampling stride for query and input + n_kv_stride=1, # downsampling stride for key and value + attn_pdrop=0.0, # dropout rate for the attention map + proj_pdrop=0.0, # dropout rate for projection op + ): + super().__init__() + assert n_embd % n_head == 0 + self.n_embd = n_embd + self.n_head = n_head + self.n_channels = n_embd // n_head + self.scale = 1.0 / math.sqrt(self.n_channels) + + # conv/pooling operations + assert (n_qx_stride == 1) or (n_qx_stride % 2 == 0) + assert (n_kv_stride == 1) or (n_kv_stride % 2 == 0) + self.n_qx_stride = n_qx_stride + self.n_kv_stride = n_kv_stride + + # query conv (depthwise) + kernel_size = self.n_qx_stride + 1 if self.n_qx_stride > 1 else 3 + stride, padding = self.n_kv_stride, kernel_size // 2 + self.query_conv = MaskedConv1D( + self.n_embd, self.n_embd, kernel_size, + stride=stride, padding=padding, groups=self.n_embd, bias=False + ) + self.query_norm = LayerNorm(self.n_embd) + + # key, value conv (depthwise) + kernel_size = self.n_kv_stride + 1 if self.n_kv_stride > 1 else 3 + stride, padding = self.n_kv_stride, kernel_size // 2 + self.key_conv = MaskedConv1D( + self.n_embd, self.n_embd, kernel_size, + stride=stride, padding=padding, groups=self.n_embd, bias=False + ) + self.key_norm = LayerNorm(self.n_embd) + self.value_conv = MaskedConv1D( + self.n_embd, self.n_embd, kernel_size, + stride=stride, padding=padding, groups=self.n_embd, bias=False + ) + self.value_norm = LayerNorm(self.n_embd) + + # key, query, value projections for all heads + # it is OK to ignore masking, as the mask will be attached on the attention + self.key = nn.Conv1d(self.n_embd, self.n_embd, 1) + self.query = nn.Conv1d(self.n_embd, self.n_embd, 1) + self.value = nn.Conv1d(self.n_embd, self.n_embd, 1) + + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.proj_drop = nn.Dropout(proj_pdrop) + + # output projection + self.proj = nn.Conv1d(self.n_embd, self.n_embd, 1) + + def forward(self, x, mask): + # x: batch size, feature channel, sequence length, + # mask: batch size, 1, sequence length (bool) + B, C, T = x.size() + + # query conv -> (B, nh * hs, T') + q, qx_mask = self.query_conv(x, mask) + q = self.query_norm(q) + # key, value conv -> (B, nh * hs, T'') + k, kv_mask = self.key_conv(x, mask) + k = self.key_norm(k) + v, _ = self.value_conv(x, mask) + v = self.value_norm(v) + + # projections + q = self.query(q) + k = self.key(k) + v = self.value(v) + + # move head forward to be the batch dim + # (B, nh * hs, T'/T'') -> (B, nh, T'/T'', hs) + k = k.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) + q = q.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) + v = v.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) + + # self-attention: (B, nh, T', hs) x (B, nh, hs, T'') -> (B, nh, T', T'') + att = (q * self.scale) @ k.transpose(-2, -1) + # prevent q from attending to invalid tokens + att = att.masked_fill(torch.logical_not(kv_mask[:, :, None, :]), float('-inf')) + # softmax attn + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + # (B, nh, T', T'') x (B, nh, T'', hs) -> (B, nh, T', hs) + out = att @ (v * kv_mask[:, :, :, None].to(v.dtype)) + # re-assemble all head outputs side by side + out = out.transpose(2, 3).contiguous().view(B, C, -1) + + # output projection + skip connection + out = self.proj_drop(self.proj(out)) * qx_mask.to(out.dtype) + return out, qx_mask + + +class LocalMaskedMHCA(nn.Module): + """ + Local Multi Head Conv Attention with mask + + Add a depthwise convolution within a standard MHA + The extra conv op can be used to + (1) encode relative position information (relacing position encoding); + (2) downsample the features if needed; + (3) match the feature channels + + Note: With current implementation, the downsampled feature will be aligned + to every s+1 time step, where s is the downsampling stride. This allows us + to easily interpolate the corresponding positional embeddings. + + The implementation is fairly tricky, code reference from + https://github.com/huggingface/transformers/blob/master/src/transformers/models/longformer/modeling_longformer.py + """ + + def __init__( + self, + n_embd, # dimension of the output features + n_head, # number of heads in multi-head self-attention + window_size, # size of the local attention window + n_qx_stride=1, # dowsampling stride for query and input + n_kv_stride=1, # downsampling stride for key and value + attn_pdrop=0.0, # dropout rate for the attention map + proj_pdrop=0.0, # dropout rate for projection op + use_rel_pe=False # use relative position encoding + ): + super().__init__() + assert n_embd % n_head == 0 + self.n_embd = n_embd + self.n_head = n_head + self.n_channels = n_embd // n_head + self.scale = 1.0 / math.sqrt(self.n_channels) + self.window_size = window_size + self.window_overlap = window_size // 2 + # must use an odd window size + assert self.window_size > 1 and self.n_head >= 1 + self.use_rel_pe = use_rel_pe + + # conv/pooling operations + assert (n_qx_stride == 1) or (n_qx_stride % 2 == 0) + assert (n_kv_stride == 1) or (n_kv_stride % 2 == 0) + self.n_qx_stride = n_qx_stride + self.n_kv_stride = n_kv_stride + + # query conv (depthwise) + kernel_size = self.n_qx_stride + 1 if self.n_qx_stride > 1 else 3 + stride, padding = self.n_kv_stride, kernel_size // 2 + self.query_conv = MaskedConv1D( + self.n_embd, self.n_embd, kernel_size, + stride=stride, padding=padding, groups=self.n_embd, bias=False + ) + self.query_norm = LayerNorm(self.n_embd) + + # key, value conv (depthwise) + kernel_size = self.n_kv_stride + 1 if self.n_kv_stride > 1 else 3 + stride, padding = self.n_kv_stride, kernel_size // 2 + self.key_conv = MaskedConv1D( + self.n_embd, self.n_embd, kernel_size, + stride=stride, padding=padding, groups=self.n_embd, bias=False + ) + self.key_norm = LayerNorm(self.n_embd) + self.value_conv = MaskedConv1D( + self.n_embd, self.n_embd, kernel_size, + stride=stride, padding=padding, groups=self.n_embd, bias=False + ) + self.value_norm = LayerNorm(self.n_embd) + + # key, query, value projections for all heads + # it is OK to ignore masking, as the mask will be attached on the attention + self.key = nn.Conv1d(self.n_embd, self.n_embd, 1) + self.query = nn.Conv1d(self.n_embd, self.n_embd, 1) + self.value = nn.Conv1d(self.n_embd, self.n_embd, 1) + + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.proj_drop = nn.Dropout(proj_pdrop) + + # output projection + self.proj = nn.Conv1d(self.n_embd, self.n_embd, 1) + + # relative position encoding + if self.use_rel_pe: + self.rel_pe = nn.Parameter( + torch.zeros(1, 1, self.n_head, self.window_size)) + trunc_normal_(self.rel_pe, std=(2.0 / self.n_embd)**0.5) + + @staticmethod + def _chunk(x, window_overlap): + """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" + # x: B x nh, T, hs + # non-overlapping chunks of size = 2w -> B x nh, T//2w, 2w, hs + x = x.view( + x.size(0), + x.size(1) // (window_overlap * 2), + window_overlap * 2, + x.size(2), + ) + + # use `as_strided` to make the chunks overlap with an overlap size = window_overlap + chunk_size = list(x.size()) + chunk_size[1] = chunk_size[1] * 2 - 1 + chunk_stride = list(x.stride()) + chunk_stride[1] = chunk_stride[1] // 2 + + # B x nh, #chunks = T//w - 1, 2w, hs + return x.as_strided(size=chunk_size, stride=chunk_stride) + + @staticmethod + def _pad_and_transpose_last_two_dims(x, padding): + """pads rows and then flips rows and columns""" + # padding value is not important because it will be overwritten + x = nn.functional.pad(x, padding) + x = x.view(*x.size()[:-2], x.size(-1), x.size(-2)) + return x + + @staticmethod + def _mask_invalid_locations(input_tensor, affected_seq_len): + beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) + beginning_mask = beginning_mask_2d[None, :, None, :] + ending_mask = beginning_mask.flip(dims=(1, 3)) + beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] + beginning_mask = beginning_mask.expand(beginning_input.size()) + # `== 1` converts to bool or uint8 + beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) + ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] + ending_mask = ending_mask.expand(ending_input.size()) + # `== 1` converts to bool or uint8 + ending_input.masked_fill_(ending_mask == 1, -float("inf")) + + @staticmethod + def _pad_and_diagonalize(x): + """ + shift every row 1 step right, converting columns into diagonals. + Example:: + chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492, + -1.8348, 0.7672, 0.2986, 0.0285, + -0.7584, 0.4206, -0.0405, 0.1599, + 2.0514, -1.1600, 0.5372, 0.2629 ] + window_overlap = num_rows = 4 + (pad & diagonalize) => + [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 + 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 + 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 + 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] + """ + total_num_heads, num_chunks, window_overlap, hidden_dim = x.size() + # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). + x = nn.functional.pad( + x, (0, window_overlap + 1) + ) + # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap + x = x.view(total_num_heads, num_chunks, -1) + # total_num_heads x num_chunks x window_overlap*window_overlap + x = x[:, :, :-window_overlap] + x = x.view( + total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim + ) + x = x[:, :, :, :-1] + return x + + def _sliding_chunks_query_key_matmul( + self, query, key, num_heads, window_overlap + ): + """ + Matrix multiplication of query and key tensors using with a sliding window attention pattern. This implementation splits the input into overlapping chunks of size 2w with an overlap of size w (window_overlap) + """ + # query / key: B*nh, T, hs + bnh, seq_len, head_dim = query.size() + batch_size = bnh // num_heads + assert seq_len % (window_overlap * 2) == 0 + assert query.size() == key.size() + + chunks_count = seq_len // window_overlap - 1 + + # B * num_heads, head_dim, #chunks=(T//w - 1), 2w + chunk_query = self._chunk(query, window_overlap) + chunk_key = self._chunk(key, window_overlap) + + # matrix multiplication + # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap + diagonal_chunked_attention_scores = torch.einsum( + "bcxd,bcyd->bcxy", (chunk_query, chunk_key)) + + # convert diagonals into columns + # B * num_heads, #chunks, 2w, 2w+1 + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( + diagonal_chunked_attention_scores, padding=(0, 0, 0, 1) + ) + + # allocate space for the overall attention matrix where the chunks are combined. The last dimension + # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to + # window_overlap previous words). The following column is attention score from each word to itself, then + # followed by window_overlap columns for the upper triangle. + diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( + (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) + ) + + # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions + # - copying the main diagonal and the upper triangle + diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, :, :window_overlap, : window_overlap + 1 + ] + diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, -1, window_overlap:, : window_overlap + 1 + ] + # - copying the lower triangle + diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ + :, :, -(window_overlap + 1) : -1, window_overlap + 1 : + ] + + diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ + :, 0, : window_overlap - 1, 1 - window_overlap : + ] + + # separate batch_size and num_heads dimensions again + diagonal_attention_scores = diagonal_attention_scores.view( + batch_size, num_heads, seq_len, 2 * window_overlap + 1 + ).transpose(2, 1) + + self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + return diagonal_attention_scores + + def _sliding_chunks_matmul_attn_probs_value( + self, attn_probs, value, num_heads, window_overlap + ): + """ + Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the + same shape as `attn_probs` + """ + bnh, seq_len, head_dim = value.size() + batch_size = bnh // num_heads + assert seq_len % (window_overlap * 2) == 0 + assert attn_probs.size(3) == 2 * window_overlap + 1 + chunks_count = seq_len // window_overlap - 1 + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + + chunked_attn_probs = attn_probs.transpose(1, 2).reshape( + batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 + ) + + # pad seq_len with w at the beginning of the sequence and another window overlap at the end + padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1) + + # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap + chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) + chunked_value_stride = padded_value.stride() + chunked_value_stride = ( + chunked_value_stride[0], + window_overlap * chunked_value_stride[1], + chunked_value_stride[1], + chunked_value_stride[2], + ) + chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) + + chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) + + context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) + return context.view(batch_size, num_heads, seq_len, head_dim) + + def forward(self, x, mask): + # x: batch size, feature channel, sequence length, + # mask: batch size, 1, sequence length (bool) + B, C, T = x.size() + + # step 1: depth convolutions + # query conv -> (B, nh * hs, T') + q, qx_mask = self.query_conv(x, mask) + q = self.query_norm(q) + # key, value conv -> (B, nh * hs, T'') + k, kv_mask = self.key_conv(x, mask) + k = self.key_norm(k) + v, _ = self.value_conv(x, mask) + v = self.value_norm(v) + + # step 2: query, key, value transforms & reshape + # projections + q = self.query(q) + k = self.key(k) + v = self.value(v) + # (B, nh * hs, T) -> (B, nh, T, hs) + q = q.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) + k = k.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) + v = v.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) + # view as (B * nh, T, hs) + q = q.view(B * self.n_head, -1, self.n_channels).contiguous() + k = k.view(B * self.n_head, -1, self.n_channels).contiguous() + v = v.view(B * self.n_head, -1, self.n_channels).contiguous() + + # step 3: compute local self-attention with rel pe and masking + q *= self.scale + # chunked query key attention -> B, T, nh, 2w+1 = window_size + att = self._sliding_chunks_query_key_matmul( + q, k, self.n_head, self.window_overlap) + + # rel pe + if self.use_rel_pe: + att += self.rel_pe + # kv_mask -> B, T'', 1 + inverse_kv_mask = torch.logical_not( + kv_mask[:, :, :, None].view(B, -1, 1)) + # 0 for valid slot, -inf for masked ones + float_inverse_kv_mask = inverse_kv_mask.type_as(q).masked_fill( + inverse_kv_mask, -1e4) + # compute the diagonal mask (for each local window) + diagonal_mask = self._sliding_chunks_query_key_matmul( + float_inverse_kv_mask.new_ones(size=float_inverse_kv_mask.size()), + float_inverse_kv_mask, + 1, + self.window_overlap + ) + att += diagonal_mask + + # ignore input masking for now + att = nn.functional.softmax(att, dim=-1) + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + att = att.masked_fill( + torch.logical_not(kv_mask.squeeze(1)[:, :, None, None]), 0.0) + att = self.attn_drop(att) + + # step 4: compute attention value product + output projection + # chunked attn value product -> B, nh, T, hs + out = self._sliding_chunks_matmul_attn_probs_value( + att, v, self.n_head, self.window_overlap) + # transpose to B, nh, hs, T -> B, nh*hs, T + out = out.transpose(2, 3).contiguous().view(B, C, -1) + # output projection + skip connection + out = self.proj_drop(self.proj(out)) * qx_mask.to(out.dtype) + return out, qx_mask + + +class TransformerBlock(nn.Module): + """ + A simple (post layer norm) Transformer block + Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py + """ + def __init__( + self, + n_embd, # dimension of the input features + n_head, # number of attention heads + n_ds_strides=(1, 1), # downsampling strides for q & x, k & v + n_out=None, # output dimension, if None, set to input dim + n_hidden=None, # dimension of the hidden layer in MLP + act_layer=nn.GELU, # nonlinear activation used in MLP, default GELU + attn_pdrop=0.0, # dropout rate for the attention map + proj_pdrop=0.0, # dropout rate for the projection / MLP + path_pdrop=0.0, # drop path rate + mha_win_size=-1, # > 0 to use window mha + use_rel_pe=False # if to add rel position encoding to attention + ): + super().__init__() + assert len(n_ds_strides) == 2 + # layer norm for order (B C T) + self.ln1 = LayerNorm(n_embd) + self.ln2 = LayerNorm(n_embd) + + # specify the attention module + if mha_win_size > 1: + self.attn = LocalMaskedMHCA( + n_embd, + n_head, + window_size=mha_win_size, + n_qx_stride=n_ds_strides[0], + n_kv_stride=n_ds_strides[1], + attn_pdrop=attn_pdrop, + proj_pdrop=proj_pdrop, + use_rel_pe=use_rel_pe # only valid for local attention + ) + else: + self.attn = MaskedMHCA( + n_embd, + n_head, + n_qx_stride=n_ds_strides[0], + n_kv_stride=n_ds_strides[1], + attn_pdrop=attn_pdrop, + proj_pdrop=proj_pdrop + ) + + # input + if n_ds_strides[0] > 1: + kernel_size, stride, padding = \ + n_ds_strides[0] + 1, n_ds_strides[0], (n_ds_strides[0] + 1)//2 + self.pool_skip = nn.MaxPool1d( + kernel_size, stride=stride, padding=padding) + else: + self.pool_skip = nn.Identity() + + # two layer mlp + if n_hidden is None: + n_hidden = 4 * n_embd # default + if n_out is None: + n_out = n_embd + # ok to use conv1d here with stride=1 + self.mlp = nn.Sequential( + nn.Conv1d(n_embd, n_hidden, 1), + act_layer(), + nn.Dropout(proj_pdrop, inplace=True), + nn.Conv1d(n_hidden, n_out, 1), + nn.Dropout(proj_pdrop, inplace=True), + ) + + # drop path + if path_pdrop > 0.0: + self.drop_path_attn = AffineDropPath(n_embd, drop_prob = path_pdrop) + self.drop_path_mlp = AffineDropPath(n_out, drop_prob = path_pdrop) + else: + self.drop_path_attn = nn.Identity() + self.drop_path_mlp = nn.Identity() + + def forward(self, x, mask, pos_embd=None): + # pre-LN transformer: https://arxiv.org/pdf/2002.04745.pdf + out, out_mask = self.attn(self.ln1(x), mask) + out_mask_float = out_mask.to(out.dtype) + out = self.pool_skip(x) * out_mask_float + self.drop_path_attn(out) + # FFN + out = out + self.drop_path_mlp(self.mlp(self.ln2(out)) * out_mask_float) + # optionally add pos_embd to the output + if pos_embd is not None: + out += pos_embd * out_mask_float + return out, out_mask + + +class ConvBlock(nn.Module): + """ + A simple conv block similar to the basic block used in ResNet + """ + def __init__( + self, + n_embd, # dimension of the input features + kernel_size=3, # conv kernel size + n_ds_stride=1, # downsampling stride for the current layer + expansion_factor=2, # expansion factor of feat dims + n_out=None, # output dimension, if None, set to input dim + act_layer=nn.ReLU, # nonlinear activation used after conv, default ReLU + ): + super().__init__() + # must use odd sized kernel + assert (kernel_size % 2 == 1) and (kernel_size > 1) + padding = kernel_size // 2 + if n_out is None: + n_out = n_embd + + # 1x3 (strided) -> 1x3 (basic block in resnet) + width = n_embd * expansion_factor + self.conv1 = MaskedConv1D( + n_embd, width, kernel_size, n_ds_stride, padding=padding) + self.conv2 = MaskedConv1D( + width, n_out, kernel_size, 1, padding=padding) + + # attach downsampling conv op + if n_ds_stride > 1: + # 1x1 strided conv (same as resnet) + self.downsample = MaskedConv1D(n_embd, n_out, 1, n_ds_stride) + else: + self.downsample = None + + self.act = act_layer() + + def forward(self, x, mask, pos_embd=None): + identity = x + out, out_mask = self.conv1(x, mask) + out = self.act(out) + out, out_mask = self.conv2(out, out_mask) + + # downsampling + if self.downsample is not None: + identity, _ = self.downsample(x, mask) + + # residual connection + out += identity + out = self.act(out) + + return out, out_mask + + +# drop path: from https://github.com/facebookresearch/SlowFast/blob/master/slowfast/models/common.py +class Scale(nn.Module): + """ + Multiply the output regression range by a learnable constant value + """ + def __init__(self, init_value=1.0): + """ + init_value : initial value for the scalar + """ + super().__init__() + self.scale = nn.Parameter( + torch.tensor(init_value, dtype=torch.float32), + requires_grad=True + ) + + def forward(self, x): + """ + input -> scale * input + """ + return x * self.scale + + +# The follow code is modified from +# https://github.com/facebookresearch/SlowFast/blob/master/slowfast/models/common.py +def drop_path(x, drop_prob=0.0, training=False): + """ + Stochastic Depth per sample. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + mask.floor_() # binarize + output = x.div(keep_prob) * mask + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class AffineDropPath(nn.Module): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks) with a per channel scaling factor (and zero init) + See: https://arxiv.org/pdf/2103.17239.pdf + """ + + def __init__(self, num_dim, drop_prob=0.0, init_scale_value=1e-4): + super().__init__() + self.scale = nn.Parameter( + init_scale_value * torch.ones((1, num_dim, 1)), + requires_grad=True + ) + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(self.scale * x, self.drop_prob, self.training) diff --git a/code/actionformer_release/libs/modeling/loc_generators.py b/code/actionformer_release/libs/modeling/loc_generators.py new file mode 100644 index 0000000..d2f7471 --- /dev/null +++ b/code/actionformer_release/libs/modeling/loc_generators.py @@ -0,0 +1,84 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .models import register_generator + + +class BufferList(nn.Module): + """ + Similar to nn.ParameterList, but for buffers + + Taken from https://github.com/facebookresearch/detectron2/blob/master/detectron2/modeling/anchor_generator.py + """ + + def __init__(self, buffers): + super().__init__() + for i, buffer in enumerate(buffers): + # Use non-persistent buffer so the values are not saved in checkpoint + self.register_buffer(str(i), buffer, persistent=False) + + def __len__(self): + return len(self._buffers) + + def __iter__(self): + return iter(self._buffers.values()) + +@register_generator('point') +class PointGenerator(nn.Module): + """ + A generator for temporal "points" + + max_seq_len can be much larger than the actual seq length + """ + def __init__( + self, + max_seq_len, # max sequence length that the generator will buffer + fpn_strides, # strides of fpn levels + regression_range, # regression range (on feature grids) + use_offset=False # if to align the points at grid centers + ): + super().__init__() + # sanity check, # fpn levels and length divisible + fpn_levels = len(fpn_strides) + assert len(regression_range) == fpn_levels + + # save params + self.max_seq_len = max_seq_len + self.fpn_levels = fpn_levels + self.fpn_strides = fpn_strides + self.regression_range = regression_range + self.use_offset = use_offset + + # generate all points and buffer the list + self.buffer_points = self._generate_points() + + def _generate_points(self): + points_list = [] + # loop over all points at each pyramid level + for l, stride in enumerate(self.fpn_strides): + reg_range = torch.as_tensor( + self.regression_range[l], dtype=torch.float) + fpn_stride = torch.as_tensor(stride, dtype=torch.float) + points = torch.arange(0, self.max_seq_len, stride)[:, None] + # add offset if necessary (not in our current model) + if self.use_offset: + points += 0.5 * stride + # pad the time stamp with additional regression range / stride + reg_range = reg_range[None].repeat(points.shape[0], 1) + fpn_stride = fpn_stride[None].repeat(points.shape[0], 1) + # size: T x 4 (ts, reg_range, stride) + points_list.append(torch.cat((points, reg_range, fpn_stride), dim=1)) + + return BufferList(points_list) + + def forward(self, feats): + # feats will be a list of torch tensors + assert len(feats) == self.fpn_levels + pts_list = [] + feat_lens = [feat.shape[-1] for feat in feats] + for feat_len, buffer_pts in zip(feat_lens, self.buffer_points): + assert feat_len <= buffer_pts.shape[0], "Reached max buffer length for point generator" + pts = buffer_pts[:feat_len, :] + pts_list.append(pts) + return pts_list \ No newline at end of file diff --git a/code/actionformer_release/libs/modeling/losses.py b/code/actionformer_release/libs/modeling/losses.py new file mode 100644 index 0000000..0e3d370 --- /dev/null +++ b/code/actionformer_release/libs/modeling/losses.py @@ -0,0 +1,168 @@ +import torch +from torch.nn import functional as F + +@torch.jit.script +def sigmoid_focal_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2.0, + reduction: str = "none", +) -> torch.Tensor: + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Taken from + https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py + # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = 0.25. + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + Returns: + Loss tensor with the reduction option applied. + """ + inputs = inputs.float() + targets = targets.float() + p = torch.sigmoid(inputs) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = p * targets + (1 - p) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + + return loss + + +@torch.jit.script +def ctr_giou_loss_1d( + input_offsets: torch.Tensor, + target_offsets: torch.Tensor, + reduction: str = 'none', + eps: float = 1e-8, +) -> torch.Tensor: + """ + Generalized Intersection over Union Loss (Hamid Rezatofighi et. al) + https://arxiv.org/abs/1902.09630 + + This is an implementation that assumes a 1D event is represented using + the same center point with different offsets, e.g., + (t1, t2) = (c - o_1, c + o_2) with o_i >= 0 + + Reference code from + https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/giou_loss.py + + Args: + input/target_offsets (Tensor): 1D offsets of size (N, 2) + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + eps (float): small number to prevent division by zero + """ + input_offsets = input_offsets.float() + target_offsets = target_offsets.float() + # check all 1D events are valid + assert (input_offsets >= 0.0).all(), "predicted offsets must be non-negative" + assert (target_offsets >= 0.0).all(), "GT offsets must be non-negative" + + lp, rp = input_offsets[:, 0], input_offsets[:, 1] + lg, rg = target_offsets[:, 0], target_offsets[:, 1] + + # intersection key points + lkis = torch.min(lp, lg) + rkis = torch.min(rp, rg) + + # iou + intsctk = rkis + lkis + unionk = (lp + rp) + (lg + rg) - intsctk + iouk = intsctk / unionk.clamp(min=eps) + + # giou is reduced to iou in our setting, skip unnecessary steps + loss = 1.0 - iouk + + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + + return loss + +@torch.jit.script +def ctr_diou_loss_1d( + input_offsets: torch.Tensor, + target_offsets: torch.Tensor, + reduction: str = 'none', + eps: float = 1e-8, +) -> torch.Tensor: + """ + Distance-IoU Loss (Zheng et. al) + https://arxiv.org/abs/1911.08287 + + This is an implementation that assumes a 1D event is represented using + the same center point with different offsets, e.g., + (t1, t2) = (c - o_1, c + o_2) with o_i >= 0 + + Reference code from + https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/giou_loss.py + + Args: + input/target_offsets (Tensor): 1D offsets of size (N, 2) + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + eps (float): small number to prevent division by zero + """ + input_offsets = input_offsets.float() + target_offsets = target_offsets.float() + # check all 1D events are valid + assert (input_offsets >= 0.0).all(), "predicted offsets must be non-negative" + assert (target_offsets >= 0.0).all(), "GT offsets must be non-negative" + + lp, rp = input_offsets[:, 0], input_offsets[:, 1] + lg, rg = target_offsets[:, 0], target_offsets[:, 1] + + # intersection key points + lkis = torch.min(lp, lg) + rkis = torch.min(rp, rg) + + # iou + intsctk = rkis + lkis + unionk = (lp + rp) + (lg + rg) - intsctk + iouk = intsctk / unionk.clamp(min=eps) + + # smallest enclosing box + lc = torch.max(lp, lg) + rc = torch.max(rp, rg) + len_c = lc + rc + + # offset between centers + rho = 0.5 * (rp - lp - rg + lg) + + # diou + loss = 1.0 - iouk + torch.square(rho / len_c.clamp(min=eps)) + + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + + return loss diff --git a/code/actionformer_release/libs/modeling/meta_archs.py b/code/actionformer_release/libs/modeling/meta_archs.py new file mode 100644 index 0000000..fa6e93f --- /dev/null +++ b/code/actionformer_release/libs/modeling/meta_archs.py @@ -0,0 +1,753 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from .models import register_meta_arch, make_backbone, make_neck, make_generator +from .blocks import MaskedConv1D, Scale, LayerNorm +from .losses import ctr_diou_loss_1d, sigmoid_focal_loss + +from ..utils import batched_nms + +class PtTransformerClsHead(nn.Module): + """ + 1D Conv heads for classification + """ + def __init__( + self, + input_dim, + feat_dim, + num_classes, + prior_prob=0.01, + num_layers=3, + kernel_size=3, + act_layer=nn.ReLU, + with_ln=False, + empty_cls = [] + ): + super().__init__() + self.act = act_layer() + + # build the head + self.head = nn.ModuleList() + self.norm = nn.ModuleList() + for idx in range(num_layers-1): + if idx == 0: + in_dim = input_dim + out_dim = feat_dim + else: + in_dim = feat_dim + out_dim = feat_dim + self.head.append( + MaskedConv1D( + in_dim, out_dim, kernel_size, + stride=1, + padding=kernel_size//2, + bias=(not with_ln) + ) + ) + if with_ln: + self.norm.append(LayerNorm(out_dim)) + else: + self.norm.append(nn.Identity()) + + # classifier + self.cls_head = MaskedConv1D( + feat_dim, num_classes, kernel_size, + stride=1, padding=kernel_size//2 + ) + + # use prior in model initialization to improve stability + # this will overwrite other weight init + if prior_prob > 0: + bias_value = -(math.log((1 - prior_prob) / prior_prob)) + torch.nn.init.constant_(self.cls_head.conv.bias, bias_value) + + # a quick fix to empty categories: + # the weights assocaited with these categories will remain unchanged + # we set their bias to a large negative value to prevent their outputs + if len(empty_cls) > 0: + bias_value = -(math.log((1 - 1e-6) / 1e-6)) + for idx in empty_cls: + torch.nn.init.constant_(self.cls_head.conv.bias[idx], bias_value) + + def forward(self, fpn_feats, fpn_masks): + assert len(fpn_feats) == len(fpn_masks) + + # apply the classifier for each pyramid level + out_logits = tuple() + for _, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)): + cur_out = cur_feat + for idx in range(len(self.head)): + cur_out, _ = self.head[idx](cur_out, cur_mask) + cur_out = self.act(self.norm[idx](cur_out)) + cur_logits, _ = self.cls_head(cur_out, cur_mask) + out_logits += (cur_logits, ) + + # fpn_masks remains the same + return out_logits + + +class PtTransformerRegHead(nn.Module): + """ + Shared 1D Conv heads for regression + Simlar logic as PtTransformerClsHead with separated implementation for clarity + """ + def __init__( + self, + input_dim, + feat_dim, + fpn_levels, + num_layers=3, + kernel_size=3, + act_layer=nn.ReLU, + with_ln=False + ): + super().__init__() + self.fpn_levels = fpn_levels + self.act = act_layer() + + # build the conv head + self.head = nn.ModuleList() + self.norm = nn.ModuleList() + for idx in range(num_layers-1): + if idx == 0: + in_dim = input_dim + out_dim = feat_dim + else: + in_dim = feat_dim + out_dim = feat_dim + self.head.append( + MaskedConv1D( + in_dim, out_dim, kernel_size, + stride=1, + padding=kernel_size//2, + bias=(not with_ln) + ) + ) + if with_ln: + self.norm.append(LayerNorm(out_dim)) + else: + self.norm.append(nn.Identity()) + + self.scale = nn.ModuleList() + for idx in range(fpn_levels): + self.scale.append(Scale()) + + # segment regression + self.offset_head = MaskedConv1D( + feat_dim, 2, kernel_size, + stride=1, padding=kernel_size//2 + ) + + def forward(self, fpn_feats, fpn_masks): + assert len(fpn_feats) == len(fpn_masks) + assert len(fpn_feats) == self.fpn_levels + + # apply the classifier for each pyramid level + out_offsets = tuple() + for l, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)): + cur_out = cur_feat + for idx in range(len(self.head)): + cur_out, _ = self.head[idx](cur_out, cur_mask) + cur_out = self.act(self.norm[idx](cur_out)) + cur_offsets, _ = self.offset_head(cur_out, cur_mask) + out_offsets += (F.relu(self.scale[l](cur_offsets)), ) + + # fpn_masks remains the same + return out_offsets + + +@register_meta_arch("LocPointTransformer") +class PtTransformer(nn.Module): + """ + Transformer based model for single stage action localization + """ + def __init__( + self, + backbone_type, # a string defines which backbone we use + fpn_type, # a string defines which fpn we use + backbone_arch, # a tuple defines #layers in embed / stem / branch + scale_factor, # scale factor between branch layers + input_dim, # input feat dim + max_seq_len, # max sequence length (used for training) + max_buffer_len_factor, # max buffer size (defined a factor of max_seq_len) + n_head, # number of heads for self-attention in transformer + n_mha_win_size, # window size for self attention; -1 to use full seq + embd_kernel_size, # kernel size of the embedding network + embd_dim, # output feat channel of the embedding network + embd_with_ln, # attach layernorm to embedding network + fpn_dim, # feature dim on FPN + fpn_with_ln, # if to apply layer norm at the end of fpn + fpn_start_level, # start level of fpn + head_dim, # feature dim for head + regression_range, # regression range on each level of FPN + head_num_layers, # number of layers in the head (including the classifier) + head_kernel_size, # kernel size for reg/cls heads + head_with_ln, # attache layernorm to reg/cls heads + use_abs_pe, # if to use abs position encoding + use_rel_pe, # if to use rel position encoding + num_classes, # number of action classes + train_cfg, # other cfg for training + test_cfg # other cfg for testing + ): + super().__init__() + # re-distribute params to backbone / neck / head + self.fpn_strides = [scale_factor**i for i in range( + fpn_start_level, backbone_arch[-1]+1 + )] + self.reg_range = regression_range + assert len(self.fpn_strides) == len(self.reg_range) + self.scale_factor = scale_factor + # #classes = num_classes + 1 (background) with last category as background + # e.g., num_classes = 10 -> 0, 1, ..., 9 as actions, 10 as background + self.num_classes = num_classes + + # check the feature pyramid and local attention window size + self.max_seq_len = max_seq_len + if isinstance(n_mha_win_size, int): + self.mha_win_size = [n_mha_win_size]*(1 + backbone_arch[-1]) + else: + assert len(n_mha_win_size) == (1 + backbone_arch[-1]) + self.mha_win_size = n_mha_win_size + max_div_factor = 1 + for l, (s, w) in enumerate(zip(self.fpn_strides, self.mha_win_size)): + stride = s * (w // 2) * 2 if w > 1 else s + assert max_seq_len % stride == 0, "max_seq_len must be divisible by fpn stride and window size" + if max_div_factor < stride: + max_div_factor = stride + self.max_div_factor = max_div_factor + + # training time config + self.train_center_sample = train_cfg['center_sample'] + assert self.train_center_sample in ['radius', 'none'] + self.train_center_sample_radius = train_cfg['center_sample_radius'] + self.train_loss_weight = train_cfg['loss_weight'] + self.train_cls_prior_prob = train_cfg['cls_prior_prob'] + self.train_dropout = train_cfg['dropout'] + self.train_droppath = train_cfg['droppath'] + self.train_label_smoothing = train_cfg['label_smoothing'] + + # test time config + self.test_pre_nms_thresh = test_cfg['pre_nms_thresh'] + self.test_pre_nms_topk = test_cfg['pre_nms_topk'] + self.test_iou_threshold = test_cfg['iou_threshold'] + self.test_min_score = test_cfg['min_score'] + self.test_max_seg_num = test_cfg['max_seg_num'] + self.test_nms_method = test_cfg['nms_method'] + assert self.test_nms_method in ['soft', 'hard', 'none'] + self.test_duration_thresh = test_cfg['duration_thresh'] + self.test_multiclass_nms = test_cfg['multiclass_nms'] + self.test_nms_sigma = test_cfg['nms_sigma'] + self.test_voting_thresh = test_cfg['voting_thresh'] + + # we will need a better way to dispatch the params to backbones / necks + # backbone network: conv + transformer + assert backbone_type in ['convTransformer', 'conv'] + if backbone_type == 'convTransformer': + self.backbone = make_backbone( + 'convTransformer', + **{ + 'n_in' : input_dim, + 'n_embd' : embd_dim, + 'n_head': n_head, + 'n_embd_ks': embd_kernel_size, + 'max_len': max_seq_len, + 'arch' : backbone_arch, + 'mha_win_size': self.mha_win_size, + 'scale_factor' : scale_factor, + 'with_ln' : embd_with_ln, + 'attn_pdrop' : 0.0, + 'proj_pdrop' : self.train_dropout, + 'path_pdrop' : self.train_droppath, + 'use_abs_pe' : use_abs_pe, + 'use_rel_pe' : use_rel_pe + } + ) + else: + self.backbone = make_backbone( + 'conv', + **{ + 'n_in': input_dim, + 'n_embd': embd_dim, + 'n_embd_ks': embd_kernel_size, + 'arch': backbone_arch, + 'scale_factor': scale_factor, + 'with_ln' : embd_with_ln + } + ) + if isinstance(embd_dim, (list, tuple)): + embd_dim = sum(embd_dim) + + # fpn network: convs + assert fpn_type in ['fpn', 'identity'] + self.neck = make_neck( + fpn_type, + **{ + 'in_channels' : [embd_dim] * (backbone_arch[-1] + 1), + 'out_channel' : fpn_dim, + 'scale_factor' : scale_factor, + 'start_level' : fpn_start_level, + 'with_ln' : fpn_with_ln + } + ) + + # location generator: points + self.point_generator = make_generator( + 'point', + **{ + 'max_seq_len' : max_seq_len * max_buffer_len_factor, + 'fpn_strides' : self.fpn_strides, + 'regression_range' : self.reg_range + } + ) + + # classfication and regerssion heads + self.cls_head = PtTransformerClsHead( + fpn_dim, head_dim, self.num_classes, + kernel_size=head_kernel_size, + prior_prob=self.train_cls_prior_prob, + with_ln=head_with_ln, + num_layers=head_num_layers, + empty_cls=train_cfg['head_empty_cls'] + ) + self.reg_head = PtTransformerRegHead( + fpn_dim, head_dim, len(self.fpn_strides), + kernel_size=head_kernel_size, + num_layers=head_num_layers, + with_ln=head_with_ln + ) + + # maintain an EMA of #foreground to stabilize the loss normalizer + # useful for small mini-batch training + self.loss_normalizer = train_cfg['init_loss_norm'] + self.loss_normalizer_momentum = 0.9 + + @property + def device(self): + # a hacky way to get the device type + # will throw an error if parameters are on different devices + return list(set(p.device for p in self.parameters()))[0] + + def forward(self, video_list): + # batch the video list into feats (B, C, T) and masks (B, 1, T) + batched_inputs, batched_masks = self.preprocessing(video_list) + + # forward the network (backbone -> neck -> heads) + feats, masks = self.backbone(batched_inputs, batched_masks) + fpn_feats, fpn_masks = self.neck(feats, masks) + + # compute the point coordinate along the FPN + # this is used for computing the GT or decode the final results + # points: List[T x 4] with length = # fpn levels + # (shared across all samples in the mini-batch) + points = self.point_generator(fpn_feats) + + # out_cls: List[B, #cls + 1, T_i] + out_cls_logits = self.cls_head(fpn_feats, fpn_masks) + # out_offset: List[B, 2, T_i] + out_offsets = self.reg_head(fpn_feats, fpn_masks) + + # permute the outputs + # out_cls: F List[B, #cls, T_i] -> F List[B, T_i, #cls] + out_cls_logits = [x.permute(0, 2, 1) for x in out_cls_logits] + # out_offset: F List[B, 2 (xC), T_i] -> F List[B, T_i, 2 (xC)] + out_offsets = [x.permute(0, 2, 1) for x in out_offsets] + # fpn_masks: F list[B, 1, T_i] -> F List[B, T_i] + fpn_masks = [x.squeeze(1) for x in fpn_masks] + + # return loss during training + if self.training: + # generate segment/lable List[N x 2] / List[N] with length = B + assert video_list[0]['segments'] is not None, "GT action labels does not exist" + assert video_list[0]['labels'] is not None, "GT action labels does not exist" + gt_segments = [x['segments'].to(self.device) for x in video_list] + gt_labels = [x['labels'].to(self.device) for x in video_list] + + # compute the gt labels for cls & reg + # list of prediction targets + gt_cls_labels, gt_offsets = self.label_points( + points, gt_segments, gt_labels) + + # compute the loss and return + losses = self.losses( + fpn_masks, + out_cls_logits, out_offsets, + gt_cls_labels, gt_offsets + ) + return losses + + else: + # decode the actions (sigmoid / stride, etc) + results = self.inference( + video_list, points, fpn_masks, + out_cls_logits, out_offsets + ) + return results + + @torch.no_grad() + def preprocessing(self, video_list, padding_val=0.0): + """ + Generate batched features and masks from a list of dict items + """ + feats = [x['feats'] for x in video_list] + feats_lens = torch.as_tensor([feat.shape[-1] for feat in feats]) + max_len = feats_lens.max(0).values.item() + + if self.training: + assert max_len <= self.max_seq_len, "Input length must be smaller than max_seq_len during training" + # set max_len to self.max_seq_len + max_len = self.max_seq_len + # batch input shape B, C, T + batch_shape = [len(feats), feats[0].shape[0], max_len] + batched_inputs = feats[0].new_full(batch_shape, padding_val) + for feat, pad_feat in zip(feats, batched_inputs): + pad_feat[..., :feat.shape[-1]].copy_(feat) + else: + assert len(video_list) == 1, "Only support batch_size = 1 during inference" + # input length < self.max_seq_len, pad to max_seq_len + if max_len <= self.max_seq_len: + max_len = self.max_seq_len + else: + # pad the input to the next divisible size + stride = self.max_div_factor + max_len = (max_len + (stride - 1)) // stride * stride + padding_size = [0, max_len - feats_lens[0]] + batched_inputs = F.pad( + feats[0], padding_size, value=padding_val).unsqueeze(0) + + # generate the mask + batched_masks = torch.arange(max_len)[None, :] < feats_lens[:, None] + + # push to device + batched_inputs = batched_inputs.to(self.device) + batched_masks = batched_masks.unsqueeze(1).to(self.device) + + return batched_inputs, batched_masks + + @torch.no_grad() + def label_points(self, points, gt_segments, gt_labels): + # concat points on all fpn levels List[T x 4] -> F T x 4 + # This is shared for all samples in the mini-batch + num_levels = len(points) + concat_points = torch.cat(points, dim=0) + gt_cls, gt_offset = [], [] + + # loop over each video sample + for gt_segment, gt_label in zip(gt_segments, gt_labels): + cls_targets, reg_targets = self.label_points_single_video( + concat_points, gt_segment, gt_label + ) + # append to list (len = # images, each of size FT x C) + gt_cls.append(cls_targets) + gt_offset.append(reg_targets) + + return gt_cls, gt_offset + + @torch.no_grad() + def label_points_single_video(self, concat_points, gt_segment, gt_label): + # concat_points : F T x 4 (t, regression range, stride) + # gt_segment : N (#Events) x 2 + # gt_label : N (#Events) x 1 + num_pts = concat_points.shape[0] + num_gts = gt_segment.shape[0] + + # corner case where current sample does not have actions + if num_gts == 0: + cls_targets = gt_segment.new_full((num_pts, self.num_classes), 0) + reg_targets = gt_segment.new_zeros((num_pts, 2)) + return cls_targets, reg_targets + + # compute the lengths of all segments -> F T x N + lens = gt_segment[:, 1] - gt_segment[:, 0] + lens = lens[None, :].repeat(num_pts, 1) + + # compute the distance of every point to each segment boundary + # auto broadcasting for all reg target-> F T x N x2 + gt_segs = gt_segment[None].expand(num_pts, num_gts, 2) + left = concat_points[:, 0, None] - gt_segs[:, :, 0] + right = gt_segs[:, :, 1] - concat_points[:, 0, None] + reg_targets = torch.stack((left, right), dim=-1) + + if self.train_center_sample == 'radius': + # center of all segments F T x N + center_pts = 0.5 * (gt_segs[:, :, 0] + gt_segs[:, :, 1]) + # center sampling based on stride radius + # compute the new boundaries: + # concat_points[:, 3] stores the stride + t_mins = \ + center_pts - concat_points[:, 3, None] * self.train_center_sample_radius + t_maxs = \ + center_pts + concat_points[:, 3, None] * self.train_center_sample_radius + # prevent t_mins / maxs from over-running the action boundary + # left: torch.maximum(t_mins, gt_segs[:, :, 0]) + # right: torch.minimum(t_maxs, gt_segs[:, :, 1]) + # F T x N (distance to the new boundary) + cb_dist_left = concat_points[:, 0, None] \ + - torch.maximum(t_mins, gt_segs[:, :, 0]) + cb_dist_right = torch.minimum(t_maxs, gt_segs[:, :, 1]) \ + - concat_points[:, 0, None] + # F T x N x 2 + center_seg = torch.stack( + (cb_dist_left, cb_dist_right), -1) + # F T x N + inside_gt_seg_mask = center_seg.min(-1)[0] > 0 + else: + # inside an gt action + inside_gt_seg_mask = reg_targets.min(-1)[0] > 0 + + # limit the regression range for each location + max_regress_distance = reg_targets.max(-1)[0] + # F T x N + inside_regress_range = torch.logical_and( + (max_regress_distance >= concat_points[:, 1, None]), + (max_regress_distance <= concat_points[:, 2, None]) + ) + + # if there are still more than one actions for one moment + # pick the one with the shortest duration (easiest to regress) + lens.masked_fill_(inside_gt_seg_mask==0, float('inf')) + lens.masked_fill_(inside_regress_range==0, float('inf')) + # F T x N -> F T + min_len, min_len_inds = lens.min(dim=1) + + # corner case: multiple actions with very similar durations (e.g., THUMOS14) + min_len_mask = torch.logical_and( + (lens <= (min_len[:, None] + 1e-3)), (lens < float('inf')) + ).to(reg_targets.dtype) + + # cls_targets: F T x C; reg_targets F T x 2 + gt_label_one_hot = F.one_hot( + gt_label, self.num_classes + ).to(reg_targets.dtype) + cls_targets = min_len_mask @ gt_label_one_hot + # to prevent multiple GT actions with the same label and boundaries + cls_targets.clamp_(min=0.0, max=1.0) + # OK to use min_len_inds + reg_targets = reg_targets[range(num_pts), min_len_inds] + # normalization based on stride + reg_targets /= concat_points[:, 3, None] + + return cls_targets, reg_targets + + def losses( + self, fpn_masks, + out_cls_logits, out_offsets, + gt_cls_labels, gt_offsets + ): + # fpn_masks, out_*: F (List) [B, T_i, C] + # gt_* : B (list) [F T, C] + # fpn_masks -> (B, FT) + valid_mask = torch.cat(fpn_masks, dim=1) + + # 1. classification loss + # stack the list -> (B, FT) -> (# Valid, ) + gt_cls = torch.stack(gt_cls_labels) + pos_mask = torch.logical_and((gt_cls.sum(-1) > 0), valid_mask) + + # cat the predicted offsets -> (B, FT, 2 (xC)) -> # (#Pos, 2 (xC)) + pred_offsets = torch.cat(out_offsets, dim=1)[pos_mask] + gt_offsets = torch.stack(gt_offsets)[pos_mask] + + # update the loss normalizer + num_pos = pos_mask.sum().item() + self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + ( + 1 - self.loss_normalizer_momentum + ) * max(num_pos, 1) + + # gt_cls is already one hot encoded now, simply masking out + gt_target = gt_cls[valid_mask] + + # optinal label smoothing + gt_target *= 1 - self.train_label_smoothing + gt_target += self.train_label_smoothing / (self.num_classes + 1) + + # focal loss + cls_loss = sigmoid_focal_loss( + torch.cat(out_cls_logits, dim=1)[valid_mask], + gt_target, + reduction='sum' + ) + cls_loss /= self.loss_normalizer + + # 2. regression using IoU/GIoU loss (defined on positive samples) + if num_pos == 0: + reg_loss = 0 * pred_offsets.sum() + else: + # giou loss defined on positive samples + reg_loss = ctr_diou_loss_1d( + pred_offsets, + gt_offsets, + reduction='sum' + ) + reg_loss /= self.loss_normalizer + + if self.train_loss_weight > 0: + loss_weight = self.train_loss_weight + else: + loss_weight = cls_loss.detach() / max(reg_loss.item(), 0.01) + + # return a dict of losses + final_loss = cls_loss + reg_loss * loss_weight + return {'cls_loss' : cls_loss, + 'reg_loss' : reg_loss, + 'final_loss' : final_loss} + + @torch.no_grad() + def inference( + self, + video_list, + points, fpn_masks, + out_cls_logits, out_offsets + ): + # video_list B (list) [dict] + # points F (list) [T_i, 4] + # fpn_masks, out_*: F (List) [B, T_i, C] + results = [] + + # 1: gather video meta information + vid_idxs = [x['video_id'] for x in video_list] + vid_fps = [x['fps'] for x in video_list] + vid_lens = [x['duration'] for x in video_list] + vid_ft_stride = [x['feat_stride'] for x in video_list] + vid_ft_nframes = [x['feat_num_frames'] for x in video_list] + + # 2: inference on each single video and gather the results + # upto this point, all results use timestamps defined on feature grids + for idx, (vidx, fps, vlen, stride, nframes) in enumerate( + zip(vid_idxs, vid_fps, vid_lens, vid_ft_stride, vid_ft_nframes) + ): + # gather per-video outputs + cls_logits_per_vid = [x[idx] for x in out_cls_logits] + offsets_per_vid = [x[idx] for x in out_offsets] + fpn_masks_per_vid = [x[idx] for x in fpn_masks] + # inference on a single video (should always be the case) + results_per_vid = self.inference_single_video( + points, fpn_masks_per_vid, + cls_logits_per_vid, offsets_per_vid + ) + # pass through video meta info + results_per_vid['video_id'] = vidx + results_per_vid['fps'] = fps + results_per_vid['duration'] = vlen + results_per_vid['feat_stride'] = stride + results_per_vid['feat_num_frames'] = nframes + results.append(results_per_vid) + + # step 3: postprocssing + results = self.postprocessing(results) + + return results + + @torch.no_grad() + def inference_single_video( + self, + points, + fpn_masks, + out_cls_logits, + out_offsets, + ): + # points F (list) [T_i, 4] + # fpn_masks, out_*: F (List) [T_i, C] + segs_all = [] + scores_all = [] + cls_idxs_all = [] + + # loop over fpn levels + for cls_i, offsets_i, pts_i, mask_i in zip( + out_cls_logits, out_offsets, points, fpn_masks + ): + # sigmoid normalization for output logits + pred_prob = (cls_i.sigmoid() * mask_i.unsqueeze(-1)).flatten() + + # Apply filtering to make NMS faster following detectron2 + # 1. Keep seg with confidence score > a threshold + keep_idxs1 = (pred_prob > self.test_pre_nms_thresh) + pred_prob = pred_prob[keep_idxs1] + topk_idxs = keep_idxs1.nonzero(as_tuple=True)[0] + + # 2. Keep top k top scoring boxes only + num_topk = min(self.test_pre_nms_topk, topk_idxs.size(0)) + pred_prob, idxs = pred_prob.sort(descending=True) + pred_prob = pred_prob[:num_topk].clone() + topk_idxs = topk_idxs[idxs[:num_topk]].clone() + + # fix a warning in pytorch 1.9 + pt_idxs = torch.div( + topk_idxs, self.num_classes, rounding_mode='floor' + ) + cls_idxs = torch.fmod(topk_idxs, self.num_classes) + + # 3. gather predicted offsets + offsets = offsets_i[pt_idxs] + pts = pts_i[pt_idxs] + + # 4. compute predicted segments (denorm by stride for output offsets) + seg_left = pts[:, 0] - offsets[:, 0] * pts[:, 3] + seg_right = pts[:, 0] + offsets[:, 1] * pts[:, 3] + pred_segs = torch.stack((seg_left, seg_right), -1) + + # 5. Keep seg with duration > a threshold (relative to feature grids) + seg_areas = seg_right - seg_left + keep_idxs2 = seg_areas > self.test_duration_thresh + + # *_all : N (filtered # of segments) x 2 / 1 + segs_all.append(pred_segs[keep_idxs2]) + scores_all.append(pred_prob[keep_idxs2]) + cls_idxs_all.append(cls_idxs[keep_idxs2]) + + # cat along the FPN levels (F N_i, C) + segs_all, scores_all, cls_idxs_all = [ + torch.cat(x) for x in [segs_all, scores_all, cls_idxs_all] + ] + results = {'segments' : segs_all, + 'scores' : scores_all, + 'labels' : cls_idxs_all} + + return results + + @torch.no_grad() + def postprocessing(self, results): + # input : list of dictionary items + # (1) push to CPU; (2) NMS; (3) convert to actual time stamps + processed_results = [] + for results_per_vid in results: + # unpack the meta info + vidx = results_per_vid['video_id'] + fps = results_per_vid['fps'] + vlen = results_per_vid['duration'] + stride = results_per_vid['feat_stride'] + nframes = results_per_vid['feat_num_frames'] + # 1: unpack the results and move to CPU + segs = results_per_vid['segments'].detach().cpu() + scores = results_per_vid['scores'].detach().cpu() + labels = results_per_vid['labels'].detach().cpu() + if self.test_nms_method != 'none': + # 2: batched nms (only implemented on CPU) + segs, scores, labels = batched_nms( + segs, scores, labels, + self.test_iou_threshold, + self.test_min_score, + self.test_max_seg_num, + use_soft_nms = (self.test_nms_method == 'soft'), + multiclass = self.test_multiclass_nms, + sigma = self.test_nms_sigma, + voting_thresh = self.test_voting_thresh + ) + # 3: convert from feature grids to seconds + if segs.shape[0] > 0: + segs = (segs * stride + 0.5 * nframes) / fps + # truncate all boundaries within [0, duration] + segs[segs<=0.0] *= 0.0 + segs[segs>=vlen] = segs[segs>=vlen] * 0.0 + vlen + + # 4: repack the results + processed_results.append( + {'video_id' : vidx, + 'segments' : segs, + 'scores' : scores, + 'labels' : labels} + ) + + return processed_results \ No newline at end of file diff --git a/code/actionformer_release/libs/modeling/models.py b/code/actionformer_release/libs/modeling/models.py new file mode 100644 index 0000000..abae35d --- /dev/null +++ b/code/actionformer_release/libs/modeling/models.py @@ -0,0 +1,50 @@ +import os + +# backbone (e.g., conv / transformer) +backbones = {} +def register_backbone(name): + def decorator(cls): + backbones[name] = cls + return cls + return decorator + +# neck (e.g., FPN) +necks = {} +def register_neck(name): + def decorator(cls): + necks[name] = cls + return cls + return decorator + +# location generator (point, segment, etc) +generators = {} +def register_generator(name): + def decorator(cls): + generators[name] = cls + return cls + return decorator + +# meta arch (the actual implementation of each model) +meta_archs = {} +def register_meta_arch(name): + def decorator(cls): + meta_archs[name] = cls + return cls + return decorator + +# builder functions +def make_backbone(name, **kwargs): + backbone = backbones[name](**kwargs) + return backbone + +def make_neck(name, **kwargs): + neck = necks[name](**kwargs) + return neck + +def make_meta_arch(name, **kwargs): + meta_arch = meta_archs[name](**kwargs) + return meta_arch + +def make_generator(name, **kwargs): + generator = generators[name](**kwargs) + return generator diff --git a/code/actionformer_release/libs/modeling/necks.py b/code/actionformer_release/libs/modeling/necks.py new file mode 100644 index 0000000..f3d40ee --- /dev/null +++ b/code/actionformer_release/libs/modeling/necks.py @@ -0,0 +1,143 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .models import register_neck +from .blocks import MaskedConv1D, LayerNorm + +@register_neck("fpn") +class FPN1D(nn.Module): + """ + Feature pyramid network + """ + def __init__( + self, + in_channels, # input feature channels, len(in_channels) = # levels + out_channel, # output feature channel + scale_factor=2.0, # downsampling rate between two fpn levels + start_level=0, # start fpn level + end_level=-1, # end fpn level + with_ln=True, # if to apply layer norm at the end + ): + super().__init__() + assert isinstance(in_channels, list) or isinstance(in_channels, tuple) + + self.in_channels = in_channels + self.out_channel = out_channel + self.scale_factor = scale_factor + + self.start_level = start_level + if end_level == -1: + self.end_level = len(in_channels) + else: + self.end_level = end_level + assert self.end_level <= len(in_channels) + assert (self.start_level >= 0) and (self.start_level < self.end_level) + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + self.fpn_norms = nn.ModuleList() + for i in range(self.start_level, self.end_level): + # disable bias if using layer norm + l_conv = MaskedConv1D( + in_channels[i], out_channel, 1, bias=(not with_ln) + ) + # use depthwise conv here for efficiency + fpn_conv = MaskedConv1D( + out_channel, out_channel, 3, + padding=1, bias=(not with_ln), groups=out_channel + ) + # layer norm for order (B C T) + if with_ln: + fpn_norm = LayerNorm(out_channel) + else: + fpn_norm = nn.Identity() + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + self.fpn_norms.append(fpn_norm) + + def forward(self, inputs, fpn_masks): + # inputs must be a list / tuple + assert len(inputs) == len(self.in_channels) + assert len(fpn_masks) == len(self.in_channels) + + # build laterals, fpn_masks will remain the same with 1x1 convs + laterals = [] + for i in range(len(self.lateral_convs)): + x, _ = self.lateral_convs[i]( + inputs[i + self.start_level], fpn_masks[i + self.start_level] + ) + laterals.append(x) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + laterals[i - 1] += F.interpolate( + laterals[i], scale_factor=self.scale_factor, mode='nearest' + ) + + # fpn conv / norm -> outputs + # mask will remain the same + fpn_feats = tuple() + new_fpn_masks = tuple() + for i in range(used_backbone_levels): + x, new_mask = self.fpn_convs[i]( + laterals[i], fpn_masks[i + self.start_level]) + x = self.fpn_norms[i](x) + fpn_feats += (x, ) + new_fpn_masks += (new_mask, ) + + return fpn_feats, new_fpn_masks + + +@register_neck('identity') +class FPNIdentity(nn.Module): + def __init__( + self, + in_channels, # input feature channels, len(in_channels) = #levels + out_channel, # output feature channel + scale_factor=2.0, # downsampling rate between two fpn levels + start_level=0, # start fpn level + end_level=-1, # end fpn level + with_ln=True, # if to apply layer norm at the end + ): + super().__init__() + + self.in_channels = in_channels + self.out_channel = out_channel + self.scale_factor = scale_factor + + self.start_level = start_level + if end_level == -1: + self.end_level = len(in_channels) + else: + self.end_level = end_level + assert self.end_level <= len(in_channels) + assert (self.start_level >= 0) and (self.start_level < self.end_level) + + self.fpn_norms = nn.ModuleList() + for i in range(self.start_level, self.end_level): + # check feat dims + assert self.in_channels[i] == self.out_channel + # layer norm for order (B C T) + if with_ln: + fpn_norm = LayerNorm(out_channel) + else: + fpn_norm = nn.Identity() + self.fpn_norms.append(fpn_norm) + + def forward(self, inputs, fpn_masks): + # inputs must be a list / tuple + assert len(inputs) == len(self.in_channels) + assert len(fpn_masks) == len(self.in_channels) + + # apply norms, fpn_masks will remain the same with 1x1 convs + fpn_feats = tuple() + new_fpn_masks = tuple() + for i in range(len(self.fpn_norms)): + x = self.fpn_norms[i](inputs[i + self.start_level]) + fpn_feats += (x, ) + new_fpn_masks += (fpn_masks[i + self.start_level], ) + + return fpn_feats, new_fpn_masks diff --git a/code/actionformer_release/libs/modeling/weight_init.py b/code/actionformer_release/libs/modeling/weight_init.py new file mode 100644 index 0000000..3e5c8b7 --- /dev/null +++ b/code/actionformer_release/libs/modeling/weight_init.py @@ -0,0 +1,61 @@ +# from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py +import torch +import math +import warnings + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/code/actionformer_release/libs/utils/__init__.py b/code/actionformer_release/libs/utils/__init__.py new file mode 100644 index 0000000..00f24e8 --- /dev/null +++ b/code/actionformer_release/libs/utils/__init__.py @@ -0,0 +1,10 @@ +from .nms import batched_nms +from .metrics import ANETdetection, remove_duplicate_annotations +from .train_utils import (make_optimizer, make_scheduler, save_checkpoint, + AverageMeter, train_one_epoch, valid_one_epoch, + fix_random_seed, ModelEma) +from .postprocessing import postprocess_results + +__all__ = ['batched_nms', 'make_optimizer', 'make_scheduler', 'save_checkpoint', + 'AverageMeter', 'train_one_epoch', 'valid_one_epoch', 'ANETdetection', + 'postprocess_results', 'fix_random_seed', 'ModelEma', 'remove_duplicate_annotations'] diff --git a/code/actionformer_release/libs/utils/csrc/nms_cpu.cpp b/code/actionformer_release/libs/utils/csrc/nms_cpu.cpp new file mode 100644 index 0000000..d8e8cad --- /dev/null +++ b/code/actionformer_release/libs/utils/csrc/nms_cpu.cpp @@ -0,0 +1,182 @@ +#include +#include +#include +#include + +// 1D NMS (CPU) helper functions, ported from +// https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/nms.cpp + +using namespace at; + +#define CHECK_CPU(x) \ + TORCH_CHECK(!x.device().is_cuda(), #x " must be a CPU tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CPU_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_CONTIGUOUS(x) + +Tensor nms_1d_cpu(Tensor segs, Tensor scores, float iou_threshold) { + if (segs.numel() == 0) { + return at::empty({0}, segs.options().dtype(at::kLong)); + } + auto x1_t = segs.select(1, 0).contiguous(); + auto x2_t = segs.select(1, 1).contiguous(); + + Tensor areas_t = x2_t - x1_t + 1e-6; + + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + + auto nsegs = segs.size(0); + Tensor select_t = at::ones({nsegs}, segs.options().dtype(at::kBool)); + + auto select = select_t.data_ptr(); + auto order = order_t.data_ptr(); + auto x1 = x1_t.data_ptr(); + auto x2 = x2_t.data_ptr(); + auto areas = areas_t.data_ptr(); + + for (int64_t _i = 0; _i < nsegs; _i++) { + if (select[_i] == false) continue; + auto i = order[_i]; + auto ix1 = x1[i]; + auto ix2 = x2[i]; + auto iarea = areas[i]; + + for (int64_t _j = _i + 1; _j < nsegs; _j++) { + if (select[_j] == false) continue; + auto j = order[_j]; + auto xx1 = std::max(ix1, x1[j]); + auto xx2 = std::min(ix2, x2[j]); + + auto inter = std::max(0.f, xx2 - xx1); + auto ovr = inter / (iarea + areas[j] - inter); + if (ovr >= iou_threshold) select[_j] = false; + } + } + return order_t.masked_select(select_t); +} + +Tensor nms_1d(Tensor segs, Tensor scores, float iou_threshold) { + CHECK_CPU_INPUT(segs); + CHECK_CPU_INPUT(scores); + return nms_1d_cpu(segs, scores, iou_threshold); + +} + +Tensor softnms_1d_cpu(Tensor segs, Tensor scores, Tensor dets, float iou_threshold, + float sigma, float min_score, int method) { + if (segs.numel() == 0) { + return at::empty({0}, segs.options().dtype(at::kLong)); + } + + auto x1_t = segs.select(1, 0).contiguous(); + auto x2_t = segs.select(1, 1).contiguous(); + auto scores_t = scores.clone(); + + Tensor areas_t = x2_t - x1_t + 1e-6; + + auto nsegs = segs.size(0); + auto x1 = x1_t.data_ptr(); + auto x2 = x2_t.data_ptr(); + auto sc = scores_t.data_ptr(); + auto areas = areas_t.data_ptr(); + auto de = dets.data_ptr(); + + int64_t pos = 0; + Tensor inds_t = at::arange(nsegs, segs.options().dtype(at::kLong)); + auto inds = inds_t.data_ptr(); + + for (int64_t i = 0; i < nsegs; i++) { + auto max_score = sc[i]; + auto max_pos = i; + + // get seg with max score + pos = i + 1; + while (pos < nsegs) { + if (max_score < sc[pos]) { + max_score = sc[pos]; + max_pos = pos; + } + pos = pos + 1; + } + // swap the current seg (i) and the seg with max score (max_pos) + auto ix1 = de[i * 3 + 0] = x1[max_pos]; + auto ix2 = de[i * 3 + 1] = x2[max_pos]; + auto iscore = de[i * 3 + 2] = sc[max_pos]; + auto iarea = areas[max_pos]; + auto iind = inds[max_pos]; + + x1[max_pos] = x1[i]; + x2[max_pos] = x2[i]; + sc[max_pos] = sc[i]; + areas[max_pos] = areas[i]; + inds[max_pos] = inds[i]; + + x1[i] = ix1; + x2[i] = ix2; + sc[i] = iscore; + areas[i] = iarea; + inds[i] = iind; + + // reset pos + pos = i + 1; + while (pos < nsegs) { + auto xx1 = std::max(ix1, x1[pos]); + auto xx2 = std::min(ix2, x2[pos]); + + auto inter = std::max(0.f, xx2 - xx1); + auto ovr = inter / (iarea + areas[pos] - inter); + + float weight = 1.; + if (method == 0) { + // vanilla nms + if (ovr >= iou_threshold) weight = 0; + } else if (method == 1) { + // linear + if (ovr >= iou_threshold) weight = 1 - ovr; + } else if (method == 2) { + // gaussian + weight = std::exp(-(ovr * ovr) / sigma); + } + sc[pos] *= weight; + + // if the score falls below threshold, discard the segment by + // swapping with last seg update N + if (sc[pos] < min_score) { + x1[pos] = x1[nsegs - 1]; + x2[pos] = x2[nsegs - 1]; + sc[pos] = sc[nsegs - 1]; + areas[pos] = areas[nsegs - 1]; + inds[pos] = inds[nsegs - 1]; + nsegs = nsegs - 1; + pos = pos - 1; + } + + pos = pos + 1; + } + } + return inds_t.slice(0, 0, nsegs); +} + +Tensor softnms_1d(Tensor segs, Tensor scores, Tensor dets, float iou_threshold, + float sigma, float min_score, int method) { + // softnms is not implemented on GPU + CHECK_CPU_INPUT(segs) + CHECK_CPU_INPUT(scores) + CHECK_CPU_INPUT(dets) + return softnms_1d_cpu(segs, scores, dets, iou_threshold, sigma, min_score, method); +} + +// bind to torch interface +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "nms", &nms_1d, "nms (CPU) ", + py::arg("segs"), py::arg("scores"), py::arg("iou_threshold") + ); + m.def( + "softnms", &softnms_1d, "softnms (CPU) ", + py::arg("segs"), py::arg("scores"), py::arg("dets"), py::arg("iou_threshold"), + py::arg("sigma"), py::arg("min_score"), py::arg("method") + ); +} diff --git a/code/actionformer_release/libs/utils/lr_schedulers.py b/code/actionformer_release/libs/utils/lr_schedulers.py new file mode 100644 index 0000000..3290fd8 --- /dev/null +++ b/code/actionformer_release/libs/utils/lr_schedulers.py @@ -0,0 +1,211 @@ +import math +import warnings +from collections import Counter +from bisect import bisect_right + +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class LinearWarmupCosineAnnealingLR(_LRScheduler): + """ + Sets the learning rate of each parameter group to follow a linear warmup schedule + between warmup_start_lr and base_lr followed by a cosine annealing schedule between + base_lr and eta_min. + + .. warning:: + It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` + after each iteration as calling it after each epoch will keep the starting lr at + warmup_start_lr for the first epoch which is 0 in most cases. + + .. warning:: + passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. + It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of + :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing + epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling + train and validation methods. + + Example: + >>> layer = nn.Linear(10, 1) + >>> optimizer = Adam(layer.parameters(), lr=0.02) + >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) + >>> # + >>> # the default case + >>> for epoch in range(40): + ... # train(...) + ... # validate(...) + ... scheduler.step() + >>> # + >>> # passing epoch param case + >>> for epoch in range(40): + ... scheduler.step(epoch) + ... # train(...) + ... # validate(...) + """ + + def __init__( + self, + optimizer, + warmup_epochs, + max_epochs, + warmup_start_lr = 0.0, + eta_min = 1e-8, + last_epoch = -1, + ): + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_epochs (int): Maximum number of iterations for linear warmup + max_epochs (int): Maximum number of iterations + warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + """ + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.warmup_start_lr = warmup_start_lr + self.eta_min = eta_min + + super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + """ + Compute learning rate using chainable form of the scheduler + """ + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + ) + + if self.last_epoch == 0: + return [self.warmup_start_lr] * len(self.base_lrs) + elif self.last_epoch < self.warmup_epochs: + return [ + group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif self.last_epoch == self.warmup_epochs: + return self.base_lrs + elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: + return [ + group["lr"] + (base_lr - self.eta_min) * + (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + + return [ + (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) / + ( + 1 + + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)) + ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + """ + Called when epoch is passed as a param to the `step` function of the scheduler. + """ + if self.last_epoch < self.warmup_epochs: + return [ + self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr in self.base_lrs + ] + + return [ + self.eta_min + 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + for base_lr in self.base_lrs + ] + + +class LinearWarmupMultiStepLR(_LRScheduler): + """ + Sets the learning rate of each parameter group to follow a linear warmup schedule + between warmup_start_lr and base_lr followed by a multi-step schedule that decays + the learning rate of each parameter group by gamma once the + number of epoch reaches one of the milestones. + + .. warning:: + It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` + after each iteration as calling it after each epoch will keep the starting lr at + warmup_start_lr for the first epoch which is 0 in most cases. + + .. warning:: + passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. + It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of + :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing + epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling + train and validation methods. + """ + + def __init__( + self, + optimizer, + warmup_epochs, + milestones, + warmup_start_lr = 0.0, + gamma = 0.1, + last_epoch = -1, + ): + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_epochs (int): Maximum number of iterations for linear warmup + max_epochs (int): Maximum number of iterations + milestones (list): List of epoch indices. Must be increasing. + warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + """ + self.warmup_epochs = warmup_epochs + self.warmup_start_lr = warmup_start_lr + self.milestones = Counter(milestones) + self.gamma = gamma + + super(LinearWarmupMultiStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + """ + Compute learning rate using chainable form of the scheduler + """ + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + # starting warm up + return [self.warmup_start_lr] * len(self.base_lrs) + elif self.last_epoch < self.warmup_epochs: + # linear warm up (0 ~ self.warmup_epochs -1) + return [ + group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif self.last_epoch == self.warmup_epochs: + # end of warm up (reset to base lrs) + return self.base_lrs + elif (self.last_epoch - self.warmup_epochs) not in self.milestones: + # in between the steps + return [group['lr'] for group in self.optimizer.param_groups] + + return [ + group['lr'] * self.gamma ** self.milestones[self.last_epoch - self.warmup_epochs] + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + """ + Called when epoch is passed as a param to the `step` function of the scheduler. + """ + if self.last_epoch < self.warmup_epochs: + return [ + self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr in self.base_lrs + ] + + milestones = list(sorted(self.milestones.elements())) + return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch - self.warmup_epochs) + for base_lr in self.base_lrs] diff --git a/code/actionformer_release/libs/utils/metrics.py b/code/actionformer_release/libs/utils/metrics.py new file mode 100644 index 0000000..e021e5b --- /dev/null +++ b/code/actionformer_release/libs/utils/metrics.py @@ -0,0 +1,445 @@ +# Modified from official EPIC-Kitchens action detection evaluation code +# see https://github.com/epic-kitchens/C2-Action-Detection/blob/master/EvaluationCode/evaluate_detection_json_ek100.py +import os +import json +import pandas as pd +import numpy as np +from joblib import Parallel, delayed +from typing import List +from typing import Tuple +from typing import Dict + + +def remove_duplicate_annotations(ants, tol=1e-3): + # remove duplicate / very short annotations (same category and starting/ending time) + valid_events = [] + for event in ants: + s, e, l = event['segment'][0], event['segment'][1], event['label_id'] + if (e - s) >= tol: + valid = True + else: + valid = False + for p_event in valid_events: + if ((abs(s-p_event['segment'][0]) <= tol) + and (abs(e-p_event['segment'][1]) <= tol) + and (l == p_event['label_id']) + ): + valid = False + break + if valid: + valid_events.append(event) + return valid_events + + +def load_gt_seg_from_json(json_file, split=None, label='label_id', label_offset=0): + # load json file + with open(json_file, "r", encoding="utf8") as f: + json_db = json.load(f) + json_db = json_db['database'] + + vids, starts, stops, labels = [], [], [], [] + for k, v in json_db.items(): + + # filter based on split + if (split is not None) and v['subset'].lower() != split: + continue + # remove duplicated instances + ants = remove_duplicate_annotations(v['annotations']) + # video id + vids += [k] * len(ants) + # for each event, grab the start/end time and label + for event in ants: + starts += [float(event['segment'][0])] + stops += [float(event['segment'][1])] + if isinstance(event[label], (Tuple, List)): + # offset the labels by label_offset + label_id = 0 + for i, x in enumerate(event[label][::-1]): + label_id += label_offset**i + int(x) + else: + # load label_id directly + label_id = int(event[label]) + labels += [label_id] + + # move to pd dataframe + gt_base = pd.DataFrame({ + 'video-id' : vids, + 't-start' : starts, + 't-end': stops, + 'label': labels + }) + + return gt_base + + +def load_pred_seg_from_json(json_file, label='label_id', label_offset=0): + # load json file + with open(json_file, "r", encoding="utf8") as f: + json_db = json.load(f) + json_db = json_db['database'] + + vids, starts, stops, labels, scores = [], [], [], [], [] + for k, v, in json_db.items(): + # video id + vids += [k] * len(v) + # for each event + for event in v: + starts += [float(event['segment'][0])] + stops += [float(event['segment'][1])] + if isinstance(event[label], (Tuple, List)): + # offset the labels by label_offset + label_id = 0 + for i, x in enumerate(event[label][::-1]): + label_id += label_offset**i + int(x) + else: + # load label_id directly + label_id = int(event[label]) + labels += [label_id] + scores += [float(event['scores'])] + + # move to pd dataframe + pred_base = pd.DataFrame({ + 'video-id' : vids, + 't-start' : starts, + 't-end': stops, + 'label': labels, + 'score': scores + }) + + return pred_base + + +class ANETdetection(object): + """Adapted from https://github.com/activitynet/ActivityNet/blob/master/Evaluation/eval_detection.py""" + + def __init__( + self, + ant_file, + split=None, + tiou_thresholds=np.linspace(0.1, 0.5, 5), + top_k=(1, 5), + label='label_id', + label_offset=0, + num_workers=8, + dataset_name=None, + ): + + self.tiou_thresholds = tiou_thresholds + self.top_k = top_k + self.ap = None + self.num_workers = num_workers + if dataset_name is not None: + self.dataset_name = dataset_name + else: + self.dataset_name = os.path.basename(ant_file).replace('.json', '') + + # Import ground truth and predictions + self.split = split + self.ground_truth = load_gt_seg_from_json( + ant_file, split=self.split, label=label, label_offset=label_offset) + + # remove labels that does not exists in gt + self.activity_index = {j: i for i, j in enumerate(sorted(self.ground_truth['label'].unique()))} + self.ground_truth['label']=self.ground_truth['label'].replace(self.activity_index) + + def _get_predictions_with_label(self, prediction_by_label, label_name, cidx): + """Get all predicitons of the given label. Return empty DataFrame if there + is no predcitions with the given label. + """ + try: + res = prediction_by_label.get_group(cidx).reset_index(drop=True) + return res + except: + print('Warning: No predictions of label \'%s\' were provdied.' % label_name) + return pd.DataFrame() + + def wrapper_compute_average_precision(self, preds): + """Computes average precision for each class in the subset. + """ + ap = np.zeros((len(self.tiou_thresholds), len(self.activity_index))) + + # Adaptation to query faster + ground_truth_by_label = self.ground_truth.groupby('label') + prediction_by_label = preds.groupby('label') + + results = Parallel(n_jobs=self.num_workers)( + delayed(compute_average_precision_detection)( + ground_truth=ground_truth_by_label.get_group(cidx).reset_index(drop=True), + prediction=self._get_predictions_with_label(prediction_by_label, label_name, cidx), + tiou_thresholds=self.tiou_thresholds, + ) for label_name, cidx in self.activity_index.items()) + + for i, cidx in enumerate(self.activity_index.values()): + ap[:,cidx] = results[i] + + return ap + + def wrapper_compute_topkx_recall(self, preds): + """Computes Top-kx recall for each class in the subset. + """ + recall = np.zeros((len(self.tiou_thresholds), len(self.top_k), len(self.activity_index))) + + # Adaptation to query faster + ground_truth_by_label = self.ground_truth.groupby('label') + prediction_by_label = preds.groupby('label') + + results = Parallel(n_jobs=self.num_workers)( + delayed(compute_topkx_recall_detection)( + ground_truth=ground_truth_by_label.get_group(cidx).reset_index(drop=True), + prediction=self._get_predictions_with_label(prediction_by_label, label_name, cidx), + tiou_thresholds=self.tiou_thresholds, + top_k=self.top_k, + ) for label_name, cidx in self.activity_index.items()) + + for i, cidx in enumerate(self.activity_index.values()): + recall[...,cidx] = results[i] + + return recall + + def evaluate(self, preds, verbose=True): + """Evaluates a prediction file. For the detection task we measure the + interpolated mean average precision to measure the performance of a + method. + preds can be (1) a pd.DataFrame; or (2) a json file where the data will be loaded; + or (3) a python dict item with numpy arrays as the values + """ + + if isinstance(preds, pd.DataFrame): + assert 'label' in preds + elif isinstance(preds, str) and os.path.isfile(preds): + preds = load_pred_seg_from_json(preds) + elif isinstance(preds, Dict): + # move to pd dataframe + # did not check dtype here, can accept both numpy / pytorch tensors + preds = pd.DataFrame({ + 'video-id' : preds['video-id'], + 't-start' : preds['t-start'].tolist(), + 't-end': preds['t-end'].tolist(), + 'label': preds['label'].tolist(), + 'score': preds['score'].tolist() + }) + # always reset ap + self.ap = None + + # make the label ids consistent + preds['label'] = preds['label'].replace(self.activity_index) + + # compute mAP + self.ap = self.wrapper_compute_average_precision(preds) + self.recall = self.wrapper_compute_topkx_recall(preds) + mAP = self.ap.mean(axis=1) + mRecall = self.recall.mean(axis=2) + average_mAP = mAP.mean() + + # print results + if verbose: + # print the results + print('[RESULTS] Action detection results on {:s}.'.format( + self.dataset_name) + ) + block = '' + for tiou, tiou_mAP, tiou_mRecall in zip(self.tiou_thresholds, mAP, mRecall): + block += '\n|tIoU = {:.2f}: '.format(tiou) + block += 'mAP = {:>4.2f} (%) '.format(tiou_mAP*100) + for idx, k in enumerate(self.top_k): + block += 'Recall@{:d}x = {:>4.2f} (%) '.format(k, tiou_mRecall[idx]*100) + print(block) + print('Average mAP: {:>4.2f} (%)'.format(average_mAP*100)) + + # return the results + return mAP, average_mAP, mRecall + + +def compute_average_precision_detection( + ground_truth, + prediction, + tiou_thresholds=np.linspace(0.1, 0.5, 5) +): + """Compute average precision (detection task) between ground truth and + predictions data frames. If multiple predictions occurs for the same + predicted segment, only the one with highest score is matches as + true positive. This code is greatly inspired by Pascal VOC devkit. + Parameters + ---------- + ground_truth : df + Data frame containing the ground truth instances. + Required fields: ['video-id', 't-start', 't-end'] + prediction : df + Data frame containing the prediction instances. + Required fields: ['video-id, 't-start', 't-end', 'score'] + tiou_thresholds : 1darray, optional + Temporal intersection over union threshold. + Outputs + ------- + ap : float + Average precision score. + """ + ap = np.zeros(len(tiou_thresholds)) + if prediction.empty: + return ap + + npos = float(len(ground_truth)) + lock_gt = np.ones((len(tiou_thresholds),len(ground_truth))) * -1 + # Sort predictions by decreasing score order. + sort_idx = prediction['score'].values.argsort()[::-1] + prediction = prediction.loc[sort_idx].reset_index(drop=True) + + # Initialize true positive and false positive vectors. + tp = np.zeros((len(tiou_thresholds), len(prediction))) + fp = np.zeros((len(tiou_thresholds), len(prediction))) + + # Adaptation to query faster + ground_truth_gbvn = ground_truth.groupby('video-id') + + # Assigning true positive to truly ground truth instances. + for idx, this_pred in prediction.iterrows(): + + try: + # Check if there is at least one ground truth in the video associated. + ground_truth_videoid = ground_truth_gbvn.get_group(this_pred['video-id']) + except Exception as e: + fp[:, idx] = 1 + continue + + this_gt = ground_truth_videoid.reset_index() + tiou_arr = segment_iou(this_pred[['t-start', 't-end']].values, + this_gt[['t-start', 't-end']].values) + # We would like to retrieve the predictions with highest tiou score. + tiou_sorted_idx = tiou_arr.argsort()[::-1] + for tidx, tiou_thr in enumerate(tiou_thresholds): + for jdx in tiou_sorted_idx: + if tiou_arr[jdx] < tiou_thr: + fp[tidx, idx] = 1 + break + if lock_gt[tidx, this_gt.loc[jdx]['index']] >= 0: + continue + # Assign as true positive after the filters above. + tp[tidx, idx] = 1 + lock_gt[tidx, this_gt.loc[jdx]['index']] = idx + break + + if fp[tidx, idx] == 0 and tp[tidx, idx] == 0: + fp[tidx, idx] = 1 + + tp_cumsum = np.cumsum(tp, axis=1).astype(float) + fp_cumsum = np.cumsum(fp, axis=1).astype(float) + recall_cumsum = tp_cumsum / npos + + precision_cumsum = tp_cumsum / (tp_cumsum + fp_cumsum) + + for tidx in range(len(tiou_thresholds)): + ap[tidx] = interpolated_prec_rec(precision_cumsum[tidx,:], recall_cumsum[tidx,:]) + + return ap + + +def compute_topkx_recall_detection( + ground_truth, + prediction, + tiou_thresholds=np.linspace(0.1, 0.5, 5), + top_k=(1, 5), +): + """Compute recall (detection task) between ground truth and + predictions data frames. If multiple predictions occurs for the same + predicted segment, only the one with highest score is matches as + true positive. This code is greatly inspired by Pascal VOC devkit. + Parameters + ---------- + ground_truth : df + Data frame containing the ground truth instances. + Required fields: ['video-id', 't-start', 't-end'] + prediction : df + Data frame containing the prediction instances. + Required fields: ['video-id, 't-start', 't-end', 'score'] + tiou_thresholds : 1darray, optional + Temporal intersection over union threshold. + top_k: tuple, optional + Top-kx results of a action category where x stands for the number of + instances for the action category in the video. + Outputs + ------- + recall : float + Recall score. + """ + if prediction.empty: + return np.zeros((len(tiou_thresholds), len(top_k))) + + # Initialize true positive vectors. + tp = np.zeros((len(tiou_thresholds), len(top_k))) + n_gts = 0 + + # Adaptation to query faster + ground_truth_gbvn = ground_truth.groupby('video-id') + prediction_gbvn = prediction.groupby('video-id') + + for videoid, _ in ground_truth_gbvn.groups.items(): + ground_truth_videoid = ground_truth_gbvn.get_group(videoid) + n_gts += len(ground_truth_videoid) + try: + prediction_videoid = prediction_gbvn.get_group(videoid) + except Exception as e: + continue + + this_gt = ground_truth_videoid.reset_index() + this_pred = prediction_videoid.reset_index() + + # Sort predictions by decreasing score order. + score_sort_idx = this_pred['score'].values.argsort()[::-1] + top_kx_idx = score_sort_idx[:max(top_k) * len(this_gt)] + tiou_arr = k_segment_iou(this_pred[['t-start', 't-end']].values[top_kx_idx], + this_gt[['t-start', 't-end']].values) + + for tidx, tiou_thr in enumerate(tiou_thresholds): + for kidx, k in enumerate(top_k): + tiou = tiou_arr[:k * len(this_gt)] + tp[tidx, kidx] += ((tiou >= tiou_thr).sum(axis=0) > 0).sum() + + recall = tp / n_gts + + return recall + + +def k_segment_iou(target_segments, candidate_segments): + return np.stack( + [segment_iou(target_segment, candidate_segments) \ + for target_segment in target_segments] + ) + + +def segment_iou(target_segment, candidate_segments): + """Compute the temporal intersection over union between a + target segment and all the test segments. + Parameters + ---------- + target_segment : 1d array + Temporal target segment containing [starting, ending] times. + candidate_segments : 2d array + Temporal candidate segments containing N x [starting, ending] times. + Outputs + ------- + tiou : 1d array + Temporal intersection over union score of the N's candidate segments. + """ + tt1 = np.maximum(target_segment[0], candidate_segments[:, 0]) + tt2 = np.minimum(target_segment[1], candidate_segments[:, 1]) + # Intersection including Non-negative overlap score. + segments_intersection = (tt2 - tt1).clip(0) + # Segment union. + segments_union = (candidate_segments[:, 1] - candidate_segments[:, 0]) \ + + (target_segment[1] - target_segment[0]) - segments_intersection + # Compute overlap as the ratio of the intersection + # over union of two segments. + tIoU = segments_intersection.astype(float) / segments_union + return tIoU + + +def interpolated_prec_rec(prec, rec): + """Interpolated AP - VOCdevkit from VOC 2011. + """ + mprec = np.hstack([[0], prec, [0]]) + mrec = np.hstack([[0], rec, [1]]) + for i in range(len(mprec) - 1)[::-1]: + mprec[i] = max(mprec[i], mprec[i + 1]) + idx = np.where(mrec[1::] != mrec[0:-1])[0] + 1 + ap = np.sum((mrec[idx] - mrec[idx - 1]) * mprec[idx]) + return ap diff --git a/code/actionformer_release/libs/utils/nms.py b/code/actionformer_release/libs/utils/nms.py new file mode 100644 index 0000000..27ceafe --- /dev/null +++ b/code/actionformer_release/libs/utils/nms.py @@ -0,0 +1,190 @@ +# Functions for 1D NMS, modified from: +# https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/nms.py +import torch + +import nms_1d_cpu + + +class NMSop(torch.autograd.Function): + @staticmethod + def forward( + ctx, segs, scores, cls_idxs, + iou_threshold, min_score, max_num + ): + # vanilla nms will not change the score, so we can filter segs first + is_filtering_by_score = (min_score > 0) + if is_filtering_by_score: + valid_mask = scores > min_score + segs, scores = segs[valid_mask], scores[valid_mask] + cls_idxs = cls_idxs[valid_mask] + valid_inds = torch.nonzero( + valid_mask, as_tuple=False).squeeze(dim=1) + + # nms op; return inds that is sorted by descending order + inds = nms_1d_cpu.nms( + segs.contiguous().cpu(), + scores.contiguous().cpu(), + iou_threshold=float(iou_threshold)) + # cap by max number + if max_num > 0: + inds = inds[:min(max_num, len(inds))] + # return the sorted segs / scores + sorted_segs = segs[inds] + sorted_scores = scores[inds] + sorted_cls_idxs = cls_idxs[inds] + return sorted_segs.clone(), sorted_scores.clone(), sorted_cls_idxs.clone() + + +class SoftNMSop(torch.autograd.Function): + @staticmethod + def forward( + ctx, segs, scores, cls_idxs, + iou_threshold, sigma, min_score, method, max_num + ): + # pre allocate memory for sorted results + dets = segs.new_empty((segs.size(0), 3), device='cpu') + # softnms op, return dets that stores the sorted segs / scores + inds = nms_1d_cpu.softnms( + segs.cpu(), + scores.cpu(), + dets.cpu(), + iou_threshold=float(iou_threshold), + sigma=float(sigma), + min_score=float(min_score), + method=int(method)) + # cap by max number + if max_num > 0: + n_segs = min(len(inds), max_num) + else: + n_segs = len(inds) + sorted_segs = dets[:n_segs, :2] + sorted_scores = dets[:n_segs, 2] + sorted_cls_idxs = cls_idxs[inds] + sorted_cls_idxs = sorted_cls_idxs[:n_segs] + return sorted_segs.clone(), sorted_scores.clone(), sorted_cls_idxs.clone() + + +def seg_voting(nms_segs, all_segs, all_scores, iou_threshold, score_offset=1.5): + """ + blur localization results by incorporating side segs. + this is known as bounding box voting in object detection literature. + slightly boost the performance around iou_threshold + """ + + # *_segs : N_i x 2, all_scores: N, + # apply offset + offset_scores = all_scores + score_offset + + # computer overlap between nms and all segs + # construct the distance matrix of # N_nms x # N_all + num_nms_segs, num_all_segs = nms_segs.shape[0], all_segs.shape[0] + ex_nms_segs = nms_segs[:, None].expand(num_nms_segs, num_all_segs, 2) + ex_all_segs = all_segs[None, :].expand(num_nms_segs, num_all_segs, 2) + + # compute intersection + left = torch.maximum(ex_nms_segs[:, :, 0], ex_all_segs[:, :, 0]) + right = torch.minimum(ex_nms_segs[:, :, 1], ex_all_segs[:, :, 1]) + inter = (right-left).clamp(min=0) + + # lens of all segments + nms_seg_lens = ex_nms_segs[:, :, 1] - ex_nms_segs[:, :, 0] + all_seg_lens = ex_all_segs[:, :, 1] - ex_all_segs[:, :, 0] + + # iou + iou = inter / (nms_seg_lens + all_seg_lens - inter) + + # get neighbors (# N_nms x # N_all) / weights + seg_weights = (iou >= iou_threshold).to(all_scores.dtype) * all_scores[None, :] * iou + seg_weights /= torch.sum(seg_weights, dim=1, keepdim=True) + refined_segs = seg_weights @ all_segs + + return refined_segs + +def batched_nms( + segs, + scores, + cls_idxs, + iou_threshold, + min_score, + max_seg_num, + use_soft_nms=True, + multiclass=True, + sigma=0.5, + voting_thresh=0.75, +): + # Based on Detectron2 implementation, + num_segs = segs.shape[0] + # corner case, no prediction outputs + if num_segs == 0: + return torch.zeros([0, 2]),\ + torch.zeros([0,]),\ + torch.zeros([0,], dtype=cls_idxs.dtype) + + if multiclass: + # multiclass nms: apply nms on each class independently + new_segs, new_scores, new_cls_idxs = [], [], [] + for class_id in torch.unique(cls_idxs): + curr_indices = torch.where(cls_idxs == class_id)[0] + # soft_nms vs nms + if use_soft_nms: + sorted_segs, sorted_scores, sorted_cls_idxs = SoftNMSop.apply( + segs[curr_indices], + scores[curr_indices], + cls_idxs[curr_indices], + iou_threshold, + sigma, + min_score, + 2, + max_seg_num + ) + else: + sorted_segs, sorted_scores, sorted_cls_idxs = NMSop.apply( + segs[curr_indices], + scores[curr_indices], + cls_idxs[curr_indices], + iou_threshold, + min_score, + max_seg_num + ) + # disable seg voting for multiclass nms, no sufficient segs + + # fill in the class index + new_segs.append(sorted_segs) + new_scores.append(sorted_scores) + new_cls_idxs.append(sorted_cls_idxs) + + # cat the results + new_segs = torch.cat(new_segs) + new_scores = torch.cat(new_scores) + new_cls_idxs = torch.cat(new_cls_idxs) + + else: + # class agnostic + if use_soft_nms: + new_segs, new_scores, new_cls_idxs = SoftNMSop.apply( + segs, scores, cls_idxs, iou_threshold, + sigma, min_score, 2, max_seg_num + ) + else: + new_segs, new_scores, new_cls_idxs = NMSop.apply( + segs, scores, cls_idxs, iou_threshold, + min_score, max_seg_num + ) + # seg voting + if voting_thresh > 0: + new_segs = seg_voting( + new_segs, + segs, + scores, + voting_thresh + ) + + # sort based on scores and return + # truncate the results based on max_seg_num + _, idxs = new_scores.sort(descending=True) + max_seg_num = min(max_seg_num, new_segs.shape[0]) + # needed for multiclass NMS + new_segs = new_segs[idxs[:max_seg_num]] + new_scores = new_scores[idxs[:max_seg_num]] + new_cls_idxs = new_cls_idxs[idxs[:max_seg_num]] + return new_segs, new_scores, new_cls_idxs diff --git a/code/actionformer_release/libs/utils/postprocessing.py b/code/actionformer_release/libs/utils/postprocessing.py new file mode 100644 index 0000000..051a1fd --- /dev/null +++ b/code/actionformer_release/libs/utils/postprocessing.py @@ -0,0 +1,155 @@ +import os +import shutil +import time +import json +import pickle +from typing import Dict + +import numpy as np + +import torch + +from .metrics import ANETdetection + + +def load_results_from_pkl(filename): + # load from pickle file + assert os.path.isfile(filename) + with open(filename, "rb") as f: + results = pickle.load(f) + return results + +def load_results_from_json(filename): + assert os.path.isfile(filename) + with open(filename, "r") as f: + results = json.load(f) + # for activity net external classification scores + if 'results' in results: + results = results['results'] + return results + +def results_to_dict(results): + """convert result arrays into dict used by json files""" + # video ids and allocate the dict + vidxs = sorted(list(set(results['video-id']))) + results_dict = {} + for vidx in vidxs: + results_dict[vidx] = [] + + # fill in the dict + for vidx, start, end, label, score in zip( + results['video-id'], + results['t-start'], + results['t-end'], + results['label'], + results['score'] + ): + results_dict[vidx].append( + { + "label" : int(label), + "score" : float(score), + "segment": [float(start), float(end)], + } + ) + return results_dict + + +def results_to_array(results, num_pred): + # video ids and allocate the dict + vidxs = sorted(list(set(results['video-id']))) + results_dict = {} + for vidx in vidxs: + results_dict[vidx] = { + 'label' : [], + 'score' : [], + 'segment' : [], + } + + # fill in the dict + for vidx, start, end, label, score in zip( + results['video-id'], + results['t-start'], + results['t-end'], + results['label'], + results['score'] + ): + results_dict[vidx]['label'].append(int(label)) + results_dict[vidx]['score'].append(float(score)) + results_dict[vidx]['segment'].append( + [float(start), float(end)] + ) + + for vidx in vidxs: + label = np.asarray(results_dict[vidx]['label']) + score = np.asarray(results_dict[vidx]['score']) + segment = np.asarray(results_dict[vidx]['segment']) + + # the score should be already sorted, just for safety + inds = np.argsort(score)[::-1][:num_pred] + label, score, segment = label[inds], score[inds], segment[inds] + results_dict[vidx]['label'] = label + results_dict[vidx]['score'] = score + results_dict[vidx]['segment'] = segment + + return results_dict + + +def postprocess_results(results, cls_score_file, num_pred=200, topk=2): + + # load results and convert to dict + if isinstance(results, str): + results = load_results_from_pkl(results) + # array -> dict + results = results_to_array(results, num_pred) + + # load external classification scores + if '.json' in cls_score_file: + cls_scores = load_results_from_json(cls_score_file) + else: + cls_scores = load_results_from_pkl(cls_score_file) + + # dict for processed results + processed_results = { + 'video-id': [], + 't-start' : [], + 't-end': [], + 'label': [], + 'score': [] + } + + # process each video + for vid, result in results.items(): + # pick top k cls scores and idx + curr_cls_scores = np.asarray(cls_scores[vid]) + topk_cls_idx = np.argsort(curr_cls_scores)[::-1][:topk] + topk_cls_score = curr_cls_scores[topk_cls_idx] + + # model outputs + pred_score, pred_segment, pred_label = \ + result['score'], result['segment'], result['label'] + num_segs = min(num_pred, len(pred_score)) + + # duplicate all segment and assign the topk labels + # K x 1 @ 1 N -> K x N -> KN + # multiply the scores + new_pred_score = np.sqrt(topk_cls_score[:, None] @ pred_score[None, :]).flatten() + new_pred_segment = np.tile(pred_segment, (topk, 1)) + new_pred_label = np.tile(topk_cls_idx[:, None], (1, num_segs)).flatten() + + # add to result + processed_results['video-id'].extend([vid]*num_segs*topk) + processed_results['t-start'].append(new_pred_segment[:, 0]) + processed_results['t-end'].append(new_pred_segment[:, 1]) + processed_results['label'].append(new_pred_label) + processed_results['score'].append(new_pred_score) + + processed_results['t-start'] = np.concatenate( + processed_results['t-start'], axis=0) + processed_results['t-end'] = np.concatenate( + processed_results['t-end'], axis=0) + processed_results['label'] = np.concatenate( + processed_results['label'],axis=0) + processed_results['score'] = np.concatenate( + processed_results['score'], axis=0) + + return processed_results diff --git a/code/actionformer_release/libs/utils/setup.py b/code/actionformer_release/libs/utils/setup.py new file mode 100644 index 0000000..3606fdf --- /dev/null +++ b/code/actionformer_release/libs/utils/setup.py @@ -0,0 +1,19 @@ +import torch + +from setuptools import setup, Extension +from torch.utils.cpp_extension import BuildExtension, CppExtension + + +setup( + name='nms_1d_cpu', + ext_modules=[ + CppExtension( + name = 'nms_1d_cpu', + sources = ['./csrc/nms_cpu.cpp'], + extra_compile_args=['-fopenmp'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/code/actionformer_release/libs/utils/train_utils.py b/code/actionformer_release/libs/utils/train_utils.py new file mode 100644 index 0000000..f5c6800 --- /dev/null +++ b/code/actionformer_release/libs/utils/train_utils.py @@ -0,0 +1,439 @@ +import os +import shutil +import time +import pickle + +import numpy as np +import random +from copy import deepcopy + +import torch +import torch.optim as optim +import torch.backends.cudnn as cudnn + +from .lr_schedulers import LinearWarmupMultiStepLR, LinearWarmupCosineAnnealingLR +from .postprocessing import postprocess_results +from ..modeling import MaskedConv1D, Scale, AffineDropPath, LayerNorm + + +################################################################################ +def fix_random_seed(seed, include_cuda=True): + rng_generator = torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + if include_cuda: + # training: disable cudnn benchmark to ensure the reproducibility + cudnn.enabled = True + cudnn.benchmark = False + cudnn.deterministic = True + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # this is needed for CUDA >= 10.2 + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.use_deterministic_algorithms(True, warn_only=True) + else: + cudnn.enabled = True + cudnn.benchmark = True + return rng_generator + + +def save_checkpoint(state, is_best, file_folder, + file_name='checkpoint.pth.tar'): + """save checkpoint to file""" + if not os.path.exists(file_folder): + os.mkdir(file_folder) + torch.save(state, os.path.join(file_folder, file_name)) + if is_best: + # skip the optimization / scheduler state + state.pop('optimizer', None) + state.pop('scheduler', None) + torch.save(state, os.path.join(file_folder, 'model_best.pth.tar')) + + +def print_model_params(model): + for name, param in model.named_parameters(): + print(name, param.min().item(), param.max().item(), param.mean().item()) + return + + +def make_optimizer(model, optimizer_config): + """create optimizer + return a supported optimizer + """ + # separate out all parameters that with / without weight decay + # see https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d, MaskedConv1D) + blacklist_weight_modules = (LayerNorm, torch.nn.GroupNorm) + + # loop over all modules / params + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + elif pn.endswith('scale') and isinstance(m, (Scale, AffineDropPath)): + # corner case of our scale layer + no_decay.add(fpn) + elif pn.endswith('rel_pe'): + # corner case for relative position encoding + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in model.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, \ + "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": optimizer_config['weight_decay']}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + + if optimizer_config["type"] == "SGD": + optimizer = optim.SGD( + optim_groups, + lr=optimizer_config["learning_rate"], + momentum=optimizer_config["momentum"] + ) + elif optimizer_config["type"] == "AdamW": + optimizer = optim.AdamW( + optim_groups, + lr=optimizer_config["learning_rate"] + ) + else: + raise TypeError("Unsupported optimizer!") + + return optimizer + + +def make_scheduler( + optimizer, + optimizer_config, + num_iters_per_epoch, + last_epoch=-1 +): + """create scheduler + return a supported scheduler + All scheduler returned by this function should step every iteration + """ + if optimizer_config["warmup"]: + max_epochs = optimizer_config["epochs"] + optimizer_config["warmup_epochs"] + max_steps = max_epochs * num_iters_per_epoch + + # get warmup params + warmup_epochs = optimizer_config["warmup_epochs"] + warmup_steps = warmup_epochs * num_iters_per_epoch + + # with linear warmup: call our custom schedulers + if optimizer_config["schedule_type"] == "cosine": + # Cosine + scheduler = LinearWarmupCosineAnnealingLR( + optimizer, + warmup_steps, + max_steps, + last_epoch=last_epoch + ) + + elif optimizer_config["schedule_type"] == "multistep": + # Multi step + steps = [num_iters_per_epoch * step for step in optimizer_config["schedule_steps"]] + scheduler = LinearWarmupMultiStepLR( + optimizer, + warmup_steps, + steps, + gamma=optimizer_config["schedule_gamma"], + last_epoch=last_epoch + ) + else: + raise TypeError("Unsupported scheduler!") + + else: + max_epochs = optimizer_config["epochs"] + max_steps = max_epochs * num_iters_per_epoch + + # without warmup: call default schedulers + if optimizer_config["schedule_type"] == "cosine": + # step per iteration + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, + max_steps, + last_epoch=last_epoch + ) + + elif optimizer_config["schedule_type"] == "multistep": + # step every some epochs + steps = [num_iters_per_epoch * step for step in optimizer_config["schedule_steps"]] + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + steps, + gamma=schedule_config["gamma"], + last_epoch=last_epoch + ) + else: + raise TypeError("Unsupported scheduler!") + + return scheduler + + +class AverageMeter(object): + """Computes and stores the average and current value. + Used to compute dataset stats from mini-batches + """ + def __init__(self): + self.initialized = False + self.val = None + self.avg = None + self.sum = None + self.count = 0.0 + + def initialize(self, val, n): + self.val = val + self.avg = val + self.sum = val * n + self.count = n + self.initialized = True + + def update(self, val, n=1): + if not self.initialized: + self.initialize(val, n) + else: + self.add(val, n) + + def add(self, val, n): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class ModelEma(torch.nn.Module): + def __init__(self, model, decay=0.999, device=None): + super().__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) + + +################################################################################ +def train_one_epoch( + train_loader, + model, + optimizer, + scheduler, + curr_epoch, + model_ema = None, + clip_grad_l2norm = -1, + tb_writer = None, + print_freq = 20 +): + """Training the model for one epoch""" + # set up meters + batch_time = AverageMeter() + losses_tracker = {} + # number of iterations per epoch + num_iters = len(train_loader) + # switch to train mode + model.train() + + # main training loop + print("\n[Train]: Epoch {:d} started".format(curr_epoch)) + start = time.time() + for iter_idx, video_list in enumerate(train_loader, 0): + # zero out optim + optimizer.zero_grad(set_to_none=True) + # forward / backward the model + losses = model(video_list) + losses['final_loss'].backward() + # gradient cliping (to stabilize training if necessary) + if clip_grad_l2norm > 0.0: + torch.nn.utils.clip_grad_norm_( + model.parameters(), + clip_grad_l2norm + ) + # step optimizer / scheduler + optimizer.step() + scheduler.step() + + if model_ema is not None: + model_ema.update(model) + + # printing (only check the stats when necessary to avoid extra cost) + if (iter_idx != 0) and (iter_idx % print_freq) == 0: + # measure elapsed time (sync all kernels) + torch.cuda.synchronize() + batch_time.update((time.time() - start) / print_freq) + start = time.time() + + # track all losses + for key, value in losses.items(): + # init meter if necessary + if key not in losses_tracker: + losses_tracker[key] = AverageMeter() + # update + losses_tracker[key].update(value.item()) + + # log to tensor board + lr = scheduler.get_last_lr()[0] + global_step = curr_epoch * num_iters + iter_idx + if tb_writer is not None: + # learning rate (after stepping) + tb_writer.add_scalar( + 'train/learning_rate', + lr, + global_step + ) + # all losses + tag_dict = {} + for key, value in losses_tracker.items(): + if key != "final_loss": + tag_dict[key] = value.val + tb_writer.add_scalars( + 'train/all_losses', + tag_dict, + global_step + ) + # final loss + tb_writer.add_scalar( + 'train/final_loss', + losses_tracker['final_loss'].val, + global_step + ) + + # print to terminal + block1 = 'Epoch: [{:03d}][{:05d}/{:05d}]'.format( + curr_epoch, iter_idx, num_iters + ) + block2 = 'Time {:.2f} ({:.2f})'.format( + batch_time.val, batch_time.avg + ) + block3 = 'Loss {:.2f} ({:.2f})\n'.format( + losses_tracker['final_loss'].val, + losses_tracker['final_loss'].avg + ) + block4 = '' + for key, value in losses_tracker.items(): + if key != "final_loss": + block4 += '\t{:s} {:.2f} ({:.2f})'.format( + key, value.val, value.avg + ) + + print('\t'.join([block1, block2, block3, block4])) + + # finish up and print + lr = scheduler.get_last_lr()[0] + print("[Train]: Epoch {:d} finished with lr={:.8f}\n".format(curr_epoch, lr)) + return + + +def valid_one_epoch( + val_loader, + model, + curr_epoch, + ext_score_file = None, + evaluator = None, + output_file = None, + tb_writer = None, + print_freq = 20 +): + """Test the model on the validation set""" + # either evaluate the results or save the results + assert (evaluator is not None) or (output_file is not None) + + # set up meters + batch_time = AverageMeter() + # switch to evaluate mode + model.eval() + # dict for results (for our evaluation code) + results = { + 'video-id': [], + 't-start' : [], + 't-end': [], + 'label': [], + 'score': [] + } + + # loop over validation set + start = time.time() + for iter_idx, video_list in enumerate(val_loader, 0): + # forward the model (wo. grad) + with torch.no_grad(): + output = model(video_list) + + # unpack the results into ANet format + num_vids = len(output) + for vid_idx in range(num_vids): + if output[vid_idx]['segments'].shape[0] > 0: + results['video-id'].extend( + [output[vid_idx]['video_id']] * + output[vid_idx]['segments'].shape[0] + ) + results['t-start'].append(output[vid_idx]['segments'][:, 0]) + results['t-end'].append(output[vid_idx]['segments'][:, 1]) + results['label'].append(output[vid_idx]['labels']) + results['score'].append(output[vid_idx]['scores']) + + # printing + if (iter_idx != 0) and iter_idx % (print_freq) == 0: + # measure elapsed time (sync all kernels) + torch.cuda.synchronize() + batch_time.update((time.time() - start) / print_freq) + start = time.time() + + # print timing + print('Test: [{0:05d}/{1:05d}]\t' + 'Time {batch_time.val:.2f} ({batch_time.avg:.2f})'.format( + iter_idx, len(val_loader), batch_time=batch_time)) + + # gather all stats and evaluate + results['t-start'] = torch.cat(results['t-start']).numpy() + results['t-end'] = torch.cat(results['t-end']).numpy() + results['label'] = torch.cat(results['label']).numpy() + results['score'] = torch.cat(results['score']).numpy() + + if evaluator is not None: + if ext_score_file is not None and isinstance(ext_score_file, str): + results = postprocess_results(results, ext_score_file) + # call the evaluator + _, mAP, _ = evaluator.evaluate(results, verbose=True) + else: + # dump to a pickle file that can be directly used for evaluation + with open(output_file, "wb") as f: + pickle.dump(results, f) + mAP = 0.0 + + # log mAP to tb_writer + if tb_writer is not None: + tb_writer.add_scalar('validation/mAP', mAP, curr_epoch) + + return mAP \ No newline at end of file diff --git a/code/dataset.py b/code/dataset.py new file mode 100644 index 0000000..046f843 --- /dev/null +++ b/code/dataset.py @@ -0,0 +1,1421 @@ +#!/usr/bin/env python3 +""" +从 ~/data/haocai/ 递归扫描「叶子会话目录」(含 mp4 + xlsx,且子目录中不再含 mp4), +按 Excel 中的时间段从对应视频抽帧,输出到「输出根/images/<商品名称>/<规格>/」并生成 JSON 元数据。 +输出分辨率默认与源视频帧一致;可用 --max-width / --max-height 限制最大尺寸(仅缩小、不放大)。 +可选 --sample-every N:按全局成功保存顺序,每第 N 张在 JSON 中标记 sample=true(便于抽检)。 +可选 --limit N:最多生成 N 条(图片或片段),用于快速检查 JSON 格式;0 表示不限制。 +可选 --extract-backend:抽帧方式。默认 auto(有 ffmpeg 则用 ffmpeg)。默认精确 seek(-ss 在 -i 之后); + 可加 --ffmpeg-fast-seek 换快 seek(部分 HEVC/H.265 文件会得到全灰无效帧,脚本会自动改回精确 seek 重试)。 + 建议安装 ffprobe 与 ffmpeg,时长/帧率以 ffprobe 为准。 +可选 --detect-bbox:用 Grounding DINO(transformers + torch)检测人体并输出 bbox 到 JSON。 +可选 --save-vis:在输出根下单独目录(默认 vis/)生成与 images 同结构的 *_vis.jpg,框与英文类别叠加在图上。 + +列约定(与样本数据一致): +- 单个 xlsx、两个视频:约 A–J,表头含「视频1」「视频2」时间段列(常见为第 9、10 列)。 +- 单个 xlsx、一个视频:约 A–I,最后一列为「视频内时间段」。 +- 两个 xlsx、两个视频:每个文件 A–I,最后一列为该视频「视频内时间段」;按文件名中的 01/02 与视频配对。 +""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import re +import shutil +import subprocess +import sys +import time +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Any, Callable, Iterator, Optional + +import cv2 +import numpy as np +import pandas as pd + +# 临时 / 锁文件 +_IGNORE_XLSX = re.compile(r"^~\$|^\._|^\.\~", re.I) + + +def _log(msg: str) -> None: + """运行日志(stderr,立即刷新)。""" + ts = time.strftime("%H:%M:%S") + print(f"[{ts}] {msg}", file=sys.stderr, flush=True) + + +@dataclass +class ImageRecord: + name: str + path: str + label_category: str # 商品名称 + size: str # 规格 + sample: bool = False # 每第 N 张(见 --sample-every)为 True + # YOLO 格式 [x_center, y_center, w, h] 归一化 0–1;未启用检测或未检出时为 None + bbox_xywhn: Optional[list[float]] = None + detection_score: Optional[float] = None + + +@dataclass +class VideoMeta: + """视频流元数据;优先来自 ffprobe(比 OpenCV 对 HEVC/VFR 更可靠)。""" + + width: int + height: int + fps: float + duration_sec: float + frame_count: int = 0 + + +def _parse_fraction(s: str) -> float: + s = (s or "").strip() + if not s or s == "0/0": + return 0.0 + if "/" in s: + a, b = s.split("/", 1) + try: + den = float(b) + return float(a) / den if den else 0.0 + except ValueError: + return 0.0 + try: + return float(s) + except ValueError: + return 0.0 + + +def _ffprobe_video_meta(path: Path, ffprobe_bin: str) -> Optional[VideoMeta]: + if not shutil.which(ffprobe_bin): + return None + cmd = [ + ffprobe_bin, + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=width,height,avg_frame_rate,r_frame_rate,nb_frames,duration", + "-show_entries", + "format=duration", + "-of", + "json", + str(path), + ] + try: + p = subprocess.run( + cmd, capture_output=True, text=True, timeout=60, check=False + ) + except (subprocess.TimeoutExpired, OSError): + return None + if p.returncode != 0 or not p.stdout: + return None + try: + data = json.loads(p.stdout) + except json.JSONDecodeError: + return None + streams = data.get("streams") or [] + if not streams: + return None + st = streams[0] + w = int(st.get("width") or 0) + h = int(st.get("height") or 0) + if w < 2 or h < 2: + return None + fps = _parse_fraction(str(st.get("avg_frame_rate") or "")) + if fps <= 0: + fps = _parse_fraction(str(st.get("r_frame_rate") or "")) + dur_s = float(st.get("duration") or 0.0) + fmt = data.get("format") or {} + if dur_s <= 0: + dur_s = float(fmt.get("duration") or 0.0) + nbf = st.get("nb_frames") + frame_count = 0 + if nbf is not None and str(nbf).strip() and str(nbf).upper() != "N/A": + try: + frame_count = int(nbf) + except (TypeError, ValueError): + frame_count = 0 + if frame_count <= 0 and dur_s > 0 and fps > 0: + frame_count = int(round(dur_s * fps)) + if fps <= 0 and dur_s > 0 and frame_count > 0: + fps = frame_count / dur_s + if fps <= 0: + fps = 25.0 + return VideoMeta( + width=w, + height=h, + fps=float(fps), + duration_sec=float(dur_s), + frame_count=frame_count, + ) + + +def _opencv_video_meta(path: Path) -> VideoMeta: + cap = cv2.VideoCapture(str(path), cv2.CAP_FFMPEG) + if not cap.isOpened(): + return VideoMeta(0, 0, 25.0, 0.0, 0) + try: + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) + fps = float(cap.get(cv2.CAP_PROP_FPS) or 25.0) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) + duration = ( + (frame_count / fps) if fps > 0 and frame_count > 0 else 0.0 + ) + return VideoMeta( + width=w, height=h, fps=fps, duration_sec=duration, frame_count=frame_count + ) + finally: + cap.release() + + +# 同一视频在一张表里会抽多次帧;缓存 ffprobe 结果,避免每个时间点都跑一遍 ffprobe。 +_VIDEO_META_CACHE: dict[tuple[str, str], VideoMeta] = {} + + +def get_video_meta(path: Path, ffprobe_bin: str = "ffprobe") -> VideoMeta: + key = (str(Path(path).resolve()), ffprobe_bin) + if key in _VIDEO_META_CACHE: + return _VIDEO_META_CACHE[key] + m = _ffprobe_video_meta(path, ffprobe_bin) + if m is not None: + _VIDEO_META_CACHE[key] = m + return m + m = _opencv_video_meta(path) + _VIDEO_META_CACHE[key] = m + return m + + +def _clamp_time_sec(t_sec: float, meta: VideoMeta) -> float: + if meta.duration_sec > 0: + margin = 1.0 / max(meta.fps, 1.0) + return float( + min(max(0.0, t_sec), max(0.0, meta.duration_sec - margin)) + ) + return max(0.0, t_sec) + + +def _time_to_frame_index(t_sec: float, meta: VideoMeta) -> int: + fps = meta.fps if meta.fps > 0 else 25.0 + t = _clamp_time_sec(t_sec, meta) + idx = int(round(t * fps)) + if meta.frame_count > 0: + idx = min(idx, meta.frame_count - 1) + return max(0, idx) + + +def _expand_root(p: str | Path) -> Path: + return Path(p).expanduser().resolve() + + +def _is_real_xlsx(path: Path) -> bool: + if path.suffix.lower() not in (".xlsx", ".xls"): + return False + name = path.name + if name.startswith("~$") or name.startswith(".~"): + return False + if _IGNORE_XLSX.search(name): + return False + return True + + +def _is_real_mp4(path: Path) -> bool: + if path.suffix.lower() != ".mp4": + return False + if ".crdownload" in path.name.lower(): + return False + return True + + +def _dir_has_mp4_recursive(d: Path) -> bool: + if not d.is_dir(): + return False + try: + for p in d.rglob("*.mp4"): + if _is_real_mp4(p): + return True + except OSError: + pass + return False + + +def iter_leaf_session_dirs(root: Path) -> Iterator[Path]: + """叶子目录:直接包含至少一个有效 mp4 与 xlsx,且其子目录内不再出现 mp4。""" + import os + + root = root.resolve() + if not root.is_dir(): + return + + for dirpath, dirnames, filenames in os.walk(root, topdown=True): + p = Path(dirpath) + mp4s = [p / f for f in filenames if _is_real_mp4(p / f)] + xlsxs = [p / f for f in filenames if _is_real_xlsx(p / f)] + if not mp4s or not xlsxs: + continue + sub_has_mp4 = False + for sub in dirnames: + if _dir_has_mp4_recursive(p / sub): + sub_has_mp4 = True + break + if sub_has_mp4: + continue + yield p + + +def _video_sort_key(path: Path) -> tuple: + stem = path.stem + m = re.search(r"(\d+)", stem) + n = int(m.group(1)) if m else 10**9 + return (n, stem.lower()) + + +def list_videos(session_dir: Path) -> list[Path]: + vids = [p for p in session_dir.iterdir() if p.is_file() and _is_real_mp4(p)] + return sorted(vids, key=_video_sort_key) + + +def list_excels(session_dir: Path) -> list[Path]: + xs = [p for p in session_dir.iterdir() if p.is_file() and _is_real_xlsx(p)] + return sorted(xs, key=lambda p: p.name.lower()) + + +def _excel_pair_key(path: Path) -> tuple: + m = re.search(r"(\d+)", path.stem) + n = int(m.group(1)) if m else 10**9 + return (n, path.name.lower()) + + +def _normalize_header(s: Any) -> str: + if s is None or (isinstance(s, float) and pd.isna(s)): + return "" + return str(s).strip() + + +def _find_col(df: pd.DataFrame, *candidates: str) -> str | None: + cols = [str(c).strip() for c in df.columns] + for want in candidates: + for c in df.columns: + h = _normalize_header(c) + if h == want or want in h: + return c + return None + + +def normalize_haocai_class_name(name: str) -> str: + """ + 与 build_haocai_dataset_hand_crops.row_product 保持一致的类名归一。 + Excel 与训练类名在个别耗材上同物异名,此处合并为同一条目。 + """ + s = (name or "").strip() + if s == "一次性使用灭菌棉签": + return "一次性医用灭菌棉签" + if s in ( + "一次性使用手术衣", + "一次性使用手术单(一次性医用垫单)", + "一次性医用垫单", + ): + return "一次性使用手术单" + return s + + +def parse_time_range(text: Any) -> tuple[float, float] | None: + """ + 支持: + - 1.23-2.23 → 1 分 23 秒 到 2 分 23 秒 + - 0.05-0.11 → 0 分 5 秒 到 0 分 11 秒(点后为两位秒) + - 00:10-00:16 / 00:10-00:16 → mm:ss + """ + if text is None or (isinstance(text, float) and pd.isna(text)): + return None + s = str(text).strip() + if not s or s.lower() == "nan": + return None + + # 全角冒号 + s = s.replace(":", ":") + + # mm:ss - mm:ss + m = re.match( + r"^\s*(\d{1,2}):(\d{2})\s*[-–—~~]\s*(\d{1,2}):(\d{2})\s*$", + s, + ) + if m: + h1, m1, h2, m2 = m.groups() + a = int(h1) * 60 + int(m1) + b = int(h2) * 60 + int(m2) + return (float(min(a, b)), float(max(a, b))) + + # M.SS - M.SS(分.秒,秒为 1~2 位时按两位秒理解) + m = re.match( + r"^\s*(\d+)\s*\.\s*(\d{1,2})\s*[-–—~~]\s*(\d+)\s*\.\s*(\d{1,2})\s*$", + s, + ) + if m: + mm1, ss1, mm2, ss2 = m.groups() + ss1 = ss1.zfill(2)[:2] + ss2 = ss2.zfill(2)[:2] + a = int(mm1) * 60 + int(ss1) + b = int(mm2) * 60 + int(ss2) + return (float(min(a, b)), float(max(a, b))) + + return None + + +def _midpoint_seconds(start: float, end: float) -> float: + return max(0.0, (start + end) / 2.0) + + +def _sample_time_in_tear_segment( + start: float, + end: float, + *, + mode: str = "tear_first_half", +) -> float: + """ + 在 Excel 标注的「撕」时间段 [start, end] 内选取抽帧时刻。 + + - tear_first_half(默认):落在区间**前半段**,取该半段内 3/4 分位 + t = start + 0.375 * (end - start),与「后半段 3/4」对称。 + - tear_second_half:整段的后 3/4 分位 t = start + 0.75 * (end - start)。 + - midpoint:取 (start+end)/2。 + """ + if end <= start: + return max(0.0, start) + span = end - start + if mode == "midpoint": + return _midpoint_seconds(start, end) + if mode == "tear_second_half": + return max(0.0, start + 0.75 * span) + # tear_first_half + return max(0.0, start + 0.375 * span) + + +def resize_frame_to_max( + frame: Any, + max_width: int, + max_height: int, +) -> Any: + """ + 将帧限制在 max_width×max_height 以内,保持宽高比。 + max_width / max_height 为 0 表示该方向不限制;二者均为 0 则返回原帧(原始分辨率)。 + 仅缩小不放大。 + """ + if frame is None: + return None + if max_width <= 0 and max_height <= 0: + return frame + h, w = frame.shape[:2] + scales: list[float] = [] + if max_width > 0: + scales.append(max_width / w) + if max_height > 0: + scales.append(max_height / h) + if not scales: + return frame + scale = min(scales) + scale = min(scale, 1.0) + if scale >= 1.0: + return frame + nw = max(1, int(round(w * scale))) + nh = max(1, int(round(h * scale))) + return cv2.resize(frame, (nw, nh), interpolation=cv2.INTER_AREA) + + +def save_frame_jpeg( + frame: Any, + out_path: Path, + jpeg_quality: int = 85, + max_width: int = 0, + max_height: int = 0, +) -> tuple[bool, Optional[np.ndarray]]: + """按 max_width/max_height 可选缩小后以 JPEG 写出;返回 (是否成功, 与磁盘一致的 BGR 图)。""" + img = resize_frame_to_max(frame, max_width, max_height) + if img is None: + return False, None + params = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality] + ok = bool(cv2.imwrite(str(out_path), img, params)) + return ok, img if ok else None + + +def save_bbox_vis_jpeg( + img_bgr: np.ndarray, + out_path: Path, + bbox_xywhn: Optional[list[float]], + detection_score: Optional[float], + jpeg_quality: int = 85, +) -> bool: + """在副本上画框后保存为 JPEG。bbox_xywhn 为 YOLO 格式归一化 [cx, cy, w, h]。""" + vis = img_bgr.copy() + h, w = vis.shape[:2] + if bbox_xywhn and len(bbox_xywhn) == 4: + cx, cy, bw, bh = bbox_xywhn + x1 = int(round((cx - bw / 2) * w)) + y1 = int(round((cy - bh / 2) * h)) + x2 = int(round((cx + bw / 2) * w)) + y2 = int(round((cy + bh / 2) * h)) + x1 = max(0, min(x1, w - 1)) + x2 = max(0, min(x2, w - 1)) + y1 = max(0, min(y1, h - 1)) + y2 = max(0, min(y2, h - 1)) + cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 220, 0), max(1, min(w, h) // 400)) + cap = f"{detection_score:.2f}" if detection_score is not None else "det" + (tw, th), _ = cv2.getTextSize(cap, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + ty = max(y1 - 4, th + 4) + cv2.rectangle(vis, (x1, ty - th - 4), (x1 + tw + 4, ty + 2), (0, 220, 0), -1) + cv2.putText(vis, cap, (x1 + 2, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) + else: + cv2.putText(vis, "no detection", (8, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (128, 128, 128), 2, cv2.LINE_AA) + out_path.parent.mkdir(parents=True, exist_ok=True) + params = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality] + return bool(cv2.imwrite(str(out_path), vis, params)) + + +def _write_vis_if_enabled( + vis_out_root: Optional[Path], + label_category: str, + size: str, + fname: str, + img_bgr: np.ndarray, + bbox_xywhn: Optional[list[float]], + detection_score: Optional[float], +) -> None: + if vis_out_root is None: + return + vis_dir = _product_image_dir(vis_out_root, label_category, size) + vis_dir.mkdir(parents=True, exist_ok=True) + vis_path = vis_dir / f"{Path(fname).stem}_vis.jpg" + save_bbox_vis_jpeg(img_bgr, vis_path, bbox_xywhn, detection_score) + + +def _clip_xyxy_xyxy( + xyxy: list[float], w: int, h: int +) -> list[float]: + x1, y1, x2, y2 = xyxy + x1 = float(max(0, min(x1, w - 1))) + x2 = float(max(0, min(x2, w))) + y1 = float(max(0, min(y1, h - 1))) + y2 = float(max(0, min(y2, h))) + if x2 <= x1: + x2 = min(x1 + 1.0, float(w)) + if y2 <= y1: + y2 = min(y1 + 1.0, float(h)) + return [x1, y1, x2, y2] + + +def _xyxy_to_xywhn(xyxy: list[float], w: int, h: int) -> list[float]: + """xyxy 像素 → YOLO [x_center, y_center, width, height] 归一化 0–1。""" + x1, y1, x2, y2 = xyxy + bw = x2 - x1 + bh = y2 - y1 + cx = (x1 + x2) / 2.0 + cy = (y1 + y2) / 2.0 + return [cx / w, cy / h, bw / w, bh / h] + + +class GroundingDinoDetector: + """ + 使用 Grounding DINO(HuggingFace transformers)做开放词汇检测。 + 返回得分最高的一个框:YOLO 格式 [cx, cy, w, h] 归一化 + 分数。 + """ + + def __init__( + self, + model_id: str = "IDEA-Research/grounding-dino-base", + prompt: str = "person .", + box_threshold: float = 0.30, + text_threshold: float = 0.25, + ) -> None: + import torch + from PIL import Image as _PILImage # noqa: F401 + from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor + + self._torch = torch + self._PILImage = _PILImage + self._device = "cuda" if torch.cuda.is_available() else "cpu" + self._processor = AutoProcessor.from_pretrained(model_id) + self._model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(self._device) + self._model.eval() + self.prompt = prompt + self.box_threshold = box_threshold + self.text_threshold = text_threshold + _log(f"GroundingDinoDetector loaded: {model_id} on {self._device}") + + def detect(self, img_bgr: np.ndarray) -> tuple[ + Optional[list[float]], + Optional[float], + ]: + h, w = img_bgr.shape[:2] + if w < 2 or h < 2: + return None, None + + rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + pil = self._PILImage.fromarray(rgb) + + with self._torch.no_grad(): + inputs = self._processor(images=pil, text=self.prompt, return_tensors="pt").to(self._device) + outputs = self._model(**inputs) + target_sizes = self._torch.tensor([[h, w]], device=self._device) + try: + results = self._processor.post_process_grounded_object_detection( + outputs, + inputs.input_ids, + threshold=self.box_threshold, + text_threshold=self.text_threshold, + target_sizes=target_sizes, + )[0] + except TypeError: + results = self._processor.post_process_grounded_object_detection( + outputs, + inputs.input_ids, + box_threshold=self.box_threshold, + text_threshold=self.text_threshold, + target_sizes=target_sizes, + )[0] + + if results is None or len(results["boxes"]) == 0: + return None, None + + best_idx = int(results["scores"].argmax().item()) + b = results["boxes"][best_idx].tolist() + score = float(results["scores"][best_idx].item()) + xyxy = _clip_xyxy_xyxy([float(b[0]), float(b[1]), float(b[2]), float(b[3])], w, h) + xywhn = _xyxy_to_xywhn(xyxy, w, h) + return xywhn, score + + +def _is_degenerate_gray_frame(img: np.ndarray) -> bool: + """ffmpeg 快 seek 在部分 HEVC 码流上可能输出近似中性灰、几乎无纹理的无效帧。""" + if img is None or img.size == 0: + return True + m = float(np.mean(img)) + s = float(np.std(img)) + return 118.0 <= m <= 138.0 and s < 8.0 + + +def extract_frame_ffmpeg( + video_path: Path, + t_sec: float, + *, + ffmpeg_bin: str = "ffmpeg", + ffprobe_bin: str = "ffprobe", + accurate_seek: bool = True, + timeout_sec: float = 600.0, +) -> np.ndarray | None: + """ + 使用 ffmpeg 解码单帧。时间戳 clamp 优先用 ffprobe,避免 OpenCV 对 HEVC 的 fps/时长偏差。 + + accurate_seek=True(默认):-ss 在 -i 之后,解码正确,长视频较慢。 + accurate_seek=False:-ss 在 -i 之前,快,少数文件仍可能异常。 + """ + if not shutil.which(ffmpeg_bin): + return None + meta = get_video_meta(video_path, ffprobe_bin) + if meta.width < 2 or meta.height < 2: + return None + t_clamped = _clamp_time_sec(t_sec, meta) + w, h = meta.width, meta.height + expected_raw = w * h * 3 + + def _run_ffmpeg(cmd: list[str]) -> tuple[Optional[bytes], Optional[str]]: + try: + p = subprocess.run( + cmd, + capture_output=True, + timeout=timeout_sec, + check=False, + ) + except subprocess.TimeoutExpired: + return None, "timeout" + err = (p.stderr or b"").decode("utf-8", errors="replace")[:800] + if p.returncode != 0: + return None, err or f"exit {p.returncode}" + if not p.stdout: + return None, err or "empty stdout" + return p.stdout, None + + def _decode_png(data: bytes) -> Optional[np.ndarray]: + arr = np.frombuffer(data, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + return img + + # 1) 精确 seek + PNG(通用) + if accurate_seek: + cmd_png = [ + ffmpeg_bin, + "-hide_banner", + "-loglevel", + "error", + "-i", + str(video_path), + "-ss", + f"{t_clamped:.6f}", + "-frames:v", + "1", + "-an", + "-f", + "image2pipe", + "-vcodec", + "png", + "-", + ] + else: + cmd_png = [ + ffmpeg_bin, + "-hide_banner", + "-loglevel", + "error", + "-ss", + f"{t_clamped:.6f}", + "-i", + str(video_path), + "-frames:v", + "1", + "-an", + "-f", + "image2pipe", + "-vcodec", + "png", + "-", + ] + out, err = _run_ffmpeg(cmd_png) + if out is not None: + img = _decode_png(out) + if img is not None and img.size > 0: + if not accurate_seek and _is_degenerate_gray_frame(img): + _log( + f"快 seek 输出疑似灰帧,改用精确 seek: {video_path.name} t={t_clamped:.2f}s" + ) + return extract_frame_ffmpeg( + video_path, + t_sec, + ffmpeg_bin=ffmpeg_bin, + ffprobe_bin=ffprobe_bin, + accurate_seek=True, + timeout_sec=timeout_sec, + ) + return img + if err and err != "timeout": + _log(f"ffmpeg PNG 解码失败: {video_path.name}: {err[:200]}") + + # 2) 精确 seek + raw BGR(避免 PNG 编解码;尺寸来自 ffprobe) + cmd_raw = [ + ffmpeg_bin, + "-hide_banner", + "-loglevel", + "error", + "-i", + str(video_path), + "-ss", + f"{t_clamped:.6f}", + "-frames:v", + "1", + "-an", + "-f", + "rawvideo", + "-pix_fmt", + "bgr24", + "-s", + f"{w}x{h}", + "-", + ] + if not accurate_seek: + cmd_raw = [ + ffmpeg_bin, + "-hide_banner", + "-loglevel", + "error", + "-ss", + f"{t_clamped:.6f}", + "-i", + str(video_path), + "-frames:v", + "1", + "-an", + "-f", + "rawvideo", + "-pix_fmt", + "bgr24", + "-s", + f"{w}x{h}", + "-", + ] + out2, err2 = _run_ffmpeg(cmd_raw) + if out2 is not None and len(out2) == expected_raw: + img = np.frombuffer(out2, dtype=np.uint8).reshape((h, w, 3)).copy() + if not accurate_seek and _is_degenerate_gray_frame(img): + _log( + f"快 seek raw 疑似灰帧,改用精确 seek: {video_path.name} t={t_clamped:.2f}s" + ) + return extract_frame_ffmpeg( + video_path, + t_sec, + ffmpeg_bin=ffmpeg_bin, + ffprobe_bin=ffprobe_bin, + accurate_seek=True, + timeout_sec=timeout_sec, + ) + return img + if err2 and err2 != "timeout": + _log(f"ffmpeg rawvideo 失败: {video_path.name}: {err2[:200]}") + + return None + + +def extract_frame_opencv_sequential( + video_path: Path, + t_sec: float, + ffprobe_bin: str = "ffprobe", +) -> Any | None: + """ + 从第 0 帧顺序读到目标帧;帧索引由 ffprobe 元数据计算(比仅用 OpenCV fps 更稳)。 + """ + meta = get_video_meta(video_path, ffprobe_bin) + target_idx = _time_to_frame_index(t_sec, meta) + cap = cv2.VideoCapture(str(video_path), cv2.CAP_FFMPEG) + if not cap.isOpened(): + return None + try: + frame: Any | None = None + for _ in range(target_idx + 1): + ok, frame = cap.read() + if not ok or frame is None: + return None + return frame + finally: + cap.release() + + +def make_extract_frame_fn( + backend: str, + ffmpeg_bin: str, + ffprobe_bin: str, + accurate_seek: bool, +) -> tuple[Callable[[Path, float], Any | None], str]: + """ + 返回 (抽帧函数, 实际后端说明)。 + auto:有 ffmpeg 用 ffmpeg,否则 OpenCV 顺序解码。 + """ + b = backend.strip().lower() + if b == "auto": + b = "ffmpeg" if shutil.which(ffmpeg_bin) else "opencv" + if b == "ffmpeg" and not shutil.which(ffmpeg_bin): + _log(f"未找到 {ffmpeg_bin!r},改用 OpenCV 顺序解码(较慢)") + b = "opencv" + if b == "ffmpeg": + + def fn_ffmpeg(p: Path, t: float) -> Any | None: + img = extract_frame_ffmpeg( + p, + t, + ffmpeg_bin=ffmpeg_bin, + ffprobe_bin=ffprobe_bin, + accurate_seek=accurate_seek, + ) + if img is None: + return extract_frame_opencv_sequential(p, t, ffprobe_bin) + return img + + mode = "ffmpeg_accurate" if accurate_seek else "ffmpeg_fast" + return fn_ffmpeg, mode + def fn_cv_only(p: Path, t: float) -> Any | None: + return extract_frame_opencv_sequential(p, t, ffprobe_bin) + + return fn_cv_only, "opencv_sequential" + + +def _unique_image_name( + session_rel: str, + row_idx: int, + video_tag: str, + time_raw: str, + ext: str = ".jpg", +) -> str: + h = hashlib.sha1( + f"{session_rel}|{row_idx}|{video_tag}|{time_raw}".encode("utf-8") + ).hexdigest()[:16] + safe = re.sub(r"[^\w\u4e00-\u9fff\-]+", "_", session_rel)[-80:] + return f"{safe}__r{row_idx}_{video_tag}_{h}{ext}" + + +def _sanitize_dir_segment(text: Any, fallback: str) -> str: + """目录名:去掉路径非法字符,过长截断;空则用 fallback。""" + if text is None: + return fallback + if isinstance(text, float) and pd.isna(text): + return fallback + t = str(text).strip() + if not t: + return fallback + t = re.sub(r'[/\\:\0<>"|?*]+', "_", t) + t = t.strip(" .") + if not t or all(c == "." for c in t): + return fallback + max_len = 180 + if len(t) > max_len: + t = t[:max_len].rstrip() + return t or fallback + + +def _product_image_dir( + images_out: Path, label_category: str, size: str +) -> Path: + """images/<商品名称>/<规格>/""" + d_name = _sanitize_dir_segment(label_category, "未命名商品") + d_spec = _sanitize_dir_segment(size, "未填规格") + return images_out / d_name / d_spec + + +def _read_excel(path: Path) -> pd.DataFrame: + return pd.read_excel(path, header=0) + + +def _limit_reached(records: list[ImageRecord], limit: int) -> bool: + """limit>0 且已保存条数达到上限时返回 True。""" + return limit > 0 and len(records) >= limit + + +def _record_saved( + records: list[ImageRecord], + global_idx: list[int], + sample_every: int, + fname: str, + out_path: Path, + label_category: str, + size: str, + bbox_xywhn: Optional[list[float]] = None, + detection_score: Optional[float] = None, +) -> None: + """global_idx[0] 为已成功保存张数;每第 sample_every 张标记 sample(N=10 → 第 10、20… 张)。""" + global_idx[0] += 1 + sample = bool( + sample_every > 0 and global_idx[0] % sample_every == 0 + ) + records.append( + ImageRecord( + name=fname, + path=str(out_path.resolve()), + label_category=label_category, + size=size, + sample=sample, + bbox_xywhn=bbox_xywhn, + detection_score=detection_score, + ) + ) + + +def _bbox_from_detector( + detector: Optional[GroundingDinoDetector], + img_bgr: Optional[np.ndarray], +) -> tuple[Optional[list[float]], Optional[float]]: + if detector is None or img_bgr is None: + return None, None + return detector.detect(img_bgr) + + +def process_session( + session_dir: Path, + data_root: Path, + images_out: Path, + records: list[ImageRecord], + global_idx: list[int], + sample_every: int, + limit: int = 0, + max_width: int = 0, + max_height: int = 0, + bbox_detector: Optional[GroundingDinoDetector] = None, + vis_out_root: Optional[Path] = None, + extract_frame_fn: Callable[[Path, float], Any | None] = extract_frame_opencv_sequential, + time_sample_mode: str = "tear_first_half", +) -> int: + """处理一个叶子目录,返回成功写入的图片数量。limit>0 时最多再写入到总条数达 limit。""" + videos = list_videos(session_dir) + excels = list_excels(session_dir) + if not videos or not excels: + return 0 + + session_rel = str(session_dir.relative_to(data_root)) + n_ok = 0 + + def row_product(row: pd.Series, df: pd.DataFrame) -> tuple[str, str]: + c_name = _find_col(df, "商品名称") + c_spec = _find_col(df, "规格") + name = "" + spec = "" + if c_name is not None: + v = row.get(c_name) + if v is not None and not (isinstance(v, float) and pd.isna(v)): + name = str(v).strip() + if c_spec is not None: + v = row.get(c_spec) + if v is not None and not (isinstance(v, float) and pd.isna(v)): + spec = str(v).strip() + return normalize_haocai_class_name(name), spec + + # 两个 Excel + 两个视频:各读各表,按行与对应视频抽帧 + if len(excels) >= 2 and len(videos) >= 2: + excel_list = sorted(excels, key=_excel_pair_key) + vid_list = sorted(videos, key=_video_sort_key) + pairs = min(len(excel_list), len(vid_list), 2) + for pi in range(pairs): + df = _read_excel(excel_list[pi]) + vid = vid_list[pi] + time_col = _find_col( + df, + "视频内时间段", + "视频01内时间段", + "视频02内时间段", + ) + if time_col is None: + # 最后一列常为时间 + time_col = df.columns[-1] + for ri, (_, row) in enumerate(df.iterrows()): + if _limit_reached(records, limit): + return n_ok + tr = row.get(time_col) + pr = parse_time_range(tr) + if pr is None: + continue + t0, t1 = pr + label, size = row_product(row, df) + if not label and not size: + continue + t_mid = _sample_time_in_tear_segment( + t0, t1, mode=time_sample_mode + ) + frame = extract_frame_fn(vid, t_mid) + if frame is None: + continue + fname = _unique_image_name( + session_rel, ri, f"v{pi + 1}", str(tr) + ) + out_dir = _product_image_dir(images_out, label, size) + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / fname + saved, img_out = save_frame_jpeg( + frame, + out_path, + max_width=max_width, + max_height=max_height, + ) + if saved: + bx, ds = _bbox_from_detector(bbox_detector, img_out) + _record_saved( + records, global_idx, sample_every, + fname, out_path, label, size, + bbox_xywhn=bx, detection_score=ds, + ) + _write_vis_if_enabled( + vis_out_root, label, size, fname, img_out, bx, ds, + ) + n_ok += 1 + if _limit_reached(records, limit): + return n_ok + return n_ok + + # 单个 Excel + if len(excels) == 1: + df = _read_excel(excels[0]) + c_v1 = _find_col(df, "视频1内时间段", "视频01内时间段") + c_v2 = _find_col(df, "视频2内时间段", "视频02内时间段") + + if len(videos) >= 2 and c_v1 is not None and c_v2 is not None: + vid_list = sorted(videos, key=_video_sort_key)[:2] + for ri, (_, row) in enumerate(df.iterrows()): + for vi, (c_time, vid) in enumerate( + zip([c_v1, c_v2], vid_list) + ): + if _limit_reached(records, limit): + return n_ok + tr = row.get(c_time) + pr = parse_time_range(tr) + if pr is None: + continue + t_mid = _sample_time_in_tear_segment( + *pr, mode=time_sample_mode + ) + frame = extract_frame_fn(vid, t_mid) + if frame is None: + continue + label, size = row_product(row, df) + fname = _unique_image_name( + session_rel, ri, f"v{vi + 1}", str(tr) + ) + out_dir = _product_image_dir(images_out, label, size) + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / fname + saved, img_out = save_frame_jpeg( + frame, + out_path, + max_width=max_width, + max_height=max_height, + ) + if saved: + bx, ds = _bbox_from_detector(bbox_detector, img_out) + _record_saved( + records, global_idx, sample_every, + fname, out_path, label, size, + bbox_xywhn=bx, detection_score=ds, + ) + _write_vis_if_enabled( + vis_out_root, label, size, fname, img_out, bx, ds, + ) + n_ok += 1 + if _limit_reached(records, limit): + return n_ok + return n_ok + + # 单视频:最后一列或「视频内时间段」 + time_col = _find_col(df, "视频内时间段", "视频1内时间段") + if time_col is None: + time_col = df.columns[-1] + vid = vid_list[0] if (vid_list := sorted(videos, key=_video_sort_key)) else None + if vid is None: + return 0 + for ri, (_, row) in enumerate(df.iterrows()): + if _limit_reached(records, limit): + return n_ok + tr = row.get(time_col) + pr = parse_time_range(tr) + if pr is None: + continue + t_mid = _sample_time_in_tear_segment( + *pr, mode=time_sample_mode + ) + frame = extract_frame_fn(vid, t_mid) + if frame is None: + continue + label, size = row_product(row, df) + fname = _unique_image_name(session_rel, ri, "v1", str(tr)) + out_dir = _product_image_dir(images_out, label, size) + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / fname + saved, img_out = save_frame_jpeg( + frame, + out_path, + max_width=max_width, + max_height=max_height, + ) + if saved: + bx, ds = _bbox_from_detector(bbox_detector, img_out) + _record_saved( + records, global_idx, sample_every, + fname, out_path, label, size, + bbox_xywhn=bx, detection_score=ds, + ) + _write_vis_if_enabled( + vis_out_root, label, size, fname, img_out, bx, ds, + ) + n_ok += 1 + if _limit_reached(records, limit): + return n_ok + return n_ok + + # 其余情况:尝试用第一个 Excel + 第一个视频 + if excels and videos: + df = _read_excel(excels[0]) + time_col = _find_col(df, "视频内时间段") or df.columns[-1] + vid = sorted(videos, key=_video_sort_key)[0] + for ri, (_, row) in enumerate(df.iterrows()): + if _limit_reached(records, limit): + return n_ok + tr = row.get(time_col) + pr = parse_time_range(tr) + if pr is None: + continue + t_mid = _sample_time_in_tear_segment( + *pr, mode=time_sample_mode + ) + frame = extract_frame_fn(vid, t_mid) + if frame is None: + continue + label, size = row_product(row, df) + fname = _unique_image_name(session_rel, ri, "v1", str(tr)) + out_dir = _product_image_dir(images_out, label, size) + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / fname + saved, img_out = save_frame_jpeg( + frame, + out_path, + max_width=max_width, + max_height=max_height, + ) + if saved: + bx, ds = _bbox_from_detector(bbox_detector, img_out) + _record_saved( + records, global_idx, sample_every, + fname, out_path, label, size, + bbox_xywhn=bx, detection_score=ds, + ) + _write_vis_if_enabled( + vis_out_root, label, size, fname, img_out, bx, ds, + ) + n_ok += 1 + if _limit_reached(records, limit): + return n_ok + return n_ok + + +def main() -> int: + parser = argparse.ArgumentParser(description="浩材视频抽帧数据集生成") + parser.add_argument( + "--data-root", + type=str, + default="~/data/haocai", + help="数据根目录(默认 ~/data/haocai)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./haocai_dataset", + help="输出根目录(图片与 JSON 放在其下)", + ) + parser.add_argument( + "--json-name", + type=str, + default="dataset.json", + help="JSON 文件名(位于 output-dir 下)", + ) + parser.add_argument( + "--images-subdir", + type=str, + default="images", + help="图片子目录名(位于 output-dir 下)", + ) + parser.add_argument( + "--sample-every", + type=int, + default=0, + metavar="N", + help="全局按保存顺序计数,每第 N 张在 JSON 中 sample=true(0 表示全部 sample=false)", + ) + parser.add_argument( + "--limit", + type=int, + default=0, + metavar="N", + help="最多生成 N 条记录(与 JSON 条目数一致),用于试跑检查格式;0 表示不限制", + ) + parser.add_argument( + "--max-width", + type=int, + default=0, + metavar="PX", + help="输出 JPEG 最大宽度(像素),0=不限制(默认,保持原始分辨率)", + ) + parser.add_argument( + "--max-height", + type=int, + default=0, + metavar="PX", + help="输出 JPEG 最大高度(像素),0=不限制(默认)。与 --max-width 同时生效时缩放到可放入矩形内", + ) + parser.add_argument( + "--detect-bbox", + action="store_true", + help="用 Grounding DINO 检测人体并写 bbox 到 JSON(需 pip install transformers torch pillow)", + ) + parser.add_argument( + "--dino-model-id", + type=str, + default="IDEA-Research/grounding-dino-base", + metavar="ID", + help="Grounding DINO HuggingFace 模型 ID", + ) + parser.add_argument( + "--dino-prompt", + type=str, + default="person .", + metavar="TEXT", + help="Grounding DINO 检测 prompt(默认 'person .')", + ) + parser.add_argument( + "--dino-box-threshold", + type=float, + default=0.30, + metavar="F", + help="Grounding DINO box 置信度阈值(默认 0.30)", + ) + parser.add_argument( + "--dino-text-threshold", + type=float, + default=0.25, + metavar="F", + help="Grounding DINO text 置信度阈值(默认 0.25)", + ) + parser.add_argument( + "--save-vis", + action="store_true", + help="在 output-dir 下写入可视化图(默认子目录 vis/),与 images 同目录结构,文件名为 <原名>_vis.jpg", + ) + parser.add_argument( + "--vis-subdir", + type=str, + default="vis", + help="可视化 JPEG 所在子目录名(位于 output-dir 下,默认 vis)", + ) + parser.add_argument( + "--extract-backend", + type=str, + choices=("auto", "ffmpeg", "opencv"), + default="auto", + help="抽帧:auto=有 ffmpeg 则用 ffmpeg(推荐,HEVC 不易花屏);" + "ffmpeg=必须可用 ffmpeg;opencv=顺序解码,无 ffmpeg 时可用但较慢", + ) + parser.add_argument( + "--ffmpeg-bin", + type=str, + default="ffmpeg", + metavar="CMD", + help="ffmpeg 可执行文件名或绝对路径(默认 ffmpeg)", + ) + parser.add_argument( + "--ffprobe-bin", + type=str, + default="ffprobe", + metavar="CMD", + help="ffprobe 可执行文件名(用于时长/帧率/分辨率;默认 ffprobe)", + ) + parser.add_argument( + "--ffmpeg-fast-seek", + action="store_true", + help="快 seek:-ss 在 -i 之前,长视频抽帧快很多;默认精确 seek 从开头解码到目标时刻,故很慢", + ) + parser.add_argument( + "--sample-midpoint", + action="store_true", + help="时间段内抽帧取中点;默认取「撕」区间前半段(半段内 3/4 分位)", + ) + parser.add_argument( + "--tear-second-half", + action="store_true", + help="撕时间段内用整段后半 3/4 分位(旧默认);与默认前半段二选一", + ) + args = parser.parse_args() + + if args.sample_every < 0: + print("--sample-every 须 >= 0", file=sys.stderr) + return 2 + if args.limit < 0: + print("--limit 须 >= 0", file=sys.stderr) + return 2 + if args.max_width < 0 or args.max_height < 0: + print("--max-width / --max-height 须 >= 0", file=sys.stderr) + return 2 + bbox_detector: Optional[GroundingDinoDetector] = None + if args.detect_bbox: + try: + _log("Grounding DINO bbox detection enabled") + _log( + f"model={args.dino_model_id}, prompt={args.dino_prompt!r}, " + f"box_threshold={args.dino_box_threshold}, " + f"text_threshold={args.dino_text_threshold}" + ) + bbox_detector = GroundingDinoDetector( + model_id=args.dino_model_id, + prompt=args.dino_prompt, + box_threshold=args.dino_box_threshold, + text_threshold=args.dino_text_threshold, + ) + except Exception as e: + print( + f"启用 --detect-bbox 失败: {type(e).__name__}: {e}\n" + "请确认已安装: pip install transformers torch pillow", + file=sys.stderr, + ) + return 2 + + data_root = _expand_root(args.data_root) + out_root = _expand_root(args.output_dir) + images_out = out_root / args.images_subdir + images_out.mkdir(parents=True, exist_ok=True) + + vis_out_root: Optional[Path] = None + if args.save_vis: + vis_out_root = out_root / args.vis_subdir + vis_out_root.mkdir(parents=True, exist_ok=True) + + records: list[ImageRecord] = [] + global_idx = [0] + total = 0 + sessions = list(iter_leaf_session_dirs(data_root)) + if not sessions: + print(f"未找到叶子会话目录(需同时含 mp4 与 xlsx): {data_root}", file=sys.stderr) + + if not shutil.which(args.ffprobe_bin): + _log( + f"未找到 {args.ffprobe_bin!r},时长/帧率将仅用 OpenCV(HEVC 可能偏差);" + "建议: conda install ffmpeg 或 apt install ffmpeg" + ) + extract_frame_fn, extract_mode = make_extract_frame_fn( + args.extract_backend, + args.ffmpeg_bin, + args.ffprobe_bin, + accurate_seek=not args.ffmpeg_fast_seek, + ) + _log(f"抽帧后端: {extract_mode}") + if args.sample_midpoint: + time_sample_mode = "midpoint" + elif args.tear_second_half: + time_sample_mode = "tear_second_half" + else: + time_sample_mode = "tear_first_half" + _log( + "时间段采样: " + + ( + "中点(--sample-midpoint)" + if time_sample_mode == "midpoint" + else ( + "撕区间后半段 3/4(--tear-second-half)" + if time_sample_mode == "tear_second_half" + else "撕区间前半段(默认,半段内 3/4 分位)" + ) + ) + ) + if extract_mode.startswith("ffmpeg") and not args.ffmpeg_fast_seek: + _log( + "精确 seek(默认)在长视频、大时间戳时很慢:每次抽帧都会从文件开头解码到目标时刻。" + "若可接受略快 seek,请加 --ffmpeg-fast-seek 加速。" + ) + + for sd in sorted(sessions): + if _limit_reached(records, args.limit): + break + n = process_session( + sd, + data_root, + images_out, + records, + global_idx, + args.sample_every, + args.limit, + args.max_width, + args.max_height, + bbox_detector, + vis_out_root, + extract_frame_fn=extract_frame_fn, + time_sample_mode=time_sample_mode, + ) + total += n + print(f"{sd.relative_to(data_root)}: {n} 张") + + json_path = out_root / args.json_name + payload = [asdict(r) for r in records] + json_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8" + ) + lim_note = f"(limit={args.limit})" if args.limit > 0 else "" + vis_note = ( + f",可视化目录: {vis_out_root}" + if vis_out_root is not None + else "" + ) + print( + f"共写入 {total} 张图片{lim_note},JSON 条目 {len(records)},元数据: {json_path}{vis_note}" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/code/repo_root.py b/code/repo_root.py new file mode 100644 index 0000000..baa7e09 --- /dev/null +++ b/code/repo_root.py @@ -0,0 +1,6 @@ +"""仓库根目录常量:本文件必须位于含 dataset.py 的 code/ 根目录。""" +from __future__ import annotations + +from pathlib import Path + +CODE_ROOT = Path(__file__).resolve().parent diff --git a/code/video_clip_cls/build_dataset_run.log b/code/video_clip_cls/build_dataset_run.log new file mode 100644 index 0000000..e69de29 diff --git a/code/video_clip_cls/extract_videoswin_features.py b/code/video_clip_cls/extract_videoswin_features.py new file mode 100644 index 0000000..4b09b13 --- /dev/null +++ b/code/video_clip_cls/extract_videoswin_features.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +""" +Extract per-video temporal features for ActionFormer using VideoSwin (torchvision). + +Output: +1) One feature file per video: /.npy + Shape = [T, C], where: + - T: number of temporal feature steps + - C: input_dim (feature dimension, e.g. 768 for swin3d_t) +2) A metadata json that records critical TAL config fields: + - input_dim + - feat_stride (in seconds and frames) + - feat_num_frames +""" + +from __future__ import annotations + +import argparse +import json +import os +from pathlib import Path +from typing import Dict, List, Tuple + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from torchvision.models.video import Swin3D_T_Weights, swin3d_t + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser("Extract VideoSwin features for ActionFormer") + _here = Path(__file__).resolve().parent + parser.add_argument( + "--data-root", + type=Path, + default=_here / "outputs" / "features_videoswin" / "_placeholder_in", + help="Root folder containing source mp4 files (e2e 调用时会显式传入).", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=_here / "outputs" / "features_videoswin", + help="Folder to save per-video .npy features (e2e 调用时会显式传入).", + ) + parser.add_argument( + "--meta-file", + type=Path, + default=_here / "outputs" / "features_videoswin" / "meta.json", + help="Where to save extraction metadata (e2e 调用时会显式传入).", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="cuda or cpu.", + ) + parser.add_argument( + "--clip-len", + type=int, + default=32, + help="Number of sampled frames per clip.", + ) + parser.add_argument( + "--frame-stride", + type=int, + default=2, + help="Frame interval inside one clip (in original frame units).", + ) + parser.add_argument( + "--feat-stride-frames", + type=int, + default=16, + help="Temporal hop between adjacent features (in original frames).", + ) + parser.add_argument( + "--image-size", + type=int, + default=224, + help="Spatial resize for each frame.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=8, + help="Number of clips per forward pass.", + ) + parser.add_argument( + "--decode-clip-buffer", + type=int, + default=16, + help="Number of clips decoded per chunk to limit RAM usage.", + ) + parser.add_argument( + "--max-videos", + type=int, + default=-1, + help="Debug option: process at most N videos; -1 means all.", + ) + parser.add_argument( + "--max-seconds", + type=float, + default=-1.0, + help="If > 0, only process the first N seconds of each video.", + ) + parser.add_argument( + "--skip-existing", + action="store_true", + help="Skip feature files that already exist.", + ) + parser.add_argument( + "--exclude-keyword", + action="append", + default=[], + help="Exclude videos whose full path contains this keyword. Can be set multiple times.", + ) + parser.add_argument( + "--https-proxy", + type=str, + default="", + help="Optional proxy, e.g. http://127.0.0.1:7897, used for weight download.", + ) + return parser.parse_args() + + +def build_model(device: torch.device) -> Tuple[nn.Module, int]: + weights = Swin3D_T_Weights.KINETICS400_V1 + model = swin3d_t(weights=weights) + # Remove classifier head; output becomes penultimate embeddings. + # input_dim is 768 for swin3d_t. + model.head = nn.Identity() + model.eval().to(device) + input_dim = 768 + return model, input_dim + + +def list_videos(data_root: Path, max_videos: int) -> List[Path]: + videos = sorted(data_root.rglob("*.mp4")) + if videos: + kept = [] + for v in videos: + v_str = str(v) + if any(k and (k in v_str) for k in GLOBAL_EXCLUDE_KEYWORDS): + continue + kept.append(v) + videos = kept + if max_videos > 0: + videos = videos[:max_videos] + return videos + + +GLOBAL_EXCLUDE_KEYWORDS: List[str] = [] + + +def decode_frame_range( + cap: cv2.VideoCapture, + current_idx: int, + end_idx: int, + needed_set: set[int], +) -> Tuple[Dict[int, np.ndarray], int]: + """ + Decode frames sequentially until end_idx and keep only needed indices. + Returns: + - cached frame dict + - next frame index (one past the last decoded frame) + """ + frame_map: Dict[int, np.ndarray] = {} + + while current_idx <= end_idx: + ok, frm = cap.read() + if not ok or frm is None: + return frame_map, current_idx + if current_idx in needed_set: + frame_map[current_idx] = cv2.cvtColor(frm, cv2.COLOR_BGR2RGB) + current_idx += 1 + + return frame_map, current_idx + + +def preprocess_clip(frames: List[np.ndarray], image_size: int) -> torch.Tensor: + processed = [] + for frm in frames: + frm = cv2.resize(frm, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + frm = frm.astype(np.float32) / 255.0 + processed.append(frm) + arr = np.stack(processed, axis=0) # [T, H, W, C] + ten = torch.from_numpy(arr).permute(0, 3, 1, 2).contiguous() # [T, C, H, W] + + # Kinetics mean/std normalization + mean = torch.tensor([0.45, 0.45, 0.45], dtype=ten.dtype).view(1, 3, 1, 1) + std = torch.tensor([0.225, 0.225, 0.225], dtype=ten.dtype).view(1, 3, 1, 1) + ten = (ten - mean) / std + return ten + + +def resolve_missing_frame( + frame_map: Dict[int, np.ndarray], target_idx: int, fallback_shape: Tuple[int, int, int] +) -> np.ndarray: + """ + Return the target frame if available; otherwise use nearest decoded frame. + If none decoded, return a black frame. + """ + if target_idx in frame_map: + return frame_map[target_idx] + if frame_map: + keys = sorted(frame_map.keys()) + nearest = min(keys, key=lambda k: abs(k - target_idx)) + return frame_map[nearest] + return np.zeros(fallback_shape, dtype=np.uint8) + + +@torch.no_grad() +def extract_one_video( + model: nn.Module, + video_path: Path, + clip_len: int, + frame_stride: int, + feat_stride_frames: int, + image_size: int, + batch_size: int, + decode_clip_buffer: int, + device: torch.device, + max_seconds: float, +) -> Tuple[np.ndarray, float, int]: + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Cannot open video: {video_path}") + + fps = float(cap.get(cv2.CAP_PROP_FPS)) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if fps <= 0 or total_frames <= 0: + cap.release() + raise RuntimeError(f"Invalid video meta for {video_path}: fps={fps}, frames={total_frames}") + + effective_total_frames = total_frames + if max_seconds is not None and max_seconds > 0: + effective_total_frames = max(1, min(total_frames, int(round(max_seconds * fps)))) + + clip_span = (clip_len - 1) * frame_stride + 1 + if effective_total_frames < clip_span: + start_indices = [0] + else: + start_indices = list(range(0, effective_total_frames - clip_span + 1, feat_stride_frames)) + if (effective_total_frames - clip_span) % feat_stride_frames != 0: + start_indices.append(effective_total_frames - clip_span) + + clip_indices: List[List[int]] = [] + for s in start_indices: + frame_indices = [min(s + i * frame_stride, effective_total_frames - 1) for i in range(clip_len)] + clip_indices.append(frame_indices) + if decode_clip_buffer <= 0: + decode_clip_buffer = max(batch_size, 1) + + feats: List[np.ndarray] = [] + cur_frame_idx = 0 + + # Decode and infer in chunks to keep RAM bounded. + for chunk_start in range(0, len(clip_indices), decode_clip_buffer): + chunk = clip_indices[chunk_start : chunk_start + decode_clip_buffer] + needed_set = set() + for ids in chunk: + needed_set.update(ids) + max_need = max(needed_set) + + frame_map, cur_frame_idx = decode_frame_range(cap, cur_frame_idx, max_need, needed_set) + if cur_frame_idx <= max_need: + raise RuntimeError("Video decode terminated before required frames were read.") + + clip_tensors: List[torch.Tensor] = [] + fallback_shape = (image_size, image_size, 3) + for frame_ids in chunk: + frames = [resolve_missing_frame(frame_map, i, fallback_shape) for i in frame_ids] + clip_tensors.append(preprocess_clip(frames, image_size)) + + for i in range(0, len(clip_tensors), batch_size): + batch = torch.stack(clip_tensors[i : i + batch_size], dim=0) # [B, T, C, H, W] + batch = batch.permute(0, 2, 1, 3, 4).contiguous().to(device) # [B, C, T, H, W] + out = model(batch) # [B, C] + feats.append(out.detach().cpu().numpy()) + + # Release per-chunk caches promptly. + del frame_map + del clip_tensors + + cap.release() + feat_arr = np.concatenate(feats, axis=0).astype(np.float32) # [T, C] + return feat_arr, fps, total_frames + + +def main() -> None: + args = parse_args() + global GLOBAL_EXCLUDE_KEYWORDS + GLOBAL_EXCLUDE_KEYWORDS = args.exclude_keyword or [] + + # Try to reduce ffmpeg decoder logs in terminal. + os.environ.setdefault("OPENCV_FFMPEG_LOGLEVEL", "8") + + if args.https_proxy: + os.environ["https_proxy"] = args.https_proxy + os.environ["http_proxy"] = args.https_proxy + + args.output_dir.mkdir(parents=True, exist_ok=True) + args.meta_file.parent.mkdir(parents=True, exist_ok=True) + + device = torch.device(args.device if (args.device != "cuda" or torch.cuda.is_available()) else "cpu") + model, input_dim = build_model(device) + + videos = list_videos(args.data_root, args.max_videos) + if not videos: + raise RuntimeError(f"No mp4 found under: {args.data_root}") + + print(f"[INFO] device={device}, videos={len(videos)}") + print( + f"[INFO] input_dim={input_dim}, feat_stride_frames={args.feat_stride_frames}, " + f"feat_num_frames={args.clip_len * args.frame_stride}" + ) + + video_metas = {} + done = 0 + for idx, video_path in enumerate(videos, start=1): + rel = video_path.relative_to(args.data_root).with_suffix("") + out_file = args.output_dir / f"{rel}.npy" + out_file.parent.mkdir(parents=True, exist_ok=True) + + if args.skip_existing and out_file.exists(): + print(f"[{idx}/{len(videos)}] skip existing: {rel}") + continue + + try: + feats, fps, total_frames = extract_one_video( + model=model, + video_path=video_path, + clip_len=args.clip_len, + frame_stride=args.frame_stride, + feat_stride_frames=args.feat_stride_frames, + image_size=args.image_size, + batch_size=args.batch_size, + decode_clip_buffer=args.decode_clip_buffer, + device=device, + max_seconds=args.max_seconds, + ) + np.save(out_file, feats) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + video_metas[str(rel)] = { + "fps": fps, + "total_frames": total_frames, + "num_features": int(feats.shape[0]), + "feature_dim": int(feats.shape[1]), + "feature_file": str(out_file), + } + done += 1 + print(f"[{idx}/{len(videos)}] ok: {rel}, feat_shape={tuple(feats.shape)}") + except Exception as exc: + print(f"[{idx}/{len(videos)}] fail: {rel}, err={exc}") + + meta = { + "extractor": "torchvision_swin3d_t_kinetics400", + "data_root": str(args.data_root), + "output_dir": str(args.output_dir), + # --- ActionFormer critical fields --- + "input_dim": input_dim, + "feat_stride_frames": args.feat_stride_frames, + "feat_num_frames": args.clip_len * args.frame_stride, + # NOTE: feat_stride in seconds is video-dependent if fps varies. + # For each video: feat_stride_seconds = feat_stride_frames / fps + "per_video_stride_formula": "feat_stride_seconds = feat_stride_frames / fps", + "videos_processed": done, + "videos_total": len(videos), + "videos": video_metas, + } + with args.meta_file.open("w", encoding="utf-8") as f: + json.dump(meta, f, ensure_ascii=False, indent=2) + print(f"[DONE] metadata saved to: {args.meta_file}") + + +if __name__ == "__main__": + main() diff --git a/code/video_clip_cls/infer_single_0506/input/03视频.mp4 b/code/video_clip_cls/infer_single_0506/input/03视频.mp4 new file mode 120000 index 0000000..ed966a4 --- /dev/null +++ b/code/video_clip_cls/infer_single_0506/input/03视频.mp4 @@ -0,0 +1 @@ +/home/hg02/Project/OperatationRoomMonitor/data/haocai/5月6号视频/5月6日第二次视频/03视频.mp4 \ No newline at end of file diff --git a/code/video_clip_cls/infer_single_0506/run_segments_consumable_vote.py b/code/video_clip_cls/infer_single_0506/run_segments_consumable_vote.py new file mode 100644 index 0000000..9f81f84 --- /dev/null +++ b/code/video_clip_cls/infer_single_0506/run_segments_consumable_vote.py @@ -0,0 +1,548 @@ +#!/usr/bin/env python3 +""" +仅在「时间段 txt」内跑:人手检测 → **逐帧**好/坏门控(**top1 为 good 且 top1conf>阈值**,默认阈值 0.9) +→ 仅通过的帧跑 41 类耗材分类;(可选)仅保留 **耗材 softmax 最大值 > --haocai-min-conf** 的帧; +对保留帧的标签序列做 **滑动窗口多数票平滑**,再 **`consumable` 取平滑后序列众数**。 + +**avg_softmax_*** :仅对上述「高置信耗材帧」统计;类别为 softmax 均值分布前三;置信度为三档边际 softmax 在时间上的平均。 + +不扫全片;每段从视频中按起止时间解码。 + +用法(建议在 yolo 环境): + python code/video_clip_cls/infer_single_0506/run_segments_consumable_vote.py \\ + --segments .../03视频_segments_mutual_exclusive_score_gt_0.1.txt \\ + --video .../03视频.mp4 \\ + --out .../03视频_segments_consumables.txt +""" + +from __future__ import annotations + +import argparse +import sys +from collections import Counter +from pathlib import Path + +import cv2 +import numpy as np +from ultralytics import YOLO + +for _repo in Path(__file__).resolve().parents: + if (_repo / "repo_root.py").is_file() and (_repo / "dataset.py").is_file(): + if str(_repo) not in sys.path: + sys.path.insert(0, str(_repo)) + break +else: + raise RuntimeError("未定位到仓库 code/ 根目录") + +from repo_root import CODE_ROOT # noqa: E402 + + +def parse_segments_txt(path: Path) -> list[tuple[int, float, float]]: + rows: list[tuple[int, float, float]] = [] + for raw in path.read_text(encoding="utf-8").splitlines(): + if not raw.strip() or raw.lower().startswith("rank"): + continue + parts = raw.split("\t") + if len(parts) < 4: + continue + rank = int(parts[0].strip()) + t0 = float(parts[1].strip()) + t1 = float(parts[2].strip()) + rows.append((rank, t0, t1)) + return rows + + +def collect_hand_boxes(det_model: YOLO, boxes) -> list[list[float]]: + names = det_model.names + out: list[list[float]] = [] + for box in boxes: + cid = int(box.cls[0]) + label = names.get(cid, "") + if label == "hand": + out.append(box.xyxy[0].tolist()) + return out + + +def pad_box( + xyxy: list[float], img_w: int, img_h: int, pad_ratio: float +) -> tuple[int, int, int, int]: + x1, y1, x2, y2 = xyxy + bw, bh = x2 - x1, y2 - y1 + px, py = bw * pad_ratio, bh * pad_ratio + return ( + max(0, int(x1 - px)), + max(0, int(y1 - py)), + min(img_w, int(x2 + px)), + min(img_h, int(y2 + py)), + ) + + +def largest_hand(hands: list[list[float]]) -> list[float]: + def area(b: list[float]) -> float: + return max(0.0, b[2] - b[0]) * max(0.0, b[3] - b[1]) + + return max(hands, key=area) + + +def _float_top1conf(pr) -> float: + tc = pr.top1conf + if tc is None: + return 0.0 + if isinstance(tc, (float, int, np.floating)): + return float(tc) + return float(tc.detach().float().cpu().item()) + + +def passes_good_gate_top1_conf( + gb_model: YOLO, + crop: np.ndarray, + gb_names: dict, + imgsz: int, + top1_conf_must_exceed: float, +) -> bool: + """好/坏分类:predicted top1 为 good,且 top1conf 严格大于给定阈值。""" + if crop.size == 0: + return False + r = gb_model.predict(crop, imgsz=imgsz, verbose=False)[0] + pr = r.probs + if pr is None: + return False + tid = int(pr.top1) + label = str(gb_names.get(tid, "")).strip().lower() + conf = _float_top1conf(pr) + return label == "good" and conf > top1_conf_must_exceed + + +def haocai_softmax_probs( + cls_model: YOLO, crop: np.ndarray, imgsz: int, n_cls: int +) -> np.ndarray | None: + """耗材分类:返回长度 n_cls 的 softmax 概率向量(与模型 top1 一致)。""" + if crop.size == 0: + return None + r = cls_model.predict(crop, imgsz=imgsz, verbose=False)[0] + pr = r.probs + if pr is None or pr.data is None: + return None + v = pr.data.detach().float().cpu().numpy().astype(np.float64).ravel() + if v.size < n_cls: + v = np.resize(v, n_cls) + v = v[:n_cls].copy() + s = float(np.sum(v)) + if s <= 1e-12: + return None + # 若未归一化则 softmax + if abs(s - 1.0) > 0.08: + v = v - float(np.max(v)) + e = np.exp(np.clip(v, -40.0, 40.0)) + out = e / float(np.sum(e)) + return out + return v / s + + +def _cls_name(names: dict, idx: int) -> str: + return str(names.get(int(idx), str(idx))) + + +def mean_softmax_top3( + probs_list: list[np.ndarray], cls_names: dict +) -> tuple[list[str], list[float]]: + """ + 类名:多帧 softmax 按类逐维算术平均,在平均向量上取概率最大的前三类。 + + 置信度(与类名解耦):逐帧对 softmax 从高到低排序,取第 1/2/3 大的概率, + 再在各帧上对这三档分别做算术平均(「帧内边际 topk」的时间平均)。 + 返回三个槽位(不足则用空字符串与 0.0 补齐)。 + """ + names_out: list[str] = [] + probs_out: list[float] = [] + if not probs_list: + for _ in range(3): + names_out.append("") + probs_out.append(0.0) + return names_out, probs_out + stacked = np.stack(probs_list, axis=0) + p = np.mean(stacked, axis=0, dtype=np.float64) + order = np.argsort(-p) + for k in range(3): + if k < order.size: + j = int(order[k]) + names_out.append(_cls_name(cls_names, j)) + else: + names_out.append("") + # 逐帧降序 softmax,对第 1/2/3 档做时间平均 + row_sorted = np.sort(stacked, axis=1)[:, ::-1] + n_cls = row_sorted.shape[1] + for k in range(3): + if k < n_cls: + probs_out.append(float(np.mean(row_sorted[:, k], dtype=np.float64))) + else: + probs_out.append(0.0) + return names_out, probs_out + + +def smooth_labels_majority(labels: list[str], window: int) -> list[str]: + """ + 对时间有序的类别名做平滑:对每个位置取以该位置为中心、长度为奇数 window 的邻域, + 用邻域内众数替换(打破平局时用最邻域计数最高者)。 + window<=1 时原样返回。 + """ + if window <= 1 or not labels: + return list(labels) + w = window if window % 2 == 1 else window + 1 + half = w // 2 + n = len(labels) + out: list[str] = [] + for i in range(n): + lo = max(0, i - half) + hi = min(n, i + half + 1) + chunk = labels[lo:hi] + top, _c = Counter(chunk).most_common(1)[0] + out.append(top) + return out + + +def process_segment( + cap: cv2.VideoCapture, + det: YOLO, + gb: YOLO, + cls_m: YOLO, + *, + start_sec: float, + end_sec: float, + seek_margin_sec: float, + det_conf: float, + pad_ratio: float, + imgsz_det: int, + imgsz_cls: int, + frame_stride: int, + good_top1_conf_threshold: float, + haocai_min_conf: float, + smooth_label_window: int, + gb_names: dict, + cls_names: dict, +) -> dict: + # HEVC/部分 mp4:直接 seek 到 start 易产生坏参考帧;先往回跳再顺序解码丢到起点。 + probe_from = float(max(0.0, start_sec - seek_margin_sec)) + cap.set(cv2.CAP_PROP_POS_MSEC, probe_from * 1000.0) + synced_frame: np.ndarray | None = None + synced_t: float | None = None + tol = 0.04 + while True: + ok0, grab = cap.read() + if not ok0 or grab is None: + synced_frame, synced_t = None, None + break + t0 = float(cap.get(cv2.CAP_PROP_POS_MSEC)) / 1000.0 + if t0 + tol >= start_sec: + synced_frame, synced_t = grab, t0 + break + + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + n_cls_key_max = max(int(k) for k in cls_names.keys()) + n_cls = n_cls_key_max + 1 + + n_hand_frames = 0 + # top1==good 且 top1conf>阈值的帧数(门控通过即计数,与是否成功得到 softmax 无关) + n_gate_pass = 0 + cls_labels: list[str] = [] + cls_prob_rows: list[np.ndarray] = [] + frames_read_in_segment = 0 + + def one_frame(fr: np.ndarray, _t_abs: float) -> None: + nonlocal frames_read_in_segment, n_hand_frames, n_gate_pass, cls_labels, cls_prob_rows + frames_read_in_segment += 1 + if frame_stride > 1 and (frames_read_in_segment - 1) % frame_stride != 0: + return + + r0 = det.predict( + fr, + conf=det_conf, + imgsz=imgsz_det, + verbose=False, + )[0] + hands = collect_hand_boxes(det, r0.boxes) if r0.boxes else [] + if not hands: + return + + n_hand_frames += 1 + xyxy = largest_hand(hands) + x1, y1, x2, y2 = pad_box(xyxy, w, h, pad_ratio) + crop = fr[y1:y2, x1:x2] + ok_gate = passes_good_gate_top1_conf( + gb, crop, gb_names, imgsz_cls, good_top1_conf_threshold + ) + if ok_gate: + n_gate_pass += 1 + vec = haocai_softmax_probs(cls_m, crop, imgsz_cls, n_cls) + if vec is not None: + top_prob = float(np.max(vec)) + if top_prob <= haocai_min_conf: + return + cls_prob_rows.append(vec) + cls_labels.append(_cls_name(cls_names, int(np.argmax(vec)))) + + if synced_frame is not None and synced_t is not None: + if synced_t <= end_sec + 0.08: + one_frame(synced_frame, synced_t) + + while True: + ok, frame = cap.read() + if not ok or frame is None: + break + t = float(cap.get(cv2.CAP_PROP_POS_MSEC)) / 1000.0 + if t > end_sec + 0.08: + break + if t + 1e-6 < start_sec: + continue + one_frame(frame, t) + + if n_hand_frames == 0: + return { + "consumable": "(段内未检测到手部)", + "n_hand_frames": 0, + "n_gate_pass": 0, + "n_predictions": 0, + "top_vote_count": 0, + "avg_top1_cls": "", + "avg_top1_prob": "", + "avg_top2_cls": "", + "avg_top2_prob": "", + "avg_top3_cls": "", + "avg_top3_prob": "", + } + + if not cls_labels: + return { + "consumable": ( + "(无满足条件的耗材帧:好帧置信度或未过门控" + + ( + "" if haocai_min_conf <= 0.0 + else ",或耗材 top1 softmax 不大于阈值" + ) + + ")" + ), + "n_hand_frames": n_hand_frames, + "n_gate_pass": n_gate_pass, + "n_predictions": 0, + "top_vote_count": 0, + "avg_top1_cls": "", + "avg_top1_prob": "", + "avg_top2_cls": "", + "avg_top2_prob": "", + "avg_top3_cls": "", + "avg_top3_prob": "", + } + + smoothed = smooth_labels_majority(cls_labels, smooth_label_window) + top_name, vote_n = Counter(smoothed).most_common(1)[0] + a1, ap1 = mean_softmax_top3(cls_prob_rows, cls_names) + return { + "consumable": top_name, + "n_hand_frames": n_hand_frames, + "n_gate_pass": n_gate_pass, + "n_predictions": len(cls_labels), + "top_vote_count": int(vote_n), + "avg_top1_cls": a1[0], + "avg_top1_prob": f"{ap1[0]:.6f}", + "avg_top2_cls": a1[1], + "avg_top2_prob": f"{ap1[1]:.6f}", + "avg_top3_cls": a1[2], + "avg_top3_prob": f"{ap1[2]:.6f}", + } + + +def main() -> int: + ap = argparse.ArgumentParser( + description="手检 + 逐帧 top1=good 且 top1conf>阈值门控 + 耗材分类;段内众数" + ) + ap.add_argument( + "--segments", + type=Path, + default=Path(__file__).resolve().parent + / "results" + / "03视频_segments_mutual_exclusive_score_gt_0.1.txt", + ) + ap.add_argument( + "--video", + type=Path, + default=CODE_ROOT.parent + / "data/haocai/5月6号视频/5月6日第二次视频/03视频.mp4", + ) + ap.add_argument( + "--hand-model", + type=Path, + default=CODE_ROOT + / "hand_detection/runs/hand_det_y11s_multiframe-better/weights/best.pt", + ) + ap.add_argument( + "--goodbad-model", + type=Path, + default=CODE_ROOT + / "goodORbad_frame/runs/goodbad_frame_y11m_e50/weights/best.pt", + ) + ap.add_argument( + "--haocai-model", + type=Path, + default=CODE_ROOT + / "haocai_classify/runs/haocai_cls_41cls_goodframe_lastest-0.95" + / "weights/best.pt", + ) + ap.add_argument( + "--out", + type=Path, + default=Path(__file__).resolve().parent + / "results" + / "03视频_segments_consumables.txt", + ) + ap.add_argument( + "--good-top1-conf-threshold", + type=float, + default=0.90, + dest="good_top1_conf_threshold", + help="逐帧:仅当 top1 为 good 且 top1conf **严格大于**该值时才跑耗材分类(默认对应 top1conf>0.9)", + ) + ap.add_argument( + "--haocai-min-conf", + type=float, + default=0.0, + metavar="P", + help="耗材:仅 softmax 最大值 **严格大于** P 的帧计入标签与 softmax 统计(0 表示不按耗材置信度筛)", + ) + ap.add_argument( + "--smooth-label-window", + type=int, + default=1, + metavar="W", + help="耗材标签平滑:长度为 W 的奇数滑动窗口内多数票(W≤1 不平滑);众数取平滑后的序列", + ) + ap.add_argument("--det-conf", type=float, default=0.5) + ap.add_argument("--pad-ratio", type=float, default=0.30) + ap.add_argument("--imgsz-det", type=int, default=640) + ap.add_argument("--imgsz-cls", type=int, default=224) + ap.add_argument( + "--frame-stride", + type=int, + default=1, + help=">1 时代码逐帧解码但每 N 帧推理一次(省算力,结论可能略粗糙)", + ) + ap.add_argument( + "--seek-margin-sec", + type=float, + default=3.0, + help="HEVC 等非关键帧 seek 时往回多跳若干秒再解码到段起点,减轻花屏", + ) + args = ap.parse_args() + + seg_path = args.segments.resolve() + vid_path = args.video.resolve() + if not seg_path.is_file(): + print("找不到时间段文件:", seg_path, file=sys.stderr) + return 1 + if not vid_path.is_file(): + print("找不到视频:", vid_path, file=sys.stderr) + return 1 + for pt, lab in ( + (args.hand_model, "hand"), + (args.goodbad_model, "good/bad"), + (args.haocai_model, "haocai cls"), + ): + if not Path(pt).is_file(): + print(f"缺少{lab} 权重:", pt, file=sys.stderr) + return 1 + + segments = parse_segments_txt(seg_path) + if not segments: + print("时间段为空:", seg_path, file=sys.stderr) + return 1 + + print("加载模型…", flush=True) + det = YOLO(str(args.hand_model)) + gb = YOLO(str(args.goodbad_model)) + cls_m = YOLO(str(args.haocai_model)) + gb_names = gb.names + cls_names = cls_m.names + + cap = cv2.VideoCapture(str(vid_path)) + if not cap.isOpened(): + print("无法打开视频:", vid_path, file=sys.stderr) + return 1 + + sep = "\t" + out_lines = [ + sep.join([ + "rank", + "start_sec", + "end_sec", + "consumable", + "n_hand_frames", + "n_frames_top1_good_conf_gt_thresh", + "n_consumable_predictions", + "top_label_vote_count", + "avg_softmax_top1_cls", + "avg_softmax_top1_prob", + "avg_softmax_top2_cls", + "avg_softmax_top2_prob", + "avg_softmax_top3_cls", + "avg_softmax_top3_prob", + ]) + ] + + try: + for rank, t0, t1 in segments: + print(f"段落 rank={rank} [{t0:.3f},{t1:.3f}]s …", flush=True) + info = process_segment( + cap, + det, + gb, + cls_m, + start_sec=t0, + end_sec=t1, + seek_margin_sec=args.seek_margin_sec, + det_conf=args.det_conf, + pad_ratio=args.pad_ratio, + imgsz_det=args.imgsz_det, + imgsz_cls=args.imgsz_cls, + frame_stride=max(1, args.frame_stride), + good_top1_conf_threshold=args.good_top1_conf_threshold, + haocai_min_conf=args.haocai_min_conf, + smooth_label_window=max(1, args.smooth_label_window), + gb_names=gb_names, + cls_names=cls_names, + ) + row = sep.join([ + str(rank), + f"{t0:.6f}", + f"{t1:.6f}", + str(info["consumable"]), + str(info["n_hand_frames"]), + str(info["n_gate_pass"]), + str(info["n_predictions"]), + str(info["top_vote_count"]), + info["avg_top1_cls"], + info["avg_top1_prob"], + info["avg_top2_cls"], + info["avg_top2_prob"], + info["avg_top3_cls"], + info["avg_top3_prob"], + ]) + out_lines.append(row) + print( + f" -> {info['consumable']} " + f"(votes {info['top_vote_count']}/{info['n_predictions']}, " + f"goodgate {info['n_gate_pass']}/{info['n_hand_frames']} hand frames)", + flush=True, + ) + finally: + cap.release() + + out_path = args.out.resolve() + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text("\n".join(out_lines) + "\n", encoding="utf-8") + print("已写出:", out_path, flush=True) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/code/video_clip_cls/scripts/pipeline/__init__.py b/code/video_clip_cls/scripts/pipeline/__init__.py new file mode 100644 index 0000000..709eed1 --- /dev/null +++ b/code/video_clip_cls/scripts/pipeline/__init__.py @@ -0,0 +1 @@ +"""可复用的多模型段内处理子模块。""" diff --git a/code/video_clip_cls/scripts/pipeline/gap_adjacent_merge.py b/code/video_clip_cls/scripts/pipeline/gap_adjacent_merge.py new file mode 100644 index 0000000..a4fe618 --- /dev/null +++ b/code/video_clip_cls/scripts/pipeline/gap_adjacent_merge.py @@ -0,0 +1,132 @@ +"""相邻成功段 gap 小于阈值时合并,pairs_h 拼接后 aggregate_top3_votes。""" +from __future__ import annotations + +from dataclasses import replace +from typing import Callable + +from run_haocai_actionformer_consumables_e2e import aggregate_top3_votes + +from .tear_gate_merge import E2eRow + +_GAP_EPS = 1e-9 + + +def span_key(t0: float, t1: float) -> tuple[float, float]: + return (round(float(t0), 6), round(float(t1), 6)) + + +def group_rows_by_gap( + rows: list[E2eRow], + max_gap_sec: float = 2.0, +) -> list[list[E2eRow]]: + """左→右贪心分组;失败行单独成组且不跨组合并。""" + groups: list[list[E2eRow]] = [] + i = 0 + while i < len(rows): + row = rows[i] + if not row.is_success(): + groups.append([row]) + i += 1 + continue + grp = [row] + j = i + 1 + while j < len(rows): + nxt = rows[j] + if not nxt.is_success(): + break + gap = float(nxt.start_sec) - float(grp[-1].end_sec) + if gap < float(max_gap_sec) - _GAP_EPS: + grp.append(nxt) + j += 1 + else: + break + groups.append(grp) + i = j + return groups + + +def e2e_row_from_pairs( + start_sec: float, + end_sec: float, + pairs: list[tuple[str, float]], + product_map: dict[str, str], + *, + rank: int = 0, +) -> E2eRow: + names, confs = aggregate_top3_votes(pairs) + n1, n2, n3 = (names + ["", "", ""])[:3] + c1, c2, c3 = (confs + [0.0, 0.0, 0.0])[:3] + id1 = product_map.get(n1, "") if n1 else "" + id2 = product_map.get(n2, "") if n2 else "" + id3 = product_map.get(n3, "") if n3 else "" + + def _cf(nm: str, c: float) -> str: + return f"{c:.6f}" if nm else "" + + return E2eRow( + rank=rank, + start_sec=float(start_sec), + end_sec=float(end_sec), + id1=id1, + n1=n1, + c1=_cf(n1, c1), + id2=id2, + n2=n2, + c2=_cf(n2, c2), + id3=id3, + n3=n3, + c3=_cf(n3, c3), + ) + + +def merge_all_by_gap( + rows: list[E2eRow], + span_to_pairs: dict[tuple[float, float], list[tuple[str, float]]], + product_map: dict[str, str], + *, + max_gap_sec: float = 2.0, + log_fn: Callable[[str], None] | None = None, +) -> list[E2eRow]: + """按 gap 分组合并;组内拼接 pairs_h 后重新 aggregate top3。""" + merged: list[E2eRow] = [] + for grp in group_rows_by_gap(rows, max_gap_sec): + if len(grp) == 1: + merged.append(grp[0]) + continue + + all_pairs: list[tuple[str, float]] = [] + pair_counts: list[int] = [] + missing = False + for r in grp: + sk = span_key(r.start_sec, r.end_sec) + pairs = span_to_pairs.get(sk) + if pairs is None: + missing = True + break + pair_counts.append(len(pairs)) + all_pairs.extend(pairs) + + if missing or not all_pairs: + if log_fn and missing: + ranks = ",".join(str(r.rank) for r in grp) + log_fn(f"[gap_merge] 跳过合并 rank={ranks}(缺少 pairs_h 缓存)") + merged.extend(grp) + continue + + out_row = e2e_row_from_pairs( + grp[0].start_sec, + grp[-1].end_sec, + all_pairs, + product_map, + ) + if log_fn: + cnt_str = "+".join(str(n) for n in pair_counts) + ranks = "~".join(str(r.rank) for r in grp) + log_fn( + f"[gap_merge] 合并 rank={ranks} " + f"[{out_row.start_sec:.3f},{out_row.end_sec:.3f}] " + f"pairs 帧数 {cnt_str}={len(all_pairs)}" + ) + merged.append(out_row) + + return [replace(r, rank=i) for i, r in enumerate(merged, start=1)] diff --git a/code/video_clip_cls/scripts/pipeline/hand_roi_merge.py b/code/video_clip_cls/scripts/pipeline/hand_roi_merge.py new file mode 100644 index 0000000..7ebf819 --- /dev/null +++ b/code/video_clip_cls/scripts/pipeline/hand_roi_merge.py @@ -0,0 +1,145 @@ +""" +双手检测框分组:判断「近则合并为单 ROI、远则双 ROI」。 + +坐标系:全部在原图像素空间(与 Ultralytics xyxy 一致)。 +内存:仅产出 numpy 切片的 .copy() 小图,避免长时间引用整帧。 +""" +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class HandMergeConfig: + """两手是否合并为单个外接 ROI 的判定(OR 关系,满足任一即合并)。""" + + # IoU 严格大于该值则合并;默认 0 表示只要有交叠(IoU>0)即合并 + merge_iou_gt: float = 0.0 + # 两框中心欧氏距离(像素)不超过该值则合并;None 表示不启用该项 + merge_center_dist_max_px: float | None = None + # 中心距不超过 frame_diag * 该比例则合并;None 表示不启用(对角线 sqrt(W^2+H^2)) + merge_center_dist_max_frac_diag: float | None = None + + +def bbox_area_xyxy(b: list[float]) -> float: + x1, y1, x2, y2 = b + return max(0.0, x2 - x1) * max(0.0, y2 - y1) + + +def bbox_iou_xyxy(a: list[float], b: list[float]) -> float: + """轴对齐框 IoU。""" + ax1, ay1, ax2, ay2 = a + bx1, by1, bx2, by2 = b + ix1 = max(ax1, bx1) + iy1 = max(ay1, by1) + ix2 = min(ax2, bx2) + iy2 = min(ay2, by2) + iw = max(0.0, ix2 - ix1) + ih = max(0.0, iy2 - iy1) + inter = iw * ih + if inter <= 0: + return 0.0 + ua = bbox_area_xyxy(a) + bbox_area_xyxy(b) - inter + if ua <= 1e-12: + return 0.0 + return inter / ua + + +def bbox_center(xyxy: list[float]) -> tuple[float, float]: + x1, y1, x2, y2 = xyxy + return (0.5 * (x1 + x2), 0.5 * (y1 + y2)) + + +def bbox_center_distance(a: list[float], b: list[float]) -> float: + cx1, cy1 = bbox_center(a) + cx2, cy2 = bbox_center(b) + dx = cx1 - cx2 + dy = cy1 - cy2 + return float((dx * dx + dy * dy) ** 0.5) + + +def union_xyxy(a: list[float], b: list[float]) -> list[float]: + """两框轴对齐最小外接矩形(仍在原图坐标)。""" + ax1, ay1, ax2, ay2 = a + bx1, by1, bx2, by2 = b + return [ + min(ax1, bx1), + min(ay1, by1), + max(ax2, bx2), + max(ay2, by2), + ] + + +def two_largest_hands(hands: list[list[float]]) -> tuple[list[float], list[float]]: + """按面积取最大的两只手( hands 已非空且至少 2 个)。""" + sorted_h = sorted(hands, key=bbox_area_xyxy, reverse=True) + return sorted_h[0], sorted_h[1] + + +def hands_should_merge( + h1: list[float], + h2: list[float], + cfg: HandMergeConfig, + frame_diag: float, +) -> bool: + iou = bbox_iou_xyxy(h1, h2) + if iou > cfg.merge_iou_gt + 1e-12: + return True + d = bbox_center_distance(h1, h2) + if cfg.merge_center_dist_max_px is not None and d <= cfg.merge_center_dist_max_px + 1e-12: + return True + if ( + cfg.merge_center_dist_max_frac_diag is not None + and d <= cfg.merge_center_dist_max_frac_diag * frame_diag + 1e-12 + ): + return True + return False + + +class HandRoiGrouper: + """根据配置把手框列表转为 1~2 张 ROI(带 padding 的裁剪图)。""" + + def __init__( + self, + merge_cfg: HandMergeConfig, + pad_box_fn, + pad_ratio: float, + ) -> None: + self.merge_cfg = merge_cfg + self.pad_box_fn = pad_box_fn + self.pad_ratio = pad_ratio + + def frame_to_rois( + self, + frame: np.ndarray, + hands: list[list[float]], + ) -> list[np.ndarray]: + """ + 从整帧与手框列表得到本帧用于分类的小图列表。 + 单手:1 张;双手远:2 张;双手近:合并外接框后 1 张。 + """ + h, w = frame.shape[:2] + if not hands: + return [] + frame_diag = float((w * w + h * h) ** 0.5) + + if len(hands) == 1: + xyxy = hands[0] + x1, y1, x2, y2 = self.pad_box_fn(xyxy, w, h, self.pad_ratio) + crop = np.ascontiguousarray(frame[y1:y2, x1:x2].copy()) + return [crop] + + h1, h2 = two_largest_hands(hands) + if hands_should_merge(h1, h2, self.merge_cfg, frame_diag): + uni = union_xyxy(h1, h2) + x1, y1, x2, y2 = self.pad_box_fn(uni, w, h, self.pad_ratio) + crop = np.ascontiguousarray(frame[y1:y2, x1:x2].copy()) + return [crop] + + crops: list[np.ndarray] = [] + for xyxy in (h1, h2): + x1, y1, x2, y2 = self.pad_box_fn(xyxy, w, h, self.pad_ratio) + crops.append(np.ascontiguousarray(frame[y1:y2, x1:x2].copy())) + return crops diff --git a/code/video_clip_cls/scripts/pipeline/segment_processor.py b/code/video_clip_cls/scripts/pipeline/segment_processor.py new file mode 100644 index 0000000..d3690dc --- /dev/null +++ b/code/video_clip_cls/scripts/pipeline/segment_processor.py @@ -0,0 +1,487 @@ +""" +单段时间范围内的流式解码:多手部 ROI → 好帧门控 → 耗材 + 撕膜分类,汇总投票样本。 + +不将整段视频载入内存;每帧处理后可 del 大图与 ROI(由调用方循环内负责)。 +""" +from __future__ import annotations + +import gc +import sys +from collections import Counter +from pathlib import Path +from typing import Any, Callable + +import cv2 +import numpy as np + +for _repo in Path(__file__).resolve().parents: + if (_repo / "repo_root.py").is_file() and (_repo / "dataset.py").is_file(): + CODE_ROOT = _repo + if str(_repo) not in sys.path: + sys.path.insert(0, str(_repo)) + break +else: + raise RuntimeError("未定位到仓库 code/ 根目录") + +_SCRIPTS = CODE_ROOT / "video_clip_cls" / "scripts" +if str(_SCRIPTS) not in sys.path: + sys.path.insert(0, str(_SCRIPTS)) + +_INF = CODE_ROOT / "video_clip_cls" / "infer_single_0506" +if str(_INF) not in sys.path: + sys.path.insert(0, str(_INF)) + +import run_segments_consumable_vote as _rsv # noqa: E402 +from run_haocai_actionformer_consumables_e2e import ( # noqa: E402 + aggregate_top3_votes, + mask_probs_whitelist, +) +from ultralytics import YOLO # noqa: E402 + +from pipeline.hand_roi_merge import HandRoiGrouper # noqa: E402 + +# 与 run_haocai_actionformer_consumables_e2e 段内失败 return 文案一致,供主流程重试判断 +REASON_NO_VALID_HAOCAI_FRAMES = "(无有效耗材帧:好帧/白名单/耗材置信度未全部满足)" + +collect_hand_boxes = _rsv.collect_hand_boxes +pad_box = _rsv.pad_box +_cls_name = _rsv._cls_name + + +def _float_top1conf(pr: Any) -> float: + tc = pr.top1conf + if tc is None: + return 0.0 + if isinstance(tc, (float, int, np.floating)): + return float(tc) + return float(tc.detach().float().cpu().item()) + + +def passes_good_gate_top1_conf_kw( + gb_model: YOLO, + crop: np.ndarray, + gb_names: dict, + imgsz: int, + top1_conf_must_exceed: float, + predict_kw: dict[str, Any], +) -> bool: + """与 run_segments_consumable_vote 一致,但向 predict 透传 half/device。""" + if crop.size == 0: + return False + r = gb_model.predict(crop, imgsz=imgsz, verbose=False, **predict_kw)[0] + pr = r.probs + if pr is None: + return False + tid = int(pr.top1) + label = str(gb_names.get(tid, "")).strip().lower() + conf = _float_top1conf(pr) + return label == "good" and conf > top1_conf_must_exceed + + +def aggregate_top2_votes( + pairs: list[tuple[str, float]], +) -> tuple[list[str], list[float]]: + """与 aggregate_top3 相同思想,取前二类及次数归一化置信度。""" + empty = (["", ""], [0.0, 0.0]) + if not pairs: + return empty + cnt = Counter(p[0] for p in pairs) + ranked = sorted(cnt.items(), key=lambda x: (-x[1], x[0])) + top = ranked[:2] + if not top: + return empty + total = float(sum(c for _, c in top)) + if total <= 0: + return empty + out_names: list[str] = ["", ""] + out_conf: list[float] = [0.0, 0.0] + for i, (nm, c) in enumerate(top): + out_names[i] = nm + out_conf[i] = float(c) / total + return out_names, out_conf + + +def _clip_xyxy(box: np.ndarray, img_w: int, img_h: int) -> np.ndarray: + """ + 将 xyxy 框裁剪到图像边界,并保证 x2>x1, y2>y1。 + """ + x1, y1, x2, y2 = [float(v) for v in box] + x1 = max(0.0, min(x1, img_w - 1.0)) + y1 = max(0.0, min(y1, img_h - 1.0)) + x2 = max(0.0, min(x2, img_w - 1.0)) + y2 = max(0.0, min(y2, img_h - 1.0)) + if x2 <= x1: + x2 = min(img_w - 1.0, x1 + 1.0) + if y2 <= y1: + y2 = min(img_h - 1.0, y1 + 1.0) + return np.array([x1, y1, x2, y2], dtype=np.float32) + + +def _fuse_hands_to_one_box(hands: list[list[float]], img_w: int, img_h: int) -> np.ndarray | None: + """ + 多手框融合为一个大框(x1,y1,x2,y2),用于段内时序平滑与短时补帧。 + """ + if not hands: + return None + arr = np.asarray(hands, dtype=np.float32) + if arr.ndim != 2 or arr.shape[1] < 4: + return None + x1 = float(np.min(arr[:, 0])) + y1 = float(np.min(arr[:, 1])) + x2 = float(np.max(arr[:, 2])) + y2 = float(np.max(arr[:, 3])) + fused = np.array([x1, y1, x2, y2], dtype=np.float32) + return _clip_xyxy(fused, img_w, img_h) + + +class FineGrainedClassifier: + """好坏帧 / 耗材 / 撕膜:薄封装 Ultralytics cls.predict,便于统一 half/device。""" + + def __init__( + self, + gb: YOLO, + cls_m: YOLO, + tear_m: YOLO, + *, + gb_names: dict, + cls_names: dict, + tear_names: dict, + imgsz_cls: int, + predict_kw: dict[str, Any], + ) -> None: + self.gb = gb + self.cls_m = cls_m + self.tear_m = tear_m + self.gb_names = gb_names + self.cls_names = cls_names + self.tear_names = tear_names + self.imgsz_cls = imgsz_cls + self.predict_kw = predict_kw + + def passes_good( + self, + crop: np.ndarray, + good_top1_conf_threshold: float, + ) -> bool: + return passes_good_gate_top1_conf_kw( + self.gb, + crop, + self.gb_names, + self.imgsz_cls, + good_top1_conf_threshold, + self.predict_kw, + ) + + def haocai_label_top_prob( + self, + crop: np.ndarray, + n_cls: int, + allowed_class_idx: frozenset[int] | None, + haocai_min_conf: float, + ) -> tuple[str, float] | None: + if crop.size == 0: + return None + r = self.cls_m.predict(crop, imgsz=self.imgsz_cls, verbose=False, **self.predict_kw)[0] + pr = r.probs + if pr is None or pr.data is None: + return None + v = pr.data.detach().float().cpu().numpy().astype(np.float64).ravel() + if v.size < n_cls: + v = np.resize(v, n_cls) + v = v[:n_cls].copy() + s = float(np.sum(v)) + if s <= 1e-12: + return None + if abs(s - 1.0) > 0.08: + v = v - float(np.max(v)) + e = np.exp(np.clip(v, -40.0, 40.0)) + vec_raw = e / float(np.sum(e)) + else: + vec_raw = v / s + if allowed_class_idx is not None: + vec = mask_probs_whitelist(vec_raw, allowed_class_idx, n_cls) + else: + vec = vec_raw + if vec is None: + return None + top_prob = float(np.max(vec)) + if top_prob <= haocai_min_conf: + return None + label = int(np.argmax(vec)) + return _cls_name(self.cls_names, label), top_prob + + def tear_label_top_conf(self, crop: np.ndarray) -> tuple[str, float] | None: + if crop.size == 0: + return None + r = self.tear_m.predict(crop, imgsz=self.imgsz_cls, verbose=False, **self.predict_kw)[0] + pr = r.probs + if pr is None: + return None + tid = int(pr.top1) + conf = _float_top1conf(pr) + return str(self.tear_names.get(tid, str(tid))).strip(), conf + + +def _maybe_cuda_empty_cache(every: int, frame_idx: int) -> None: + if every <= 0: + return + if frame_idx % every != 0: + return + gc.collect() + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass + + +def process_segment_multi_hand_tear( + cap: cv2.VideoCapture, + det: YOLO, + fg: FineGrainedClassifier, + grouper: HandRoiGrouper, + *, + start_sec: float, + end_sec: float, + seek_margin_sec: float, + det_conf: float, + imgsz_det: int, + frame_stride: int, + good_top1_conf_threshold: float, + haocai_min_conf: float, + cls_names: dict, + allowed_class_idx: frozenset[int] | None, + tracking_alpha: float = 0.6, + tracking_max_lost_frames: int = 0, + empty_cache_every: int = 0, +) -> dict[str, Any]: + """ + 与 process_segment_e2e 相同 seek 策略;每帧最多两 ROI,逐 ROI做好帧+耗材+撕膜门控。 + """ + probe_from = float(max(0.0, start_sec - seek_margin_sec)) + cap.set(cv2.CAP_PROP_POS_MSEC, probe_from * 1000.0) + synced_frame: np.ndarray | None = None + synced_t: float | None = None + tol = 0.04 + while True: + ok0, grab = cap.read() + if not ok0 or grab is None: + synced_frame, synced_t = None, None + break + t0 = float(cap.get(cv2.CAP_PROP_POS_MSEC)) / 1000.0 + if t0 + tol >= start_sec: + synced_frame, synced_t = grab, t0 + break + + n_cls_key_max = max(int(k) for k in cls_names.keys()) + n_cls = n_cls_key_max + 1 + + n_hand_frames = 0 + n_gate_pass = 0 + # pairs_h 存放段内耗材候选 (类名, 置信度),后续会做“按置信度加权”的段内投票聚合。 + # 仅记录通过门控的样本;失败分支仍按是否为空来判定,不改变既有逻辑。 + pairs_h: list[tuple[str, float]] = [] + pairs_t: list[tuple[str, float]] = [] + frames_read_in_segment = 0 + # 追踪状态仅在单个 segment 生命周期内有效。 + prev_hand_box: np.ndarray | None = None + lost_frame_count = 0 + alpha = float(np.clip(tracking_alpha, 0.0, 1.0)) + max_lost = max(0, int(tracking_max_lost_frames)) + + def one_frame(fr: np.ndarray) -> None: + nonlocal frames_read_in_segment, n_hand_frames, n_gate_pass + nonlocal pairs_h, pairs_t, prev_hand_box, lost_frame_count + frames_read_in_segment += 1 + idx_local = frames_read_in_segment + _maybe_cuda_empty_cache(empty_cache_every, idx_local) + + if frame_stride > 1 and (frames_read_in_segment - 1) % frame_stride != 0: + return + + img_h, img_w = fr.shape[:2] + r0 = det.predict(fr, conf=det_conf, imgsz=imgsz_det, verbose=False, **fg.predict_kw)[0] + hands = collect_hand_boxes(det, r0.boxes) if r0.boxes else [] + current_box = _fuse_hands_to_one_box(hands, img_w, img_h) + hands_for_roi: list[list[float]] + + if current_box is not None: + # EMA 对齐 x1,y1,x2,y2 四个坐标,缓解跨帧抖动。 + if prev_hand_box is not None: + smoothed_box = alpha * current_box + (1.0 - alpha) * prev_hand_box + smoothed_box = _clip_xyxy(smoothed_box, img_w, img_h) + else: + smoothed_box = current_box + prev_hand_box = smoothed_box + lost_frame_count = 0 + hands_for_roi = [smoothed_box.tolist()] + elif prev_hand_box is not None and lost_frame_count < max_lost: + # 短时漏检补帧:复用上一次平滑框继续分类链路。 + lost_frame_count += 1 + hands_for_roi = [prev_hand_box.tolist()] + else: + return + + n_hand_frames += 1 + rois = grouper.frame_to_rois(fr, hands_for_roi) + for crop in rois: + if not fg.passes_good(crop, good_top1_conf_threshold): + del crop + continue + n_gate_pass += 1 + hc = fg.haocai_label_top_prob( + crop, n_cls, allowed_class_idx, haocai_min_conf + ) + tr = fg.tear_label_top_conf(crop) + del crop + if hc is not None: + pairs_h.append(hc) + if tr is not None: + pairs_t.append(tr) + + if synced_frame is not None and synced_t is not None and synced_t <= end_sec + 0.08: + one_frame(synced_frame) + del synced_frame + synced_frame = None + + while True: + ok, frame = cap.read() + if not ok or frame is None: + break + t = float(cap.get(cv2.CAP_PROP_POS_MSEC)) / 1000.0 + if t > end_sec + 0.08: + del frame + break + if t + 1e-6 < start_sec: + del frame + continue + one_frame(frame) + del frame + + gc.collect() + if empty_cache_every > 0: + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass + + if n_hand_frames == 0: + return {"ok": False, "reason": "(段内未检测到手部)", "pairs_h": [], "pairs_t": [], "n_gate_pass": 0} + if not pairs_h: + return { + "ok": False, + "reason": REASON_NO_VALID_HAOCAI_FRAMES, + "pairs_h": [], + "pairs_t": pairs_t, + "n_hand_frames": n_hand_frames, + "n_gate_pass": n_gate_pass, + } + + n1, c1 = aggregate_top3_votes(pairs_h) + t1, t2 = aggregate_top2_votes(pairs_t) + return { + "ok": True, + "top_names": n1, + "top_confs": c1, + "tear_top_names": t1, + "tear_top_confs": t2, + "pairs_h": pairs_h, + "pairs_t": pairs_t, + "n_hand_frames": n_hand_frames, + "n_gate_pass": n_gate_pass, + "n_valid_haocai": len(pairs_h), + } + + +def process_segment_multi_hand_tear_with_gate_retries( + cap: cv2.VideoCapture, + det: YOLO, + fg: FineGrainedClassifier, + grouper: HandRoiGrouper, + *, + start_sec: float, + end_sec: float, + seek_margin_sec: float, + det_conf: float, + imgsz_det: int, + frame_stride: int, + good_top1_conf_threshold: float, + good_top1_retry_threshold: float, + haocai_min_conf: float, + haocai_min_conf_retry: float | None, + cls_names: dict, + allowed_class_idx: frozenset[int] | None, + empty_cache_every: int = 0, + log_fn: Callable[[str], None] | None = None, + log_prefix: str | None = None, + tracking_alpha: float = 0.6, + tracking_max_lost_frames: int = 0, +) -> dict[str, Any]: + """ + 先跑段内推理;若仍为「无有效耗材帧」则: + 1) 可放宽好帧 top1 阈值(good_top1_retry_threshold)再试; + 2) 再放宽耗材置信阈值(haocai_min_conf_retry)再试。 + log_fn / log_prefix:重试时各打一行(如 log_prefix='段落 rank=3: ')。 + """ + + def run(good_thr: float, haocai_thr: float) -> dict[str, Any]: + return process_segment_multi_hand_tear( + cap, + det, + fg, + grouper, + start_sec=start_sec, + end_sec=end_sec, + seek_margin_sec=seek_margin_sec, + det_conf=det_conf, + imgsz_det=imgsz_det, + frame_stride=frame_stride, + tracking_alpha=tracking_alpha, + tracking_max_lost_frames=tracking_max_lost_frames, + good_top1_conf_threshold=good_thr, + haocai_min_conf=haocai_thr, + cls_names=cls_names, + allowed_class_idx=allowed_class_idx, + empty_cache_every=empty_cache_every, + ) + + good_thr = float(good_top1_conf_threshold) + haocai_thr = float(haocai_min_conf) + info = run(good_thr, haocai_thr) + + rgb = float(good_top1_retry_threshold) + if ( + not info.get("ok") + and str(info.get("reason", "")) == REASON_NO_VALID_HAOCAI_FRAMES + and rgb > 0 + and rgb < good_thr - 1e-12 + ): + if log_fn and log_prefix: + log_fn( + f"{log_prefix}以 good_top1_conf_threshold={rgb} 重试本段(无有效耗材帧)…" + ) + good_thr = rgb + info = run(good_thr, haocai_thr) + + if ( + haocai_min_conf_retry is not None + and haocai_min_conf_retry > 1e-12 + and haocai_min_conf_retry < haocai_thr - 1e-12 + ): + if ( + not info.get("ok") + and str(info.get("reason", "")) == REASON_NO_VALID_HAOCAI_FRAMES + ): + h2 = float(haocai_min_conf_retry) + if log_fn and log_prefix: + log_fn( + f"{log_prefix}以 haocai_min_conf={h2} 重试本段(无有效耗材帧)…" + ) + info = run(good_thr, h2) + + return info diff --git a/code/video_clip_cls/scripts/pipeline/tear_gate_merge.py b/code/video_clip_cls/scripts/pipeline/tear_gate_merge.py new file mode 100644 index 0000000..ab47b5c --- /dev/null +++ b/code/video_clip_cls/scripts/pipeline/tear_gate_merge.py @@ -0,0 +1,350 @@ +""" +相邻成功行若 top1 相同:在下一段开头 head_sec 内统计「撕膜」高置信帧数; +>= tear_min_frames 视为两次耗材(不合并),否则合并为一段。 + +main_pipeline 内:默认在门控窗口内 **手检 → 双手 ROI(与 Phase2 相同合并策略)→ 撕膜分类**; +若未传入 det/grouper 则退化为 **整帧** 撕膜(与旧 pack merge 脚本一致)。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Any + +import cv2 +from ultralytics import YOLO + +from .hand_roi_merge import HandRoiGrouper + +try: + from run_segments_consumable_vote import collect_hand_boxes +except ImportError: # 脚本独立运行时无 path + collect_hand_boxes = None # type: ignore[misc, assignment] + + +@dataclass +class E2eRow: + rank: int + start_sec: float + end_sec: float + id1: str + n1: str + c1: str + id2: str + n2: str + c2: str + id3: str + n3: str + c3: str + + def is_success(self) -> bool: + if not self.n1.strip(): + return False + try: + float(self.c1.strip()) + return True + except ValueError: + return False + + def to_line12(self, rank: int) -> str: + r = replace(self, rank=rank) + return "\t".join( + [ + str(r.rank), + f"{r.start_sec:.6f}", + f"{r.end_sec:.6f}", + r.id1, + r.n1, + r.c1, + r.id2, + r.n2, + r.c2, + r.id3, + r.n3, + r.c3, + ] + ) + + +def parse_e2e_rows_from_body_lines(lines: list[str]) -> list[E2eRow]: + rows: list[E2eRow] = [] + for i, line in enumerate(lines, start=2): + if not line.strip(): + continue + parts_line = line.split("\t") + while len(parts_line) < 12: + parts_line.append("") + parts_line = parts_line[:12] + try: + rank = int(parts_line[0]) + s = float(parts_line[1]) + e = float(parts_line[2]) + except ValueError as ex: + raise ValueError(f"第{i}行解析失败: {line[:80]}...") from ex + rows.append( + E2eRow( + rank=rank, + start_sec=s, + end_sec=e, + id1=parts_line[3], + n1=parts_line[4], + c1=parts_line[5], + id2=parts_line[6], + n2=parts_line[7], + c2=parts_line[8], + id3=parts_line[9], + n3=parts_line[10], + c3=parts_line[11], + ) + ) + return rows + + +def tear_class_index(model: YOLO, class_name: str) -> int: + names: dict[int, str] = model.names # type: ignore[assignment] + for k, v in names.items(): + if str(v).strip() == class_name: + return int(k) + lower = {str(v).strip().lower(): int(k) for k, v in names.items()} + if lower.get(class_name.lower()) is not None: + return lower[class_name.lower()] + raise ValueError(f"模型中无类别「{class_name}」,names={names}") + + +def count_tearing_frames( + cap: cv2.VideoCapture, + window_start: float, + window_end: float, + yolo: YOLO, + tear_cls: int, + tear_prob: float, + imgsz: int, + *, + predict_kw: dict[str, Any] | None = None, + det: YOLO | None = None, + grouper: HandRoiGrouper | None = None, + imgsz_det: int = 640, + det_conf: float = 0.5, +) -> int: + """[window_start, window_end) 内逐帧统计:P(tear_cls) >= tear_prob 的帧数。 + + 若提供 det+grouper:每帧先检测手,再对每个 ROI 跑撕膜;**任一 ROI** 达到阈值则该帧计 1。 + 否则对 **整帧** 跑一次撕膜(与旧 merge_e2e 一致)。 + """ + pred_tear: dict[str, Any] = {"imgsz": imgsz, "verbose": False} + pred_det: dict[str, Any] = {"imgsz": imgsz_det, "verbose": False} + if predict_kw: + pred_tear.update(predict_kw) + pred_det.update(predict_kw) + use_hand = ( + det is not None + and grouper is not None + and collect_hand_boxes is not None + ) + cap.set(cv2.CAP_PROP_POS_MSEC, max(0.0, window_start) * 1000.0) + cnt = 0 + while True: + ok, frame = cap.read() + if not ok or frame is None: + break + t = float(cap.get(cv2.CAP_PROP_POS_MSEC)) / 1000.0 + if t >= window_end - 1e-9: + break + if t + 1e-6 < window_start: + continue + if use_hand: + r0 = det.predict( # type: ignore[union-attr] + frame, conf=det_conf, **pred_det + )[0] + hands = collect_hand_boxes(det, r0.boxes) if r0.boxes else [] # type: ignore[arg-type] + if not hands: + continue + rois = grouper.frame_to_rois(frame, hands) # type: ignore[union-attr] + frame_hit = False + for crop in rois: + if crop is None or crop.size == 0: + continue + res = yolo.predict(crop, **pred_tear)[0] + if res.probs is None: + continue + prob_tear = float(res.probs.data[tear_cls].item()) + if prob_tear >= tear_prob - 1e-12: + frame_hit = True + break + if frame_hit: + cnt += 1 + else: + res = yolo.predict(frame, **pred_tear)[0] + if res.probs is None: + continue + prob_tear = float(res.probs.data[tear_cls].item()) + if prob_tear >= tear_prob - 1e-12: + cnt += 1 + return cnt + + +def merge_two_segments(a: E2eRow, b: E2eRow) -> E2eRow: + n1 = a.n1.strip() + fc1 = max(float(a.c1.strip()), float(b.c1.strip())) + c1s = f"{fc1:.6f}" + + id1 = a.id1.strip() or b.id1.strip() + + top1_name = n1 + cands: list[tuple[str, float, str]] = [] + for row in (a, b): + for nm, cf, pid in ( + (row.n2.strip(), row.c2.strip(), row.id2.strip()), + (row.n3.strip(), row.c3.strip(), row.id3.strip()), + ): + if not nm or not cf: + continue + try: + cff = float(cf) + except ValueError: + continue + if nm == top1_name: + continue + cands.append((nm, cff, pid)) + + cands.sort(key=lambda x: -x[1]) + seen: set[str] = set() + picked: list[tuple[str, float, str]] = [] + for nm, cff, pid in cands: + if nm in seen: + continue + seen.add(nm) + picked.append((nm, cff, pid)) + if len(picked) >= 2: + break + + id2 = n2 = c2 = id3 = n3 = c3 = "" + if len(picked) >= 1: + n2, c2f, id2 = picked[0] + c2 = f"{c2f:.6f}" + if len(picked) >= 2: + n3, c3f, id3 = picked[1] + c3 = f"{c3f:.6f}" + + return E2eRow( + rank=0, + start_sec=a.start_sec, + end_sec=b.end_sec, + id1=id1, + n1=n1, + c1=c1s, + id2=id2, + n2=n2, + c2=c2, + id3=id3, + n3=n3, + c3=c3, + ) + + +def one_pass_merge( + rows: list[E2eRow], + cap: cv2.VideoCapture, + yolo: YOLO, + tear_cls: int, + *, + head_sec: float, + tear_prob: float, + tear_min_frames: int, + imgsz: int, + predict_kw: dict[str, Any] | None, + verbose: bool, + det: YOLO | None = None, + grouper: HandRoiGrouper | None = None, + imgsz_det: int = 640, + det_conf: float = 0.5, +) -> tuple[list[E2eRow], bool]: + out: list[E2eRow] = [] + i = 0 + changed = False + while i < len(rows): + a = rows[i] + if i + 1 >= len(rows): + out.append(a) + break + b = rows[i + 1] + same_top1 = ( + a.is_success() + and b.is_success() + and a.n1.strip() == b.n1.strip() + ) + if same_top1: + w0 = b.start_sec + w1 = min(b.start_sec + head_sec, b.end_sec) + n_high = count_tearing_frames( + cap, + w0, + w1, + yolo, + tear_cls, + tear_prob, + imgsz, + predict_kw=predict_kw, + det=det, + grouper=grouper, + imgsz_det=imgsz_det, + det_conf=det_conf, + ) + if verbose: + mode = "hand_roi" if det is not None and grouper is not None else "full_frame" + print( + f"[tear_gate:{mode}] 窗口 [{w0:.3f},{w1:.3f})s(下一段起点起 head_sec={head_sec:g}s,截断至本段 end) " + f"P(tearing)>={tear_prob} 计数={n_high} (保留两段需>={tear_min_frames})", + flush=True, + ) + if n_high >= tear_min_frames: + out.append(a) + out.append(b) + else: + out.append(merge_two_segments(a, b)) + changed = True + i += 2 + else: + out.append(a) + i += 1 + return out, changed + + +def merge_all( + rows: list[E2eRow], + cap: cv2.VideoCapture, + yolo: YOLO, + tear_cls: int, + *, + head_sec: float, + tear_prob: float, + tear_min_frames: int, + imgsz: int, + predict_kw: dict[str, Any] | None = None, + verbose: bool = False, + det: YOLO | None = None, + grouper: HandRoiGrouper | None = None, + imgsz_det: int = 640, + det_conf: float = 0.5, +) -> list[E2eRow]: + cur = rows + while True: + cur, changed = one_pass_merge( + cur, + cap, + yolo, + tear_cls, + head_sec=head_sec, + tear_prob=tear_prob, + tear_min_frames=tear_min_frames, + imgsz=imgsz, + predict_kw=predict_kw, + verbose=verbose, + det=det, + grouper=grouper, + imgsz_det=imgsz_det, + det_conf=det_conf, + ) + if not changed: + break + return cur diff --git a/code/video_clip_cls/scripts/run_haocai_actionformer_consumables_e2e.py b/code/video_clip_cls/scripts/run_haocai_actionformer_consumables_e2e.py new file mode 100644 index 0000000..f2765da --- /dev/null +++ b/code/video_clip_cls/scripts/run_haocai_actionformer_consumables_e2e.py @@ -0,0 +1,839 @@ +#!/usr/bin/env python3 +""" +单视频端到端:VideoSwin 特征 → ActionFormer 划段 → 分数引导边界切割+score 过滤 → +手检 + 好帧(>阈值) + 白名单裁剪 + 耗材(softmax max>阈值) → 段内在有效帧上对类名计数,取 **票数前三**, +再以这三类出现次数 **归一化** 为 top1~3 置信度(三项和为 1;不足三类则空位补 0)。 +商品 id 来自 Excel「产品编码」。 +""" + +from __future__ import annotations + +import argparse +import json +import os +import pickle +import shutil +import subprocess +import sys +import tempfile +import time +from collections import defaultdict +from pathlib import Path +from typing import Any + +import cv2 +import numpy as np +from ultralytics import YOLO + +for _repo in Path(__file__).resolve().parents: + if (_repo / "repo_root.py").is_file() and (_repo / "dataset.py").is_file(): + if str(_repo) not in sys.path: + sys.path.insert(0, str(_repo)) + break +else: + raise RuntimeError("未定位到仓库 code/ 根目录") + +from repo_root import CODE_ROOT # noqa: E402 + +# 单文件夹打包:由 run.py 设置 HAOCAI_E2E_BUNDLE=解压根目录,权重/Excel 走包内路径,ActionFormer 在 /actionformer_release +_BUNDLE_ENV = os.environ.get("HAOCAI_E2E_BUNDLE", "").strip() +_BUNDLE_ROOT: Path | None = Path(_BUNDLE_ENV).resolve() if _BUNDLE_ENV else None + +if _BUNDLE_ROOT is not None: + _DEFAULT_EXCEL = _BUNDLE_ROOT / "data" / "视频中的商品信息表.xlsx" + _DEFAULT_AF_CKPT = _BUNDLE_ROOT / "models" / "actionformer_epoch_045.pth.tar" + _DEFAULT_HAND = _BUNDLE_ROOT / "models" / "hand_detect.pt" + _DEFAULT_GOODBAD = _BUNDLE_ROOT / "models" / "goodbad_frame.pt" + _DEFAULT_HAOCAI = _BUNDLE_ROOT / "models" / "haocai_classify.pt" +else: + _DEFAULT_EXCEL = CODE_ROOT.parent / "data/haocai/视频中的商品信息表.xlsx" + _DEFAULT_AF_CKPT = ( + CODE_ROOT + / "video_clip_cls/runs/actionformer_ckpt/haocai_main_perspective_videoswin_haocai_main_perspective_videoswin/epoch_045.pth.tar" + ) + _DEFAULT_HAND = CODE_ROOT / "hand_detection/runs/hand_det_y11s_multiframe-better/weights/best.pt" + _DEFAULT_GOODBAD = CODE_ROOT / "goodORbad_frame/runs/goodbad_frame_y11m_e50/weights/best.pt" + _DEFAULT_HAOCAI = ( + CODE_ROOT / "haocai_classify/runs/haocai_cls_41cls_goodframe_lastest-0.95/weights/best.pt" + ) + + +def _actionformer_release_dir() -> Path: + if _BUNDLE_ROOT is not None: + return _BUNDLE_ROOT / "actionformer_release" + return CODE_ROOT / "actionformer_release" + + +# 耗材投票:复用片段推理工具(infer_single_0506 为平铺目录,非 package) +_SYS_INSERT = str(CODE_ROOT / "video_clip_cls" / "infer_single_0506") +if _SYS_INSERT not in sys.path: + sys.path.insert(0, _SYS_INSERT) +import run_segments_consumable_vote as _rsv # noqa: E402 + +collect_hand_boxes = _rsv.collect_hand_boxes +haocai_softmax_probs = _rsv.haocai_softmax_probs +largest_hand = _rsv.largest_hand +pad_box = _rsv.pad_box +passes_good_gate_top1_conf = _rsv.passes_good_gate_top1_conf +_cls_name = _rsv._cls_name + +try: + import pandas as pd +except ImportError as e: + raise SystemExit("需要 pandas / openpyxl 读取 Excel:pip install pandas openpyxl") from e + +# ---------- 与训练/曾用 infer 对齐的 VideoSwin 参数 ---------- +FEAT_STRIDE_FRAMES = 8 +CLIP_LEN = 16 +FRAME_STRIDE = 1 +INPUT_DIM = 768 + + +def log(msg: str) -> None: + print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) + + +def load_product_code_map(excel_path: Path) -> dict[str, str]: + """商品名称 -> 产品编码。""" + df = pd.read_excel(excel_path, sheet_name=0, header=0) + col_code = "产品编码" + col_name = "商品名称" + if col_code not in df.columns or col_name not in df.columns: + df = pd.read_excel(excel_path, sheet_name=0, header=None) + col_code, col_name = df.columns[1], df.columns[2] + m: dict[str, str] = {} + for _, row in df.iterrows(): + name = row[col_name] + code = row[col_code] + if pd.isna(name) or str(name).strip() == "": + continue + name_s = str(name).strip() + if name_s not in m: + m[name_s] = str(code) if not pd.isna(code) else "" + return m + + +def mask_probs_whitelist( + probs: np.ndarray, + allowed: frozenset[int], + n_cls: int, +) -> np.ndarray | None: + v = np.asarray(probs, dtype=np.float64).ravel() + if v.size < n_cls: + v = np.resize(v, n_cls) + v = v[:n_cls].copy() + out = np.zeros_like(v) + for i in allowed: + if 0 <= i < n_cls: + out[i] = v[i] + s = float(np.sum(out)) + if s < 1e-12: + return None + return out / s + + +def allowed_indices_from_json_names( + allowed_names: list[str], cls_names: dict +) -> frozenset[int] | None: + """None 表示不按名称裁剪(全类)。""" + if not allowed_names: + return None + idx_by_name: dict[str, int] = {} + for k, v in cls_names.items(): + nm = str(v).strip() + if nm and nm not in idx_by_name: + idx_by_name[nm] = int(k) + out: set[int] = set() + for n in allowed_names: + ns = str(n).strip() + if ns in idx_by_name: + out.add(idx_by_name[ns]) + if not out: + log("警告: allowed_names 与模型类名无交集,白名单裁剪将不生效(等同全类)。") + return None + return frozenset(out) + + +def load_whitelist_json(path: Path) -> list[str]: + data = json.loads(path.read_text(encoding="utf-8")) + if isinstance(data, dict) and "allowed_names" in data: + raw = data["allowed_names"] + elif isinstance(data, list): + raw = data + else: + raise ValueError("白名单 JSON 应为 {\"allowed_names\": [...]} 或名称数组") + return [str(x).strip() for x in raw if str(x).strip()] + + +def run_feature_extraction( + *, + python_exe: str, + data_root: Path, + output_dir: Path, + meta_file: Path, + device: str, + batch_size: int, +) -> None: + ext_script = CODE_ROOT / "video_clip_cls" / "extract_videoswin_features.py" + cmd = [ + python_exe, + str(ext_script), + "--data-root", + str(data_root), + "--output-dir", + str(output_dir), + "--meta-file", + str(meta_file), + "--device", + device, + "--clip-len", + str(CLIP_LEN), + "--frame-stride", + str(FRAME_STRIDE), + "--feat-stride-frames", + str(FEAT_STRIDE_FRAMES), + "--batch-size", + str(batch_size), + "--max-videos", + "1", + ] + log("运行 VideoSwin 特征提取…") + env = os.environ.copy() + env.setdefault("OPENCV_FFMPEG_LOGLEVEL", "8") + r = subprocess.run(cmd, cwd=str(CODE_ROOT), env=env, check=False) + if r.returncode != 0: + raise RuntimeError(f"特征提取失败,exit={r.returncode}") + + +def write_infer_json( + out_path: Path, + video_id: str, + duration: float, + fps: float, +) -> None: + payload = { + "version": "haocai_infer_single_v1", + "taxonomy": [{"nodeName": "Action", "nodeId": 0}], + "database": { + video_id: { + "subset": "val", + "duration": float(duration), + "fps": float(fps), + "annotations": [ + {"segment": [0.0, min(1.0, duration)], "label": "Action", "label_id": 0} + ], + } + }, + } + out_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + + +def write_infer_yaml(out_path: Path, json_file: Path, feat_folder: Path) -> None: + jf = str(json_file.resolve()) + ff = str(feat_folder.resolve()) + nf = CLIP_LEN * FRAME_STRIDE + text = f"""dataset_name: thumos +devices: [0] +train_split: ['train'] +val_split: ['val'] + +dataset: + json_file: "{jf}" + feat_folder: "{ff}" + file_prefix: null + file_ext: ".npy" + num_classes: 1 + input_dim: {INPUT_DIM} + feat_stride: {FEAT_STRIDE_FRAMES} + num_frames: {nf} + default_fps: null + downsample_rate: 1 + trunc_thresh: 0.5 + crop_ratio: [0.9, 1.0] + max_seq_len: 2304 + force_upsampling: false + +model: + fpn_type: identity + max_buffer_len_factor: 6.0 + n_mha_win_size: 19 + n_head: 4 + embd_dim: 256 + fpn_dim: 256 + head_dim: 256 + use_abs_pe: false + +loader: + batch_size: 1 + num_workers: 2 + +test_cfg: + voting_thresh: 0.75 + pre_nms_topk: 4000 + max_seg_num: 600 + min_score: 0.001 + iou_threshold: 0.1 + duration_thresh: 0.05 + nms_method: soft + nms_sigma: 0.5 + multiclass_nms: true +""" + out_path.write_text(text, encoding="utf-8") + + +def run_actionformer_eval( + *, + python_exe: str, + yaml_path: Path, + ckpt_path: Path, + copy_pkl_to: Path, +) -> None: + af_dir = _actionformer_release_dir() + eval_py = af_dir / "eval.py" + cmd = [python_exe, str(eval_py), str(yaml_path), str(ckpt_path), "--saveonly"] + log("运行 ActionFormer eval(saveonly)…") + r = subprocess.run(cmd, cwd=str(af_dir), check=False) + if r.returncode != 0: + raise RuntimeError(f"ActionFormer eval 失败,exit={r.returncode}") + src_pkl = ckpt_path.parent / "eval_results.pkl" + if not src_pkl.is_file(): + raise FileNotFoundError(f"未找到输出: {src_pkl}") + shutil.copy2(src_pkl, copy_pkl_to) + log(f"已复制 eval_results.pkl -> {copy_pkl_to}") + + +def segments_overlap(s0: float, e0: float, s1: float, e1: float) -> bool: + inter = min(e0, e1) - max(s0, s1) + return inter > 1e-6 + + +def greedy_mutual_exclusive( + items: list[tuple[float, float, float]], +) -> list[tuple[float, float, float]]: + """items: (t_start, t_end, score)。按 score 降序;与已选段重叠则整段丢弃。""" + sorted_items = sorted(items, key=lambda x: -x[2]) + picked: list[tuple[float, float, float]] = [] + for s, e, sc in sorted_items: + if any(segments_overlap(s, e, ps, pe) for ps, pe, _ in picked): + continue + picked.append((s, e, sc)) + picked.sort(key=lambda x: x[0]) + return picked + + +_INTERVAL_EPS = 1e-6 +_IOU_NMS_THRESHOLD = 0.4 +_HYBRID_MIN_LEN = 1.5 + + +def segment_iou_1d(s0: float, e0: float, s1: float, e1: float) -> float: + """一维时间段 IoU;无交集或 union<=0 时返回 0.0。""" + inter = max(0.0, min(e0, e1) - max(s0, s1)) + if inter <= _INTERVAL_EPS: + return 0.0 + union = max(e0, e1) - min(s0, s1) + if union <= _INTERVAL_EPS: + return 0.0 + return inter / union + + +def _subtract_interval( + s: float, e: float, ps: float, pe: float +) -> list[tuple[float, float]]: + """从 [s,e] 挖掉 blocker [ps,pe],返回 0~2 个不重叠子区间。""" + if min(e, pe) - max(s, ps) <= _INTERVAL_EPS: + return [(s, e)] + out: list[tuple[float, float]] = [] + if ps - s > _INTERVAL_EPS: + out.append((s, min(e, ps))) + if e - pe > _INTERVAL_EPS: + out.append((max(s, pe), e)) + return out + + +def hybrid_nms_and_trimming( + items: list[tuple[float, float, float]], + iou_threshold: float = _IOU_NMS_THRESHOLD, + min_len: float = _HYBRID_MIN_LEN, +) -> list[tuple[float, float, float]]: + """混合后处理:IoU NMS 去重 → 边界切割 → 最短片段过滤。""" + sorted_items = sorted(items, key=lambda x: -x[2]) + picked: list[tuple[float, float, float]] = [] + for s, e, sc in sorted_items: + if e - s <= _INTERVAL_EPS: + continue + if any( + segment_iou_1d(s, e, ps, pe) > iou_threshold + _INTERVAL_EPS + for ps, pe, _ in picked + ): + continue + frags: list[tuple[float, float]] = [(s, e)] + for ps, pe, _ in picked: + nxt: list[tuple[float, float]] = [] + for fs, fe in frags: + nxt.extend(_subtract_interval(fs, fe, ps, pe)) + frags = nxt + if not frags: + break + for fs, fe in frags: + if fe - fs >= min_len - _INTERVAL_EPS: + picked.append((fs, fe, sc)) + picked.sort(key=lambda x: x[0]) + return picked + + +def parse_actionformer_pkl( + pkl_path: Path, video_id: str +) -> list[tuple[float, float, float]]: + with pkl_path.open("rb") as f: + data: dict[str, Any] = pickle.load(f) + vids = data["video-id"] + t0 = np.asarray(data["t-start"]).reshape(-1) + t1 = np.asarray(data["t-end"]).reshape(-1) + scores = np.asarray(data["score"]).reshape(-1) + # 兼容 str / bytes + def _norm(x: object) -> str: + if isinstance(x, bytes): + return x.decode("utf-8", errors="replace") + return str(x) + + mask = np.array([_norm(v) == video_id for v in np.asarray(vids).reshape(-1)]) + out: list[tuple[float, float, float]] = [] + for i in np.where(mask)[0]: + out.append((float(t0[i]), float(t1[i]), float(scores[i]))) + return out + + +def aggregate_top3_votes( + pairs: list[tuple[str, float]], +) -> tuple[list[str], list[float]]: + """ + pairs: (类名, 该帧 max softmax);按置信度做段内加权累计。 + 按累计分数取前三类(同分按类名字典序稳定次序),再以这三类累计分数之和归一化为 top1~3 置信度。 + """ + empty = (["", "", ""], [0.0, 0.0, 0.0]) + if not pairs: + return empty + + # 1) 初始化“积分池”:key=类名,value=该类在段内累计得到的置信度积分。 + score_pool: defaultdict[str, float] = defaultdict(float) + # 2) 逐帧累加积分:同一类在不同帧的 top_prob 按加和方式累计。 + for name, conf in pairs: + score_pool[name] += float(conf) + + # 3) 按累计积分降序排序(同分用类名字典序保证结果稳定),取 Top3。 + ranked = sorted(score_pool.items(), key=lambda x: (-x[1], x[0])) + top = ranked[:3] + if not top: + return empty + + # 4) 仅对 Top3 的累计积分做归一化,得到 top1~top3 置信度(和为 1)。 + total = float(sum(score for _, score in top)) + if total <= 0: + return empty + out_names: list[str] = ["", "", ""] + out_conf: list[float] = [0.0, 0.0, 0.0] + for i, (nm, score) in enumerate(top): + out_names[i] = nm + out_conf[i] = float(score) / total + return out_names, out_conf + + +def process_segment_e2e( + cap: cv2.VideoCapture, + det: YOLO, + gb: YOLO, + cls_m: YOLO, + *, + start_sec: float, + end_sec: float, + seek_margin_sec: float, + det_conf: float, + pad_ratio: float, + imgsz_det: int, + imgsz_cls: int, + frame_stride: int, + good_top1_conf_threshold: float, + haocai_min_conf: float, + gb_names: dict, + cls_names: dict, + allowed_class_idx: frozenset[int] | None, +) -> dict[str, Any]: + probe_from = float(max(0.0, start_sec - seek_margin_sec)) + cap.set(cv2.CAP_PROP_POS_MSEC, probe_from * 1000.0) + synced_frame: np.ndarray | None = None + synced_t: float | None = None + tol = 0.04 + while True: + ok0, grab = cap.read() + if not ok0 or grab is None: + synced_frame, synced_t = None, None + break + t0 = float(cap.get(cv2.CAP_PROP_POS_MSEC)) / 1000.0 + if t0 + tol >= start_sec: + synced_frame, synced_t = grab, t0 + break + + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + n_cls_key_max = max(int(k) for k in cls_names.keys()) + n_cls = n_cls_key_max + 1 + + n_hand_frames = 0 + n_gate_pass = 0 + pairs: list[tuple[str, float]] = [] + frames_read_in_segment = 0 + + def one_frame(fr: np.ndarray) -> None: + nonlocal frames_read_in_segment, n_hand_frames, n_gate_pass, pairs + frames_read_in_segment += 1 + if frame_stride > 1 and (frames_read_in_segment - 1) % frame_stride != 0: + return + + r0 = det.predict(fr, conf=det_conf, imgsz=imgsz_det, verbose=False)[0] + hands = collect_hand_boxes(det, r0.boxes) if r0.boxes else [] + if not hands: + return + + n_hand_frames += 1 + xyxy = largest_hand(hands) + x1, y1, x2, y2 = pad_box(xyxy, w, h, pad_ratio) + crop = fr[y1:y2, x1:x2] + if not passes_good_gate_top1_conf( + gb, crop, gb_names, imgsz_cls, good_top1_conf_threshold + ): + return + n_gate_pass += 1 + vec_raw = haocai_softmax_probs(cls_m, crop, imgsz_cls, n_cls) + if vec_raw is None: + return + if allowed_class_idx is not None: + vec = mask_probs_whitelist(vec_raw, allowed_class_idx, n_cls) + else: + vec = vec_raw + if vec is None: + return + top_prob = float(np.max(vec)) + if top_prob <= haocai_min_conf: + return + label = int(np.argmax(vec)) + pairs.append((_cls_name(cls_names, label), top_prob)) + + if synced_frame is not None and synced_t is not None and synced_t <= end_sec + 0.08: + one_frame(synced_frame) + + while True: + ok, frame = cap.read() + if not ok or frame is None: + break + t = float(cap.get(cv2.CAP_PROP_POS_MSEC)) / 1000.0 + if t > end_sec + 0.08: + break + if t + 1e-6 < start_sec: + continue + one_frame(frame) + + if n_hand_frames == 0: + return {"ok": False, "reason": "(段内未检测到手部)", "pairs": [], "n_gate_pass": 0} + if not pairs: + return { + "ok": False, + "reason": "(无有效耗材帧:好帧/白名单/耗材置信度未全部满足)", + "pairs": [], + "n_hand_frames": n_hand_frames, + "n_gate_pass": n_gate_pass, + } + + n1, c1 = aggregate_top3_votes(pairs) + return { + "ok": True, + "top_names": n1, + "top_confs": c1, + "pairs": pairs, + "n_hand_frames": n_hand_frames, + "n_gate_pass": n_gate_pass, + "n_valid": len(pairs), + } + + +def duration_fps_from_meta(meta: dict, video_id: str) -> tuple[float, float]: + v = meta.get("videos", {}).get(video_id, {}) + if v: + fps = float(v.get("fps", 25.0)) + tf = int(v.get("total_frames", 0)) + if tf > 0 and fps > 0: + return tf / fps, fps + return 300.0, 25.0 + + +def main() -> int: + ap = argparse.ArgumentParser(description="ActionFormer 划段 + 耗材端到端(单视频)") + ap.add_argument("--video", type=Path, required=True, help="输入 MP4") + ap.add_argument("--whitelist-json", type=Path, required=True, help='{"allowed_names":["..."]}') + ap.add_argument( + "--excel", + type=Path, + default=_DEFAULT_EXCEL, + help="商品名称→产品编码", + ) + ap.add_argument("--out", type=Path, required=True, help="输出制表符 TXT") + ap.add_argument( + "--work-dir", + type=Path, + default=None, + help="工作目录(默认临时目录;加 --keep-work-dir 可保留)", + ) + ap.add_argument("--keep-work-dir", action="store_true") + ap.add_argument( + "--actionformer-ckpt", + type=Path, + default=_DEFAULT_AF_CKPT, + ) + ap.add_argument( + "--hand-model", + type=Path, + default=_DEFAULT_HAND, + ) + ap.add_argument( + "--goodbad-model", + type=Path, + default=_DEFAULT_GOODBAD, + ) + ap.add_argument( + "--haocai-model", + type=Path, + default=_DEFAULT_HAOCAI, + ) + ap.add_argument("--good-top1-conf-threshold", type=float, default=0.9) + ap.add_argument("--haocai-min-conf", type=float, default=0.8) + ap.add_argument("--af-min-score", type=float, default=0.1, help="划段保留 score 下限(不含等于)") + ap.add_argument("--det-conf", type=float, default=0.5) + ap.add_argument("--pad-ratio", type=float, default=0.30) + ap.add_argument("--imgsz-det", type=int, default=640) + ap.add_argument("--imgsz-cls", type=int, default=224) + ap.add_argument("--frame-stride", type=int, default=1) + ap.add_argument("--seek-margin-sec", type=float, default=3.0) + ap.add_argument("--feat-batch-size", type=int, default=1) + ap.add_argument("--device", type=str, default="cuda") + ap.add_argument( + "--python", + type=str, + default=sys.executable, + help="子进程 Python(建议 conda yolo 环境的 python)", + ) + args = ap.parse_args() + + video_path = args.video.resolve() + if not video_path.is_file(): + log(f"找不到视频: {video_path}") + return 1 + if not args.excel.is_file(): + log(f"找不到 Excel: {args.excel}") + return 1 + if not args.whitelist_json.is_file(): + log(f"找不到白名单 JSON: {args.whitelist_json}") + return 1 + for p, name in ( + (args.actionformer_ckpt, "ActionFormer ckpt"), + (args.hand_model, "hand"), + (args.goodbad_model, "goodbad"), + (args.haocai_model, "haocai"), + ): + if not Path(p).is_file(): + log(f"缺少{name}: {p}") + return 1 + + stem = video_path.stem + tmp_ctx: tempfile.TemporaryDirectory | None = None + if args.work_dir is not None: + work = Path(args.work_dir).resolve() + work.mkdir(parents=True, exist_ok=True) + elif args.keep_work_dir: + work = Path(tempfile.mkdtemp(prefix="haocai_e2e_")) + log(f"工作目录(保留): {work}") + else: + tmp_ctx = tempfile.TemporaryDirectory(prefix="haocai_e2e_") + work = Path(tmp_ctx.name) + + try: + product_map = load_product_code_map(args.excel.resolve()) + allowed_names = load_whitelist_json(args.whitelist_json.resolve()) + + inp = work / "input" + feat_dir = work / "features" + inp.mkdir(parents=True, exist_ok=True) + feat_dir.mkdir(parents=True, exist_ok=True) + + single_video = inp / video_path.name + if single_video.resolve() != video_path.resolve(): + shutil.copy2(video_path, single_video) + + meta_path = feat_dir / "meta.json" + run_feature_extraction( + python_exe=args.python, + data_root=inp, + output_dir=feat_dir, + meta_file=meta_path, + device=args.device, + batch_size=max(1, args.feat_batch_size), + ) + + meta = json.loads(meta_path.read_text(encoding="utf-8")) + duration, fps = duration_fps_from_meta(meta, stem) + if stem not in meta.get("videos", {}): + # 回退:用文件名 stem 对应 npy + log("meta 中未找到 video_id=stem,使用 ffprobe 估 duration…") + cap0 = cv2.VideoCapture(str(video_path)) + if cap0.isOpened(): + fps = float(cap0.get(cv2.CAP_PROP_FPS)) or fps + nfr = int(cap0.get(cv2.CAP_PROP_FRAME_COUNT)) + cap0.release() + if fps > 0 and nfr > 0: + duration = nfr / fps + + npy_path = feat_dir / f"{stem}.npy" + if not npy_path.is_file(): + log(f"特征文件不存在: {npy_path}") + return 1 + + json_path = work / "infer_single.json" + write_infer_json(json_path, stem, duration, fps) + + yaml_path = work / "infer_single.yaml" + write_infer_yaml(yaml_path, json_path.resolve(), feat_dir.resolve()) + + pkl_dest = work / "eval_results.pkl" + run_actionformer_eval( + python_exe=args.python, + yaml_path=yaml_path.resolve(), + ckpt_path=args.actionformer_ckpt.resolve(), + copy_pkl_to=pkl_dest, + ) + + raw_segs = parse_actionformer_pkl(pkl_dest, stem) + raw_segs = [(s, e, sc) for s, e, sc in raw_segs if sc > args.af_min_score] + segs = greedy_mutual_exclusive(raw_segs) + log(f"ActionFormer 候选 {len(raw_segs)} -> 互斥后 {len(segs)} 段(score>{args.af_min_score})") + + log("加载 YOLO 模型…") + det = YOLO(str(args.hand_model)) + gb = YOLO(str(args.goodbad_model)) + cls_m = YOLO(str(args.haocai_model)) + gb_names = gb.names + cls_names = cls_m.names + allowed_idx = allowed_indices_from_json_names(allowed_names, cls_names) + + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + log("无法打开视频") + return 1 + + sep = "\t" + header = sep.join( + [ + "rank", + "start_sec", + "end_sec", + "product_id_top1", + "top1_name", + "top1_conf", + "product_id_top2", + "top2_name", + "top2_conf", + "product_id_top3", + "top3_name", + "top3_conf", + ] + ) + lines_out = [header] + + try: + for rank, (t0, t1, af_sc) in enumerate(segs, start=1): + log(f"段落 rank={rank} [{t0:.3f},{t1:.3f}] score={af_sc:.4f} …") + info = process_segment_e2e( + cap, + det, + gb, + cls_m, + start_sec=t0, + end_sec=t1, + seek_margin_sec=args.seek_margin_sec, + det_conf=args.det_conf, + pad_ratio=args.pad_ratio, + imgsz_det=args.imgsz_det, + imgsz_cls=args.imgsz_cls, + frame_stride=max(1, args.frame_stride), + good_top1_conf_threshold=args.good_top1_conf_threshold, + haocai_min_conf=args.haocai_min_conf, + gb_names=gb_names, + cls_names=cls_names, + allowed_class_idx=allowed_idx, + ) + if not info.get("ok"): + reason = str(info.get("reason", "")) + lines_out.append( + sep.join( + [ + str(rank), + f"{t0:.6f}", + f"{t1:.6f}", + "", + reason, + "", + "", + "", + "", + "", + "", + "", + "", + ] + ) + ) + continue + + n1, n2, n3 = info["top_names"] + c1, c2, c3 = info["top_confs"] + id1 = product_map.get(n1, "") if n1 else "" + id2 = product_map.get(n2, "") if n2 else "" + id3 = product_map.get(n3, "") if n3 else "" + for nm, pid in ((n1, id1), (n2, id2), (n3, id3)): + if nm and not pid: + log(f"警告: 商品表无名称「{nm}」,产品编码置空。") + + lines_out.append( + sep.join( + [ + str(rank), + f"{t0:.6f}", + f"{t1:.6f}", + id1, + n1, + f"{c1:.6f}" if n1 else "", + id2, + n2, + f"{c2:.6f}" if n2 else "", + id3, + n3, + f"{c3:.6f}" if n3 else "", + ] + ) + ) + finally: + cap.release() + + args.out.parent.mkdir(parents=True, exist_ok=True) + args.out.write_text("\n".join(lines_out) + "\n", encoding="utf-8") + log(f"已写出: {args.out.resolve()}") + if args.work_dir is not None or (args.keep_work_dir and args.work_dir is None): + log(f"工作目录: {work}") + finally: + if tmp_ctx is not None: + tmp_ctx.cleanup() + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/code/video_clip_cls/scripts/test_gap_adjacent_merge.py b/code/video_clip_cls/scripts/test_gap_adjacent_merge.py new file mode 100644 index 0000000..ce4a594 --- /dev/null +++ b/code/video_clip_cls/scripts/test_gap_adjacent_merge.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +"""gap_adjacent_merge 单元测试。""" +from __future__ import annotations + +import sys +import unittest +from pathlib import Path + +_SCRIPTS = Path(__file__).resolve().parent +if str(_SCRIPTS) not in sys.path: + sys.path.insert(0, str(_SCRIPTS)) + +from run_haocai_actionformer_consumables_e2e import aggregate_top3_votes # noqa: E402 +from pipeline.gap_adjacent_merge import ( # noqa: E402 + e2e_row_from_pairs, + group_rows_by_gap, + merge_all_by_gap, + span_key, +) +from pipeline.tear_gate_merge import E2eRow # noqa: E402 + + +def _row( + rank: int, + start: float, + end: float, + n1: str = "A", + c1: str = "0.9", +) -> E2eRow: + return E2eRow( + rank=rank, + start_sec=start, + end_sec=end, + id1="", + n1=n1, + c1=c1, + id2="", + n2="", + c2="", + id3="", + n3="", + c3="", + ) + + +def _fail_row(rank: int, start: float, end: float) -> E2eRow: + return E2eRow( + rank=rank, + start_sec=start, + end_sec=end, + id1="", + n1="无有效耗材帧", + c1="", + id2="", + n2="", + c2="", + id3="", + n3="", + c3="", + ) + + +class TestGroupRowsByGap(unittest.TestCase): + def test_gap_zero_same_group(self) -> None: + rows = [_row(1, 10.0, 20.0), _row(2, 20.0, 30.0)] + groups = group_rows_by_gap(rows, max_gap_sec=2.0) + self.assertEqual(len(groups), 1) + self.assertEqual(len(groups[0]), 2) + + def test_gap_1_5_same_group(self) -> None: + rows = [_row(1, 10.0, 20.0), _row(2, 21.5, 30.0)] + groups = group_rows_by_gap(rows, max_gap_sec=2.0) + self.assertEqual(len(groups), 1) + self.assertEqual(len(groups[0]), 2) + + def test_gap_2_not_merged(self) -> None: + rows = [_row(1, 10.0, 20.0), _row(2, 22.0, 30.0)] + groups = group_rows_by_gap(rows, max_gap_sec=2.0) + self.assertEqual(len(groups), 2) + + def test_failed_row_breaks_group(self) -> None: + rows = [ + _row(1, 10.0, 20.0), + _fail_row(2, 20.0, 25.0), + _row(3, 25.0, 30.0), + ] + groups = group_rows_by_gap(rows, max_gap_sec=2.0) + self.assertEqual(len(groups), 3) + self.assertEqual(len(groups[0]), 1) + self.assertEqual(len(groups[1]), 1) + self.assertEqual(len(groups[2]), 1) + + +class TestMergeAllByGap(unittest.TestCase): + def test_pairs_concat_matches_aggregate(self) -> None: + pairs_a = [("敷贴", 0.9), ("血路", 0.8)] + pairs_b = [("敷贴", 0.95)] + rows = [_row(1, 34.0, 39.5), _row(2, 39.5, 44.0)] + span_to_pairs = { + span_key(34.0, 39.5): pairs_a, + span_key(39.5, 44.0): pairs_b, + } + product_map = {"敷贴": "p1", "血路": "p2"} + merged = merge_all_by_gap( + rows, + span_to_pairs, + product_map, + max_gap_sec=2.0, + ) + self.assertEqual(len(merged), 1) + expected = e2e_row_from_pairs(34.0, 44.0, pairs_a + pairs_b, product_map) + self.assertEqual(merged[0].n1, expected.n1) + self.assertEqual(merged[0].c1, expected.c1) + self.assertEqual(merged[0].n2, expected.n2) + self.assertEqual(merged[0].c2, expected.c2) + names, confs = aggregate_top3_votes(pairs_a + pairs_b) + self.assertEqual(merged[0].n1, names[0]) + self.assertAlmostEqual(float(merged[0].c1), confs[0], places=5) + + +if __name__ == "__main__": + unittest.main() diff --git a/code/video_clip_cls/scripts/test_greedy_mutual_exclusive.py b/code/video_clip_cls/scripts/test_greedy_mutual_exclusive.py new file mode 100644 index 0000000..8c72ef4 --- /dev/null +++ b/code/video_clip_cls/scripts/test_greedy_mutual_exclusive.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +"""greedy_mutual_exclusive 单元测试。""" +from __future__ import annotations + +import sys +import unittest +from pathlib import Path + +_SCRIPTS = Path(__file__).resolve().parent +if str(_SCRIPTS) not in sys.path: + sys.path.insert(0, str(_SCRIPTS)) + +from run_haocai_actionformer_consumables_e2e import ( # noqa: E402 + greedy_mutual_exclusive, + segments_overlap, +) + + +def _no_overlap(segs: list[tuple[float, float, float]]) -> bool: + for i, (s0, e0, _) in enumerate(segs): + for s1, e1, _ in segs[i + 1 :]: + if segments_overlap(s0, e0, s1, e1): + return False + return True + + +class TestGreedyMutualExclusive(unittest.TestCase): + def test_overlap_discarded_keep_highest(self) -> None: + items = [(10.0, 30.0, 0.15), (20.0, 25.0, 0.30)] + out = greedy_mutual_exclusive(items) + self.assertEqual(out, [(20.0, 25.0, 0.30)]) + + def test_fully_overlapped_low_discarded(self) -> None: + items = [(15.0, 25.0, 0.10), (10.0, 30.0, 0.50)] + out = greedy_mutual_exclusive(items) + self.assertEqual(out, [(10.0, 30.0, 0.50)]) + + def test_non_overlap_all_kept(self) -> None: + items = [(10.0, 20.0, 0.8), (30.0, 40.0, 0.7), (50.0, 60.0, 0.9)] + out = greedy_mutual_exclusive(items) + self.assertEqual(len(out), 3) + self.assertTrue(_no_overlap(out)) + + def test_sorted_by_start(self) -> None: + items = [(50.0, 60.0, 0.9), (10.0, 20.0, 0.8), (30.0, 40.0, 0.7)] + out = greedy_mutual_exclusive(items) + starts = [s for s, _, _ in out] + self.assertEqual(starts, sorted(starts)) + + +if __name__ == "__main__": + unittest.main() diff --git a/code/video_clip_cls/scripts/test_hybrid_nms_and_trimming.py b/code/video_clip_cls/scripts/test_hybrid_nms_and_trimming.py new file mode 100644 index 0000000..ebbd16e --- /dev/null +++ b/code/video_clip_cls/scripts/test_hybrid_nms_and_trimming.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +"""hybrid_nms_and_trimming 单元测试。""" +from __future__ import annotations + +import sys +import unittest +from pathlib import Path + +_SCRIPTS = Path(__file__).resolve().parent +if str(_SCRIPTS) not in sys.path: + sys.path.insert(0, str(_SCRIPTS)) + +from run_haocai_actionformer_consumables_e2e import ( # noqa: E402 + greedy_mutual_exclusive, + hybrid_nms_and_trimming, + segment_iou_1d, + segments_overlap, +) + + +def _no_overlap(segs: list[tuple[float, float, float]]) -> bool: + for i, (s0, e0, _) in enumerate(segs): + for s1, e1, _ in segs[i + 1 :]: + if segments_overlap(s0, e0, s1, e1): + return False + return True + + +class TestSegmentIou1d(unittest.TestCase): + def test_duplicate_high_iou(self) -> None: + self.assertAlmostEqual(segment_iou_1d(5.0, 10.0, 6.0, 9.0), 0.6) + + def test_partial_overlap_low_iou(self) -> None: + self.assertAlmostEqual(segment_iou_1d(10.0, 30.0, 20.0, 25.0), 0.25) + + def test_no_overlap_zero(self) -> None: + self.assertEqual(segment_iou_1d(10.0, 20.0, 30.0, 40.0), 0.0) + + +class TestHybridNmsAndTrimming(unittest.TestCase): + def test_iou_nms_discard_duplicate(self) -> None: + items = [(5.0, 10.0, 0.30), (6.0, 9.0, 0.20)] + out = hybrid_nms_and_trimming(items) + self.assertEqual(len(out), 1) + self.assertAlmostEqual(out[0][0], 5.0) + self.assertAlmostEqual(out[0][1], 10.0) + self.assertAlmostEqual(out[0][2], 0.30) + + def test_partial_overlap_trim_sides(self) -> None: + items = [(10.0, 30.0, 0.15), (20.0, 25.0, 0.30)] + out = hybrid_nms_and_trimming(items, min_len=0.0) + self.assertEqual(len(out), 3) + self.assertTrue(_no_overlap(out)) + starts = sorted(s for s, _, _ in out) + self.assertEqual(starts, [10.0, 20.0, 25.0]) + + def test_trim_fragment_shorter_than_min_len(self) -> None: + items = [(10.0, 30.0, 0.15), (20.0, 25.0, 0.30)] + out = hybrid_nms_and_trimming(items, min_len=1.5) + self.assertEqual(len(out), 3) + for s, e, _ in out: + self.assertGreaterEqual(e - s, 1.5 - 1e-9) + + items_short = [(10.0, 11.0, 0.15), (20.0, 25.0, 0.30)] + out_short = hybrid_nms_and_trimming(items_short, min_len=1.5) + self.assertEqual(out_short, [(20.0, 25.0, 0.30)]) + + def test_non_overlap_all_kept(self) -> None: + items = [(10.0, 20.0, 0.8), (30.0, 40.0, 0.7), (50.0, 60.0, 0.9)] + out = hybrid_nms_and_trimming(items) + self.assertEqual(len(out), 3) + self.assertTrue(_no_overlap(out)) + + def test_output_sorted_no_overlap(self) -> None: + items = [ + (50.0, 60.0, 0.9), + (10.0, 30.0, 0.15), + (20.0, 25.0, 0.30), + (30.0, 40.0, 0.7), + ] + out = hybrid_nms_and_trimming(items, min_len=0.0) + starts = [s for s, _, _ in out] + self.assertEqual(starts, sorted(starts)) + self.assertTrue(_no_overlap(out)) + + def test_endpoint_chain_not_merged(self) -> None: + """端点相接链不合并,但 IoU=0 时各自保留(与纯 trimming 区别)。""" + items = [ + (45.8, 48.5, 0.20), + (48.5, 52.0, 0.19), + (52.0, 57.9, 0.20), + (57.9, 63.1, 0.27), + ] + out = hybrid_nms_and_trimming(items, min_len=0.0) + self.assertEqual(len(out), 4) + self.assertTrue(_no_overlap(out)) + + def test_vs_legacy_on_partial_overlap(self) -> None: + """相邻粘连场景:hybrid 保留切割碎片,legacy 整段丢弃低分。""" + items = [(10.0, 30.0, 0.15), (20.0, 25.0, 0.30)] + hybrid = hybrid_nms_and_trimming(items, min_len=0.0) + legacy = greedy_mutual_exclusive(items) + self.assertGreater(len(hybrid), len(legacy)) + self.assertEqual(len(legacy), 1) + + def test_infusion_subset_135353_style(self) -> None: + """135353 风格:高分段包含低分段 → IoU NMS 丢弃重复。""" + items = [ + (45.826, 63.131, 0.2735), + (45.826, 48.512, 0.1988), + (48.512, 52.060, 0.1884), + (52.060, 57.868, 0.1975), + ] + out = hybrid_nms_and_trimming(items, min_len=0.0) + self.assertEqual(len(out), 1) + self.assertAlmostEqual(out[0][0], 45.826) + self.assertAlmostEqual(out[0][1], 63.131) + + +if __name__ == "__main__": + unittest.main() diff --git a/configs/default_config.yaml b/configs/default_config.yaml new file mode 100644 index 0000000..a4b0f80 --- /dev/null +++ b/configs/default_config.yaml @@ -0,0 +1,87 @@ +# 手术室耗材主流程 — 统一配置 +# 修改路径与阈值后运行: python main.py + +io: + # 输入输出(支持绝对路径或相对 pack 根目录) + video: input/sample.mp4 + excel: input/视频中的商品信息表.xlsx + out: output/result.txt + # null:不从 JSON 读白名单,改为从 excel 第 1 张表 C 列(商品名称)构建 + whitelist_json: null + +weights: + actionformer: weights/actionformer_epoch_045.pth.tar + hand: weights/hand_detect.pt + goodbad: weights/goodbad_frame.pt + haocai: weights/haocai_classify.pt + tear: weights/tear_classify.pt + +runtime: + # 中间结果目录;null 表示使用系统临时目录(运行结束删除) + work_dir: null + keep_work_dir: false + # null 表示子进程使用当前 Python 解释器(sys.executable) + python: null + +device: + type: cuda + half: false + +phase1: + af_min_score: 0.1 + af_min_seg_seconds: 2 + feat_batch_size: 1 + +phase2: + seek_margin_sec: 3.0 + frame_stride: 1 + det_conf: 0.6 + pad_ratio: 0.20 + imgsz_det: 640 + merge_iou_gt: 0.0 + merge_center_dist_max_px: null + merge_center_dist_max_frac_diag: null + # 段内手部追踪:检测框 EMA 平滑系数(x1,y1,x2,y2 同步平滑) + tracking_alpha: 0.6 + # 段内手部追踪:连续漏检补帧上限;0 表示不补帧(向下兼容旧行为) + tracking_max_lost_frames: 0 + +classification: + imgsz_cls: 224 + good_top1_conf_threshold: 0.9 + good_top1_retry_threshold: 0.5 + haocai_min_conf: 0.8 + # <=0 关闭第二次耗材阈值重试 + haocai_min_conf_retry: 0.5 + empty_cache_every: 0 + +tear_merge: + merge_adjacent_tear: true + tear_merge_weights: null + tear_merge_class: tearing + tear_merge_head_sec: 3.0 + tear_merge_prob: 0.9 + tear_merge_min_frames: 6 + tear_merge_verbose: false + tear_merge_full_frame: false + +gap_merge: + enabled: true + max_gap_sec: 2.0 + +output: + # true:不输出 tear_top1_name / tear_top2_name(仅 12 列,默认) + legacy_12_col_only: true + +doctor_identity: + enabled: true + checkpoint: doctor_identity_package/doctor_info.pth + labels_csv: doctor_identity_package/labels.csv + # MediaPipe Pose 相关阈值(人体检测置信度) + pose_min_detection_confidence: 0.30 + # 医生身份结果最低置信度;低于该值时在文本中标注“低置信度” + min_identity_confidence: 0.00 + # 中间窗口与采样参数 + middle_seconds: 10.0 + sample_fps: 3.0 + pad_frac: 0.15 diff --git a/docs/segment_mutual_exclusive.md b/docs/segment_mutual_exclusive.md new file mode 100644 index 0000000..a1ca8bd --- /dev/null +++ b/docs/segment_mutual_exclusive.md @@ -0,0 +1,135 @@ +# ActionFormer 时段贪心互斥算法说明 + +> 适用版本:`5.17` / `5.21` 主流程 Phase1 +> 代码位置:`code/video_clip_cls/scripts/run_haocai_actionformer_consumables_e2e.py` +> 调用入口:`src/actionformer_utils.py` → `ActionSegmenter.build_segments()` + +--- + +## 1. 背景与目标 + +ActionFormer 对整段手术视频会输出 **大量重叠的候选时段**。Phase1 使用 `greedy_mutual_exclusive` 做贪心互斥:按 score 降序选取,**与已选段有任何重叠的候选整段丢弃**,保证输出时段两两不重叠。 + +--- + +## 2. 在流水线中的位置 + +```mermaid +flowchart LR + A[VideoSwin 特征] --> B[ActionFormer eval] + B --> C["parse_actionformer_pkl()"] + C --> D["score > af_min_score"] + D --> E["greedy_mutual_exclusive()"] + E --> F["时长 >= af_min_seg_seconds"] + F --> G[Phase2 YOLO 段内推断] +``` + +日志示例: + +```text +ActionFormer 候选 47 -> 互斥后 8 段 -> 剔除短于 2s 后 6 段(score>0.1) +``` + +--- + +## 3. 核心算法 + +```python +def greedy_mutual_exclusive(items): + sorted_items = sorted(items, key=lambda x: -x[2]) # score 降序 + picked = [] + for s, e, sc in sorted_items: + if any(segments_overlap(s, e, ps, pe) for ps, pe, _ in picked): + continue # 有重叠则整段丢弃 + picked.append((s, e, sc)) + picked.sort(key=lambda x: x[0]) + return picked +``` + +| 步骤 | 行为 | +|------|------| +| 排序 | 按 score **降序** 处理候选 | +| 互斥 | 与已选段 `inter > 1e-6` 则 **整段丢弃** | +| 输出 | 按 start 升序排列 | + +--- + +## 4. 配置项 + +```yaml +phase1: + af_min_score: 0.1 + af_min_seg_seconds: 2 +``` + +--- + +## 5. 与 Phase2 撕膜合并的区别 + +| 阶段 | 函数 | 作用 | +|------|------|------| +| Phase1 | `greedy_mutual_exclusive` | ActionFormer 候选 **零重叠** 化 | +| Phase2 | `tear_gate_merge` | 相邻段 **合并**(撕膜门控) | + +--- + +## 6. 代码与测试索引 + +| 功能 | 文件 | +|------|------| +| 互斥主逻辑 | `run_haocai_actionformer_consumables_e2e.py` | +| 流程编排 | `src/actionformer_utils.py` | +| 单元测试 | `test_greedy_mutual_exclusive.py` | + +```bash +cd code/video_clip_cls/scripts +conda run -n yolo python test_greedy_mutual_exclusive.py -v +``` + +--- + +## 7. Hybrid NMS & Trimming(实验性,未接入主流程) + +> 函数:`hybrid_nms_and_trimming()` +> 状态:**已实现 + 单元测试,主流程仍使用 `greedy_mutual_exclusive`** + +### 动机 + +| 方案 | 问题 | +|------|------| +| `greedy_mutual_exclusive` | 任意重叠整段丢弃,漏掉边缘粘连的相邻动作 | +| 纯 Trimming + 端点合并 | 未区分「同动作重复预测」与「相邻动作」,段数膨胀 | + +### 三道关卡 + +```mermaid +flowchart TD + A[score 降序] --> B{IoU > 0.4?} + B -->|是| C[整段丢弃] + B -->|否| D[边界切割] + D --> E{长度 >= 1.5s?} + E -->|是| F[加入 picked] + E -->|否| G[丢弃碎片] +``` + +| 关卡 | 规则 | +|------|------| +| 第一关 IoU NMS | 与已选段 `segment_iou_1d > 0.4` → 同动作分身,整段丢弃 | +| 第二关 Trimming | 部分重叠 → `_subtract_interval` 挖掉重叠区 | +| 第三关 min_len | 碎片 `>= 1.5s` 才保留 | + +默认参数(写死在代码):`iou_threshold=0.4`,`min_len=1.5`。 + +### 测试 + +```bash +cd code/video_clip_cls/scripts +conda run -n yolo python test_hybrid_nms_and_trimming.py -v +``` + +### 后续接入(一行替换) + +```python +# src/actionformer_utils.py(尚未启用) +segs = e2e.hybrid_nms_and_trimming(raw_segs, min_len=af_min_seg_seconds) +``` diff --git a/doctor_identity_package/infer_doctor_from_video.py b/doctor_identity_package/infer_doctor_from_video.py new file mode 100644 index 0000000..22f9b3b --- /dev/null +++ b/doctor_identity_package/infer_doctor_from_video.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +"""Infer doctor identity from one MP4 video. + +Pipeline: +1) Take middle N seconds from input video. +2) Run MediaPipe Pose to detect human bbox. +3) Keep the best crop (largest bbox area). +4) Run doctor ReID checkpoint classification head. +5) Output one final doctor identity. +""" + +from __future__ import annotations + +import argparse +import csv +import sys +import urllib.request +from pathlib import Path + +import cv2 +import mediapipe as mp +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +# Allow importing local training model definition when running directly. +THIS_DIR = Path(__file__).resolve().parent +if str(THIS_DIR) not in sys.path: + sys.path.insert(0, str(THIS_DIR)) + +from train_reid_contrastive import ReIDEmbedModel # noqa: E402 + +BaseOptions = mp.tasks.BaseOptions +PoseLandmarker = mp.tasks.vision.PoseLandmarker +PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions +VisionRunningMode = mp.tasks.vision.RunningMode + +POSE_LITE_URL = ( + "https://storage.googleapis.com/mediapipe-models/pose_landmarker/" + "pose_landmarker_lite/float16/1/pose_landmarker_lite.task" +) +POSE_LITE_NAME = "pose_landmarker_lite.task" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Input mp4 -> middle 10s pose crop -> doctor identity", + ) + parser.add_argument("--video", type=Path, required=True, help="input .mp4 path") + parser.add_argument( + "--checkpoint", + type=Path, + default=THIS_DIR / "doctor_info.pth", + help="doctor checkpoint path (.pth)", + ) + parser.add_argument( + "--labels-csv", + type=Path, + default=THIS_DIR / "labels.csv", + help="person_id to doctor name mapping csv", + ) + parser.add_argument( + "--middle-seconds", + type=float, + default=10.0, + help="window length around video center in seconds", + ) + parser.add_argument( + "--sample-fps", + type=float, + default=3.0, + help="sampling fps inside the middle window", + ) + parser.add_argument( + "--pad-frac", + type=float, + default=0.15, + help="bbox padding ratio", + ) + parser.add_argument( + "--save-crop", + type=Path, + default=None, + help="optional path to save best cropped person image", + ) + return parser.parse_args() + + +def _ensure_pose_lite_model(model_dir: Path) -> Path: + model_dir.mkdir(parents=True, exist_ok=True) + model_path = model_dir / POSE_LITE_NAME + if model_path.is_file() and model_path.stat().st_size > 10_000: + return model_path + print(f"[info] Downloading MediaPipe Pose model -> {model_path}", flush=True) + urllib.request.urlretrieve(POSE_LITE_URL, model_path) + return model_path + + +def bbox_from_normalized_pose_landmarks( + w: int, + h: int, + landmark_list, +) -> tuple[int, int, int, int] | None: + if not landmark_list: + return None + xs = [float(lm.x) * w for lm in landmark_list] + ys = [float(lm.y) * h for lm in landmark_list] + if not xs: + return None + return int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys)) + + +def expand_bbox_with_padding( + x1: int, + y1: int, + x2: int, + y2: int, + image_w: int, + image_h: int, + pad_frac: float, +) -> tuple[int, int, int, int]: + bw = max(1, x2 - x1) + bh = max(1, y2 - y1) + cx = (x1 + x2) / 2.0 + cy = (y1 + y2) / 2.0 + nw = bw * (1.0 + pad_frac) + nh = bh * (1.0 + pad_frac) + nx1 = int(round(cx - nw / 2.0)) + ny1 = int(round(cy - nh / 2.0)) + nx2 = int(round(cx + nw / 2.0)) + ny2 = int(round(cy + nh / 2.0)) + nx1 = max(0, nx1) + ny1 = max(0, ny1) + nx2 = min(image_w, nx2) + ny2 = min(image_h, ny2) + if nx2 <= nx1 or ny2 <= ny1: + return 0, 0, min(1, image_w), min(1, image_h) + return nx1, ny1, nx2, ny2 + + +def sample_middle_timestamps(duration_sec: float, middle_seconds: float, sample_fps: float) -> list[float]: + if duration_sec <= 0 or middle_seconds <= 0 or sample_fps <= 0: + return [] + center = duration_sec / 2.0 + half = middle_seconds / 2.0 + t0 = max(0.0, center - half) + t1 = min(duration_sec, center + half) + step = 1.0 / sample_fps + ts = [] + t = t0 + while t < t1 - 1e-6: + ts.append(t) + t += step + return ts + + +def pick_best_person_crop( + video_path: Path, + landmarker: PoseLandmarker, + middle_seconds: float, + sample_fps: float, + pad_frac: float, +) -> np.ndarray: + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Cannot open video: {video_path}") + + fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) + frame_count = float(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0.0) + duration_sec = frame_count / fps if fps > 1e-6 else 0.0 + timestamps = sample_middle_timestamps(duration_sec, middle_seconds, sample_fps) + if not timestamps: + cap.release() + raise RuntimeError("No valid timestamps from middle window.") + + best_area = -1 + best_crop: np.ndarray | None = None + + for ts in timestamps: + cap.set(cv2.CAP_PROP_POS_MSEC, ts * 1000.0) + ok, frame = cap.read() + if not ok or frame is None: + continue + h, w = frame.shape[:2] + rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + mp_img = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb) + res = landmarker.detect(mp_img) + if not res.pose_landmarks: + continue + + for lmk in res.pose_landmarks: + box = bbox_from_normalized_pose_landmarks(w, h, lmk) + if box is None: + continue + ex1, ey1, ex2, ey2 = expand_bbox_with_padding(*box, w, h, pad_frac=pad_frac) + crop = frame[ey1:ey2, ex1:ex2] + if crop.size == 0: + continue + area = int((ex2 - ex1) * (ey2 - ey1)) + if area > best_area: + best_area = area + best_crop = crop.copy() + + cap.release() + if best_crop is None: + raise RuntimeError("No person detected in the middle window.") + return best_crop + + +def build_label_to_pid(pid_to_label: dict) -> dict[int, str]: + label_to_pid: dict[int, str] = {} + for raw_pid, label in pid_to_label.items(): + try: + label_int = int(label) + except (TypeError, ValueError): + continue + label_to_pid[label_int] = str(raw_pid) + return label_to_pid + + +def load_name_mapping(labels_csv: Path) -> dict[str, str]: + if not labels_csv.is_file(): + return {} + mapping: dict[str, str] = {} + with labels_csv.open("r", encoding="utf-8-sig", newline="") as f: + reader = csv.DictReader(f) + for row in reader: + pid = str(row.get("person_id", "")).strip() + name = str(row.get("医生姓名", "")).strip() + if pid and name and pid not in mapping: + mapping[pid] = name + return mapping + + +def run_inference(crop_bgr: np.ndarray, checkpoint_path: Path) -> tuple[str, float]: + if not checkpoint_path.is_file(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) + num_classes = int(ckpt["num_classes"]) + pid_to_label = ckpt.get("pid_to_label", {}) + if not isinstance(pid_to_label, dict): + raise RuntimeError("Checkpoint missing valid pid_to_label dict.") + + model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(device) + model.load_state_dict(ckpt["model_state"]) + model.eval() + + transform = transforms.Compose( + [ + transforms.Resize((256, 128)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] + ) + crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB) + inp = transform(Image.fromarray(crop_rgb)).unsqueeze(0).to(device) + + with torch.no_grad(): + _, logits = model(inp) + probs = torch.softmax(logits, dim=1) + pred_label = int(torch.argmax(probs, dim=1).item()) + conf = float(probs[0, pred_label].item()) + + label_to_pid = build_label_to_pid(pid_to_label) + raw_pid = label_to_pid.get(pred_label) + if raw_pid is None: + raise RuntimeError(f"Predicted label {pred_label} not found in pid mapping.") + return raw_pid, conf + + +def main() -> int: + args = parse_args() + if not args.video.is_file(): + print(f"[error] video not found: {args.video}", file=sys.stderr) + return 2 + + try: + model_path = _ensure_pose_lite_model(THIS_DIR / ".mediapipe_models") + opts = PoseLandmarkerOptions( + base_options=BaseOptions(model_asset_path=str(model_path)), + running_mode=VisionRunningMode.IMAGE, + min_pose_detection_confidence=0.3, + ) + landmarker = PoseLandmarker.create_from_options(opts) + try: + best_crop = pick_best_person_crop( + video_path=args.video, + landmarker=landmarker, + middle_seconds=args.middle_seconds, + sample_fps=args.sample_fps, + pad_frac=args.pad_frac, + ) + finally: + landmarker.close() + + if args.save_crop is not None: + args.save_crop.parent.mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(args.save_crop), best_crop) + + raw_pid, conf = run_inference(best_crop, args.checkpoint) + name_map = load_name_mapping(args.labels_csv) + doctor_name = name_map.get(str(raw_pid), "") + + if doctor_name: + print(f"doctor={doctor_name} (id={raw_pid}, conf={conf:.4f})") + else: + print(f"doctor_id={raw_pid} (conf={conf:.4f})") + return 0 + except Exception as exc: # noqa: BLE001 + print(f"[error] {exc}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/doctor_identity_package/labels.csv b/doctor_identity_package/labels.csv new file mode 100644 index 0000000..c447e5f --- /dev/null +++ b/doctor_identity_package/labels.csv @@ -0,0 +1,760 @@ +filename,person_id,person_id_file,医生姓名,camera_id,global_index +24502_c1_s1_00001.jpg,24502,24502,钟光喜,1,1 +24502_c1_s1_00002.jpg,24502,24502,钟光喜,1,2 +24502_c1_s1_00003.jpg,24502,24502,钟光喜,1,3 +24502_c1_s1_00004.jpg,24502,24502,钟光喜,1,4 +24502_c1_s1_00005.jpg,24502,24502,钟光喜,1,5 +24502_c1_s1_00006.jpg,24502,24502,钟光喜,1,6 +24502_c1_s1_00007.jpg,24502,24502,钟光喜,1,7 +24502_c1_s1_00008.jpg,24502,24502,钟光喜,1,8 +24502_c1_s1_00009.jpg,24502,24502,钟光喜,1,9 +24502_c1_s1_00010.jpg,24502,24502,钟光喜,1,10 +24502_c1_s1_00011.jpg,24502,24502,钟光喜,1,11 +24502_c1_s1_00012.jpg,24502,24502,钟光喜,1,12 +24502_c1_s1_00013.jpg,24502,24502,钟光喜,1,13 +24502_c1_s1_00014.jpg,24502,24502,钟光喜,1,14 +24502_c1_s1_00015.jpg,24502,24502,钟光喜,1,15 +24502_c1_s1_00016.jpg,24502,24502,钟光喜,1,16 +24502_c1_s1_00017.jpg,24502,24502,钟光喜,1,17 +24502_c1_s1_00018.jpg,24502,24502,钟光喜,1,18 +24502_c1_s1_00019.jpg,24502,24502,钟光喜,1,19 +24502_c1_s1_00020.jpg,24502,24502,钟光喜,1,20 +24502_c1_s1_00021.jpg,24502,24502,钟光喜,1,21 +24502_c1_s1_00022.jpg,24502,24502,钟光喜,1,22 +24502_c1_s1_00023.jpg,24502,24502,钟光喜,1,23 +24502_c1_s1_00024.jpg,24502,24502,钟光喜,1,24 +24502_c1_s1_00025.jpg,24502,24502,钟光喜,1,25 +24502_c1_s1_00026.jpg,24502,24502,钟光喜,1,26 +24502_c1_s1_00027.jpg,24502,24502,钟光喜,1,27 +24502_c1_s1_00028.jpg,24502,24502,钟光喜,1,28 +24502_c1_s1_00029.jpg,24502,24502,钟光喜,1,29 +24502_c1_s1_00030.jpg,24502,24502,钟光喜,1,30 +24502_c1_s1_00031.jpg,24502,24502,钟光喜,1,31 +24502_c1_s1_00032.jpg,24502,24502,钟光喜,1,32 +24502_c1_s1_00033.jpg,24502,24502,钟光喜,1,33 +24502_c1_s1_00034.jpg,24502,24502,钟光喜,1,34 +24502_c1_s1_00035.jpg,24502,24502,钟光喜,1,35 +24502_c1_s1_00036.jpg,24502,24502,钟光喜,1,36 +24502_c1_s1_00037.jpg,24502,24502,钟光喜,1,37 +24502_c1_s1_00038.jpg,24502,24502,钟光喜,1,38 +24502_c1_s1_00039.jpg,24502,24502,钟光喜,1,39 +24502_c1_s1_00040.jpg,24502,24502,钟光喜,1,40 +24502_c1_s1_00041.jpg,24502,24502,钟光喜,1,41 +24502_c1_s1_00042.jpg,24502,24502,钟光喜,1,42 +24502_c1_s1_00043.jpg,24502,24502,钟光喜,1,43 +24502_c1_s1_00044.jpg,24502,24502,钟光喜,1,44 +24502_c1_s1_00045.jpg,24502,24502,钟光喜,1,45 +24502_c2_s1_00046.jpg,24502,24502,钟光喜,2,46 +24502_c2_s1_00047.jpg,24502,24502,钟光喜,2,47 +24502_c2_s1_00048.jpg,24502,24502,钟光喜,2,48 +24502_c2_s1_00049.jpg,24502,24502,钟光喜,2,49 +24502_c2_s1_00050.jpg,24502,24502,钟光喜,2,50 +24502_c2_s1_00051.jpg,24502,24502,钟光喜,2,51 +24502_c2_s1_00052.jpg,24502,24502,钟光喜,2,52 +24502_c2_s1_00053.jpg,24502,24502,钟光喜,2,53 +24502_c2_s1_00054.jpg,24502,24502,钟光喜,2,54 +24502_c2_s1_00055.jpg,24502,24502,钟光喜,2,55 +24502_c2_s1_00056.jpg,24502,24502,钟光喜,2,56 +24502_c2_s1_00057.jpg,24502,24502,钟光喜,2,57 +24502_c2_s1_00058.jpg,24502,24502,钟光喜,2,58 +24502_c2_s1_00059.jpg,24502,24502,钟光喜,2,59 +24502_c2_s1_00060.jpg,24502,24502,钟光喜,2,60 +24502_c2_s1_00061.jpg,24502,24502,钟光喜,2,61 +24502_c2_s1_00062.jpg,24502,24502,钟光喜,2,62 +24502_c2_s1_00063.jpg,24502,24502,钟光喜,2,63 +24502_c2_s1_00064.jpg,24502,24502,钟光喜,2,64 +24502_c2_s1_00065.jpg,24502,24502,钟光喜,2,65 +24502_c2_s1_00066.jpg,24502,24502,钟光喜,2,66 +24502_c2_s1_00067.jpg,24502,24502,钟光喜,2,67 +24502_c2_s1_00068.jpg,24502,24502,钟光喜,2,68 +24502_c2_s1_00069.jpg,24502,24502,钟光喜,2,69 +24502_c2_s1_00070.jpg,24502,24502,钟光喜,2,70 +24502_c2_s1_00071.jpg,24502,24502,钟光喜,2,71 +24502_c2_s1_00072.jpg,24502,24502,钟光喜,2,72 +24502_c2_s1_00073.jpg,24502,24502,钟光喜,2,73 +24502_c2_s1_00074.jpg,24502,24502,钟光喜,2,74 +24502_c2_s1_00075.jpg,24502,24502,钟光喜,2,75 +24502_c2_s1_00076.jpg,24502,24502,钟光喜,2,76 +24502_c2_s1_00077.jpg,24502,24502,钟光喜,2,77 +24502_c2_s1_00078.jpg,24502,24502,钟光喜,2,78 +24502_c2_s1_00079.jpg,24502,24502,钟光喜,2,79 +24502_c2_s1_00080.jpg,24502,24502,钟光喜,2,80 +24502_c2_s1_00081.jpg,24502,24502,钟光喜,2,81 +24502_c2_s1_00082.jpg,24502,24502,钟光喜,2,82 +24502_c2_s1_00083.jpg,24502,24502,钟光喜,2,83 +24502_c2_s1_00084.jpg,24502,24502,钟光喜,2,84 +24502_c2_s1_00085.jpg,24502,24502,钟光喜,2,85 +24502_c2_s1_00086.jpg,24502,24502,钟光喜,2,86 +24502_c2_s1_00087.jpg,24502,24502,钟光喜,2,87 +24502_c2_s1_00088.jpg,24502,24502,钟光喜,2,88 +24502_c2_s1_00089.jpg,24502,24502,钟光喜,2,89 +24502_c2_s1_00090.jpg,24502,24502,钟光喜,2,90 +24502_c2_s1_00091.jpg,24502,24502,钟光喜,2,91 +24502_c2_s1_00092.jpg,24502,24502,钟光喜,2,92 +24502_c2_s1_00093.jpg,24502,24502,钟光喜,2,93 +24502_c2_s1_00094.jpg,24502,24502,钟光喜,2,94 +24502_c2_s1_00095.jpg,24502,24502,钟光喜,2,95 +24502_c2_s1_00096.jpg,24502,24502,钟光喜,2,96 +24502_c2_s1_00097.jpg,24502,24502,钟光喜,2,97 +24502_c2_s1_00098.jpg,24502,24502,钟光喜,2,98 +24502_c2_s1_00099.jpg,24502,24502,钟光喜,2,99 +24502_c2_s1_00100.jpg,24502,24502,钟光喜,2,100 +24502_c2_s1_00101.jpg,24502,24502,钟光喜,2,101 +24502_c2_s1_00102.jpg,24502,24502,钟光喜,2,102 +24502_c3_s1_00103.jpg,24502,24502,钟光喜,3,103 +24502_c3_s1_00104.jpg,24502,24502,钟光喜,3,104 +24502_c3_s1_00105.jpg,24502,24502,钟光喜,3,105 +24502_c3_s1_00106.jpg,24502,24502,钟光喜,3,106 +24502_c3_s1_00107.jpg,24502,24502,钟光喜,3,107 +24502_c3_s1_00108.jpg,24502,24502,钟光喜,3,108 +24502_c3_s1_00109.jpg,24502,24502,钟光喜,3,109 +24502_c3_s1_00110.jpg,24502,24502,钟光喜,3,110 +24502_c3_s1_00111.jpg,24502,24502,钟光喜,3,111 +24502_c3_s1_00112.jpg,24502,24502,钟光喜,3,112 +24502_c3_s1_00113.jpg,24502,24502,钟光喜,3,113 +24502_c3_s1_00114.jpg,24502,24502,钟光喜,3,114 +24502_c3_s1_00115.jpg,24502,24502,钟光喜,3,115 +24502_c3_s1_00116.jpg,24502,24502,钟光喜,3,116 +24502_c3_s1_00117.jpg,24502,24502,钟光喜,3,117 +24502_c3_s1_00118.jpg,24502,24502,钟光喜,3,118 +24502_c3_s1_00119.jpg,24502,24502,钟光喜,3,119 +24502_c3_s1_00120.jpg,24502,24502,钟光喜,3,120 +24502_c3_s1_00121.jpg,24502,24502,钟光喜,3,121 +24502_c3_s1_00122.jpg,24502,24502,钟光喜,3,122 +24502_c3_s1_00123.jpg,24502,24502,钟光喜,3,123 +24502_c3_s1_00124.jpg,24502,24502,钟光喜,3,124 +24502_c3_s1_00125.jpg,24502,24502,钟光喜,3,125 +24502_c3_s1_00126.jpg,24502,24502,钟光喜,3,126 +24502_c3_s1_00127.jpg,24502,24502,钟光喜,3,127 +24502_c3_s1_00128.jpg,24502,24502,钟光喜,3,128 +24502_c3_s1_00129.jpg,24502,24502,钟光喜,3,129 +24502_c3_s1_00130.jpg,24502,24502,钟光喜,3,130 +24502_c3_s1_00131.jpg,24502,24502,钟光喜,3,131 +24502_c3_s1_00132.jpg,24502,24502,钟光喜,3,132 +24502_c3_s1_00133.jpg,24502,24502,钟光喜,3,133 +24502_c3_s1_00134.jpg,24502,24502,钟光喜,3,134 +24502_c3_s1_00135.jpg,24502,24502,钟光喜,3,135 +24502_c3_s1_00136.jpg,24502,24502,钟光喜,3,136 +24502_c3_s1_00137.jpg,24502,24502,钟光喜,3,137 +24502_c3_s1_00138.jpg,24502,24502,钟光喜,3,138 +24502_c3_s1_00139.jpg,24502,24502,钟光喜,3,139 +24502_c3_s1_00140.jpg,24502,24502,钟光喜,3,140 +24502_c3_s1_00141.jpg,24502,24502,钟光喜,3,141 +24502_c3_s1_00142.jpg,24502,24502,钟光喜,3,142 +24502_c3_s1_00143.jpg,24502,24502,钟光喜,3,143 +24502_c3_s1_00144.jpg,24502,24502,钟光喜,3,144 +24502_c3_s1_00145.jpg,24502,24502,钟光喜,3,145 +24502_c3_s1_00146.jpg,24502,24502,钟光喜,3,146 +24503_c1_s1_00147.jpg,24503,24503,付玉峰,1,147 +24503_c1_s1_00148.jpg,24503,24503,付玉峰,1,148 +24503_c1_s1_00149.jpg,24503,24503,付玉峰,1,149 +24503_c1_s1_00150.jpg,24503,24503,付玉峰,1,150 +24503_c1_s1_00151.jpg,24503,24503,付玉峰,1,151 +24503_c1_s1_00152.jpg,24503,24503,付玉峰,1,152 +24503_c1_s1_00153.jpg,24503,24503,付玉峰,1,153 +24503_c1_s1_00154.jpg,24503,24503,付玉峰,1,154 +24503_c1_s1_00155.jpg,24503,24503,付玉峰,1,155 +24503_c1_s1_00156.jpg,24503,24503,付玉峰,1,156 +24503_c1_s1_00157.jpg,24503,24503,付玉峰,1,157 +24503_c1_s1_00158.jpg,24503,24503,付玉峰,1,158 +24503_c1_s1_00159.jpg,24503,24503,付玉峰,1,159 +24503_c1_s1_00160.jpg,24503,24503,付玉峰,1,160 +24503_c1_s1_00161.jpg,24503,24503,付玉峰,1,161 +24503_c1_s1_00162.jpg,24503,24503,付玉峰,1,162 +24503_c1_s1_00163.jpg,24503,24503,付玉峰,1,163 +24503_c1_s1_00164.jpg,24503,24503,付玉峰,1,164 +24503_c1_s1_00165.jpg,24503,24503,付玉峰,1,165 +24503_c1_s1_00166.jpg,24503,24503,付玉峰,1,166 +24503_c1_s1_00167.jpg,24503,24503,付玉峰,1,167 +24503_c1_s1_00168.jpg,24503,24503,付玉峰,1,168 +24503_c1_s1_00169.jpg,24503,24503,付玉峰,1,169 +24503_c1_s1_00170.jpg,24503,24503,付玉峰,1,170 +24503_c1_s1_00171.jpg,24503,24503,付玉峰,1,171 +24503_c1_s1_00172.jpg,24503,24503,付玉峰,1,172 +24503_c1_s1_00173.jpg,24503,24503,付玉峰,1,173 +24503_c1_s1_00174.jpg,24503,24503,付玉峰,1,174 +24503_c1_s1_00175.jpg,24503,24503,付玉峰,1,175 +24503_c1_s1_00176.jpg,24503,24503,付玉峰,1,176 +24503_c1_s1_00177.jpg,24503,24503,付玉峰,1,177 +24503_c1_s1_00178.jpg,24503,24503,付玉峰,1,178 +24503_c1_s1_00179.jpg,24503,24503,付玉峰,1,179 +24503_c1_s1_00180.jpg,24503,24503,付玉峰,1,180 +24503_c1_s1_00181.jpg,24503,24503,付玉峰,1,181 +24503_c1_s1_00182.jpg,24503,24503,付玉峰,1,182 +24503_c1_s1_00183.jpg,24503,24503,付玉峰,1,183 +24503_c1_s1_00184.jpg,24503,24503,付玉峰,1,184 +24503_c1_s1_00185.jpg,24503,24503,付玉峰,1,185 +24503_c1_s1_00186.jpg,24503,24503,付玉峰,1,186 +24503_c1_s1_00187.jpg,24503,24503,付玉峰,1,187 +24503_c2_s1_00188.jpg,24503,24503,付玉峰,2,188 +24503_c2_s1_00189.jpg,24503,24503,付玉峰,2,189 +24503_c2_s1_00190.jpg,24503,24503,付玉峰,2,190 +24503_c2_s1_00191.jpg,24503,24503,付玉峰,2,191 +24503_c2_s1_00192.jpg,24503,24503,付玉峰,2,192 +24503_c2_s1_00193.jpg,24503,24503,付玉峰,2,193 +24503_c2_s1_00194.jpg,24503,24503,付玉峰,2,194 +24503_c2_s1_00195.jpg,24503,24503,付玉峰,2,195 +24503_c2_s1_00196.jpg,24503,24503,付玉峰,2,196 +24503_c2_s1_00197.jpg,24503,24503,付玉峰,2,197 +24503_c2_s1_00198.jpg,24503,24503,付玉峰,2,198 +24503_c2_s1_00199.jpg,24503,24503,付玉峰,2,199 +24503_c2_s1_00200.jpg,24503,24503,付玉峰,2,200 +24503_c2_s1_00201.jpg,24503,24503,付玉峰,2,201 +24503_c2_s1_00202.jpg,24503,24503,付玉峰,2,202 +24503_c2_s1_00203.jpg,24503,24503,付玉峰,2,203 +24503_c2_s1_00204.jpg,24503,24503,付玉峰,2,204 +24503_c2_s1_00205.jpg,24503,24503,付玉峰,2,205 +24503_c2_s1_00206.jpg,24503,24503,付玉峰,2,206 +24503_c2_s1_00207.jpg,24503,24503,付玉峰,2,207 +24503_c2_s1_00208.jpg,24503,24503,付玉峰,2,208 +24503_c2_s1_00209.jpg,24503,24503,付玉峰,2,209 +24503_c2_s1_00210.jpg,24503,24503,付玉峰,2,210 +24503_c2_s1_00211.jpg,24503,24503,付玉峰,2,211 +24503_c2_s1_00212.jpg,24503,24503,付玉峰,2,212 +24503_c2_s1_00213.jpg,24503,24503,付玉峰,2,213 +24503_c2_s1_00214.jpg,24503,24503,付玉峰,2,214 +24503_c2_s1_00215.jpg,24503,24503,付玉峰,2,215 +24503_c2_s1_00216.jpg,24503,24503,付玉峰,2,216 +24503_c2_s1_00217.jpg,24503,24503,付玉峰,2,217 +24503_c2_s1_00218.jpg,24503,24503,付玉峰,2,218 +24503_c2_s1_00219.jpg,24503,24503,付玉峰,2,219 +24503_c2_s1_00220.jpg,24503,24503,付玉峰,2,220 +24503_c2_s1_00221.jpg,24503,24503,付玉峰,2,221 +24503_c2_s1_00222.jpg,24503,24503,付玉峰,2,222 +24503_c2_s1_00223.jpg,24503,24503,付玉峰,2,223 +24503_c2_s1_00224.jpg,24503,24503,付玉峰,2,224 +24503_c2_s1_00225.jpg,24503,24503,付玉峰,2,225 +24503_c2_s1_00226.jpg,24503,24503,付玉峰,2,226 +24503_c2_s1_00227.jpg,24503,24503,付玉峰,2,227 +24503_c2_s1_00228.jpg,24503,24503,付玉峰,2,228 +24503_c2_s1_00229.jpg,24503,24503,付玉峰,2,229 +24503_c2_s1_00230.jpg,24503,24503,付玉峰,2,230 +24503_c2_s1_00231.jpg,24503,24503,付玉峰,2,231 +24503_c2_s1_00232.jpg,24503,24503,付玉峰,2,232 +24503_c2_s1_00233.jpg,24503,24503,付玉峰,2,233 +24503_c2_s1_00234.jpg,24503,24503,付玉峰,2,234 +24503_c2_s1_00235.jpg,24503,24503,付玉峰,2,235 +24503_c2_s1_00236.jpg,24503,24503,付玉峰,2,236 +24503_c2_s1_00237.jpg,24503,24503,付玉峰,2,237 +24503_c2_s1_00238.jpg,24503,24503,付玉峰,2,238 +24503_c2_s1_00239.jpg,24503,24503,付玉峰,2,239 +24503_c2_s1_00240.jpg,24503,24503,付玉峰,2,240 +24503_c2_s1_00241.jpg,24503,24503,付玉峰,2,241 +24503_c2_s1_00242.jpg,24503,24503,付玉峰,2,242 +24503_c2_s1_00243.jpg,24503,24503,付玉峰,2,243 +24503_c2_s1_00244.jpg,24503,24503,付玉峰,2,244 +24503_c3_s1_00245.jpg,24503,24503,付玉峰,3,245 +24503_c3_s1_00246.jpg,24503,24503,付玉峰,3,246 +24503_c3_s1_00247.jpg,24503,24503,付玉峰,3,247 +24503_c3_s1_00248.jpg,24503,24503,付玉峰,3,248 +24503_c3_s1_00249.jpg,24503,24503,付玉峰,3,249 +24503_c3_s1_00250.jpg,24503,24503,付玉峰,3,250 +24503_c3_s1_00251.jpg,24503,24503,付玉峰,3,251 +24503_c3_s1_00252.jpg,24503,24503,付玉峰,3,252 +24503_c3_s1_00253.jpg,24503,24503,付玉峰,3,253 +24503_c3_s1_00254.jpg,24503,24503,付玉峰,3,254 +24503_c3_s1_00255.jpg,24503,24503,付玉峰,3,255 +24503_c3_s1_00256.jpg,24503,24503,付玉峰,3,256 +24503_c3_s1_00257.jpg,24503,24503,付玉峰,3,257 +24503_c3_s1_00258.jpg,24503,24503,付玉峰,3,258 +24503_c3_s1_00259.jpg,24503,24503,付玉峰,3,259 +24503_c3_s1_00260.jpg,24503,24503,付玉峰,3,260 +24503_c3_s1_00261.jpg,24503,24503,付玉峰,3,261 +24503_c3_s1_00262.jpg,24503,24503,付玉峰,3,262 +24503_c3_s1_00263.jpg,24503,24503,付玉峰,3,263 +24503_c3_s1_00264.jpg,24503,24503,付玉峰,3,264 +24503_c3_s1_00265.jpg,24503,24503,付玉峰,3,265 +24503_c3_s1_00266.jpg,24503,24503,付玉峰,3,266 +24503_c3_s1_00267.jpg,24503,24503,付玉峰,3,267 +24503_c3_s1_00268.jpg,24503,24503,付玉峰,3,268 +24503_c3_s1_00269.jpg,24503,24503,付玉峰,3,269 +24503_c3_s1_00270.jpg,24503,24503,付玉峰,3,270 +24503_c3_s1_00271.jpg,24503,24503,付玉峰,3,271 +24503_c3_s1_00272.jpg,24503,24503,付玉峰,3,272 +24503_c3_s1_00273.jpg,24503,24503,付玉峰,3,273 +24503_c3_s1_00274.jpg,24503,24503,付玉峰,3,274 +24503_c3_s1_00275.jpg,24503,24503,付玉峰,3,275 +24503_c3_s1_00276.jpg,24503,24503,付玉峰,3,276 +24503_c3_s1_00277.jpg,24503,24503,付玉峰,3,277 +24503_c3_s1_00278.jpg,24503,24503,付玉峰,3,278 +24503_c3_s1_00279.jpg,24503,24503,付玉峰,3,279 +24503_c3_s1_00280.jpg,24503,24503,付玉峰,3,280 +24503_c3_s1_00281.jpg,24503,24503,付玉峰,3,281 +24503_c3_s1_00282.jpg,24503,24503,付玉峰,3,282 +24503_c3_s1_00283.jpg,24503,24503,付玉峰,3,283 +24503_c3_s1_00284.jpg,24503,24503,付玉峰,3,284 +24503_c3_s1_00285.jpg,24503,24503,付玉峰,3,285 +24503_c3_s1_00286.jpg,24503,24503,付玉峰,3,286 +24503_c3_s1_00287.jpg,24503,24503,付玉峰,3,287 +24503_c3_s1_00288.jpg,24503,24503,付玉峰,3,288 +24503_c3_s1_00289.jpg,24503,24503,付玉峰,3,289 +24503_c3_s1_00290.jpg,24503,24503,付玉峰,3,290 +24503_c3_s1_00291.jpg,24503,24503,付玉峰,3,291 +24503_c3_s1_00292.jpg,24503,24503,付玉峰,3,292 +24503_c3_s1_00293.jpg,24503,24503,付玉峰,3,293 +24503_c3_s1_00294.jpg,24503,24503,付玉峰,3,294 +24503_c3_s1_00295.jpg,24503,24503,付玉峰,3,295 +24503_c3_s1_00296.jpg,24503,24503,付玉峰,3,296 +24503_c3_s1_00297.jpg,24503,24503,付玉峰,3,297 +24503_c3_s1_00298.jpg,24503,24503,付玉峰,3,298 +24504_c1_s1_00299.jpg,24504,24504,李树华,1,299 +24504_c1_s1_00300.jpg,24504,24504,李树华,1,300 +24504_c1_s1_00301.jpg,24504,24504,李树华,1,301 +24504_c1_s1_00302.jpg,24504,24504,李树华,1,302 +24504_c1_s1_00303.jpg,24504,24504,李树华,1,303 +24504_c1_s1_00304.jpg,24504,24504,李树华,1,304 +24504_c1_s1_00305.jpg,24504,24504,李树华,1,305 +24504_c1_s1_00306.jpg,24504,24504,李树华,1,306 +24504_c1_s1_00307.jpg,24504,24504,李树华,1,307 +24504_c1_s1_00308.jpg,24504,24504,李树华,1,308 +24504_c1_s1_00309.jpg,24504,24504,李树华,1,309 +24504_c1_s1_00310.jpg,24504,24504,李树华,1,310 +24504_c1_s1_00311.jpg,24504,24504,李树华,1,311 +24504_c1_s1_00312.jpg,24504,24504,李树华,1,312 +24504_c1_s1_00313.jpg,24504,24504,李树华,1,313 +24504_c1_s1_00314.jpg,24504,24504,李树华,1,314 +24504_c1_s1_00315.jpg,24504,24504,李树华,1,315 +24504_c1_s1_00316.jpg,24504,24504,李树华,1,316 +24504_c1_s1_00317.jpg,24504,24504,李树华,1,317 +24504_c1_s1_00318.jpg,24504,24504,李树华,1,318 +24504_c1_s1_00319.jpg,24504,24504,李树华,1,319 +24504_c1_s1_00320.jpg,24504,24504,李树华,1,320 +24504_c1_s1_00321.jpg,24504,24504,李树华,1,321 +24504_c1_s1_00322.jpg,24504,24504,李树华,1,322 +24504_c1_s1_00323.jpg,24504,24504,李树华,1,323 +24504_c1_s1_00324.jpg,24504,24504,李树华,1,324 +24504_c1_s1_00325.jpg,24504,24504,李树华,1,325 +24504_c1_s1_00326.jpg,24504,24504,李树华,1,326 +24504_c1_s1_00327.jpg,24504,24504,李树华,1,327 +24504_c1_s1_00328.jpg,24504,24504,李树华,1,328 +24504_c1_s1_00329.jpg,24504,24504,李树华,1,329 +24504_c1_s1_00330.jpg,24504,24504,李树华,1,330 +24504_c1_s1_00331.jpg,24504,24504,李树华,1,331 +24504_c1_s1_00332.jpg,24504,24504,李树华,1,332 +24504_c1_s1_00333.jpg,24504,24504,李树华,1,333 +24504_c1_s1_00334.jpg,24504,24504,李树华,1,334 +24504_c1_s1_00335.jpg,24504,24504,李树华,1,335 +24504_c1_s1_00336.jpg,24504,24504,李树华,1,336 +24504_c1_s1_00337.jpg,24504,24504,李树华,1,337 +24504_c1_s1_00338.jpg,24504,24504,李树华,1,338 +24504_c1_s1_00339.jpg,24504,24504,李树华,1,339 +24504_c1_s1_00340.jpg,24504,24504,李树华,1,340 +24504_c2_s1_00341.jpg,24504,24504,李树华,2,341 +24504_c2_s1_00342.jpg,24504,24504,李树华,2,342 +24504_c2_s1_00343.jpg,24504,24504,李树华,2,343 +24504_c2_s1_00344.jpg,24504,24504,李树华,2,344 +24504_c2_s1_00345.jpg,24504,24504,李树华,2,345 +24504_c2_s1_00346.jpg,24504,24504,李树华,2,346 +24504_c2_s1_00347.jpg,24504,24504,李树华,2,347 +24504_c2_s1_00348.jpg,24504,24504,李树华,2,348 +24504_c2_s1_00349.jpg,24504,24504,李树华,2,349 +24504_c2_s1_00350.jpg,24504,24504,李树华,2,350 +24504_c2_s1_00351.jpg,24504,24504,李树华,2,351 +24504_c2_s1_00352.jpg,24504,24504,李树华,2,352 +24504_c2_s1_00353.jpg,24504,24504,李树华,2,353 +24504_c2_s1_00354.jpg,24504,24504,李树华,2,354 +24504_c2_s1_00355.jpg,24504,24504,李树华,2,355 +24504_c2_s1_00356.jpg,24504,24504,李树华,2,356 +24504_c2_s1_00357.jpg,24504,24504,李树华,2,357 +24504_c2_s1_00358.jpg,24504,24504,李树华,2,358 +24504_c2_s1_00359.jpg,24504,24504,李树华,2,359 +24504_c2_s1_00360.jpg,24504,24504,李树华,2,360 +24504_c2_s1_00361.jpg,24504,24504,李树华,2,361 +24504_c2_s1_00362.jpg,24504,24504,李树华,2,362 +24504_c2_s1_00363.jpg,24504,24504,李树华,2,363 +24504_c2_s1_00364.jpg,24504,24504,李树华,2,364 +24504_c2_s1_00365.jpg,24504,24504,李树华,2,365 +24504_c2_s1_00366.jpg,24504,24504,李树华,2,366 +24504_c2_s1_00367.jpg,24504,24504,李树华,2,367 +24504_c2_s1_00368.jpg,24504,24504,李树华,2,368 +24504_c2_s1_00369.jpg,24504,24504,李树华,2,369 +24504_c2_s1_00370.jpg,24504,24504,李树华,2,370 +24504_c2_s1_00371.jpg,24504,24504,李树华,2,371 +24504_c2_s1_00372.jpg,24504,24504,李树华,2,372 +24504_c2_s1_00373.jpg,24504,24504,李树华,2,373 +24504_c2_s1_00374.jpg,24504,24504,李树华,2,374 +24504_c2_s1_00375.jpg,24504,24504,李树华,2,375 +24504_c2_s1_00376.jpg,24504,24504,李树华,2,376 +24504_c2_s1_00377.jpg,24504,24504,李树华,2,377 +24504_c2_s1_00378.jpg,24504,24504,李树华,2,378 +24504_c2_s1_00379.jpg,24504,24504,李树华,2,379 +24504_c2_s1_00380.jpg,24504,24504,李树华,2,380 +24504_c2_s1_00381.jpg,24504,24504,李树华,2,381 +24504_c2_s1_00382.jpg,24504,24504,李树华,2,382 +24504_c3_s1_00383.jpg,24504,24504,李树华,3,383 +24504_c3_s1_00384.jpg,24504,24504,李树华,3,384 +24504_c3_s1_00385.jpg,24504,24504,李树华,3,385 +24504_c3_s1_00386.jpg,24504,24504,李树华,3,386 +24504_c3_s1_00387.jpg,24504,24504,李树华,3,387 +24504_c3_s1_00388.jpg,24504,24504,李树华,3,388 +24504_c3_s1_00389.jpg,24504,24504,李树华,3,389 +24504_c3_s1_00390.jpg,24504,24504,李树华,3,390 +24504_c3_s1_00391.jpg,24504,24504,李树华,3,391 +24504_c3_s1_00392.jpg,24504,24504,李树华,3,392 +24504_c3_s1_00393.jpg,24504,24504,李树华,3,393 +24504_c3_s1_00394.jpg,24504,24504,李树华,3,394 +24504_c3_s1_00395.jpg,24504,24504,李树华,3,395 +24504_c3_s1_00396.jpg,24504,24504,李树华,3,396 +24504_c3_s1_00397.jpg,24504,24504,李树华,3,397 +24504_c3_s1_00398.jpg,24504,24504,李树华,3,398 +24504_c3_s1_00399.jpg,24504,24504,李树华,3,399 +24504_c3_s1_00400.jpg,24504,24504,李树华,3,400 +24504_c3_s1_00401.jpg,24504,24504,李树华,3,401 +24504_c3_s1_00402.jpg,24504,24504,李树华,3,402 +24504_c3_s1_00403.jpg,24504,24504,李树华,3,403 +24504_c3_s1_00404.jpg,24504,24504,李树华,3,404 +24504_c3_s1_00405.jpg,24504,24504,李树华,3,405 +24504_c3_s1_00406.jpg,24504,24504,李树华,3,406 +24504_c3_s1_00407.jpg,24504,24504,李树华,3,407 +24504_c3_s1_00408.jpg,24504,24504,李树华,3,408 +24504_c3_s1_00409.jpg,24504,24504,李树华,3,409 +24504_c3_s1_00410.jpg,24504,24504,李树华,3,410 +24504_c3_s1_00411.jpg,24504,24504,李树华,3,411 +24504_c3_s1_00412.jpg,24504,24504,李树华,3,412 +24504_c3_s1_00413.jpg,24504,24504,李树华,3,413 +24504_c3_s1_00414.jpg,24504,24504,李树华,3,414 +24504_c3_s1_00415.jpg,24504,24504,李树华,3,415 +24504_c3_s1_00416.jpg,24504,24504,李树华,3,416 +24504_c3_s1_00417.jpg,24504,24504,李树华,3,417 +24504_c3_s1_00418.jpg,24504,24504,李树华,3,418 +24504_c3_s1_00419.jpg,24504,24504,李树华,3,419 +24504_c3_s1_00420.jpg,24504,24504,李树华,3,420 +24505_c1_s1_00421.jpg,24505,24505,刘杰,1,421 +24505_c1_s1_00422.jpg,24505,24505,刘杰,1,422 +24505_c1_s1_00423.jpg,24505,24505,刘杰,1,423 +24505_c1_s1_00424.jpg,24505,24505,刘杰,1,424 +24505_c1_s1_00425.jpg,24505,24505,刘杰,1,425 +24505_c1_s1_00426.jpg,24505,24505,刘杰,1,426 +24505_c1_s1_00427.jpg,24505,24505,刘杰,1,427 +24505_c1_s1_00428.jpg,24505,24505,刘杰,1,428 +24505_c1_s1_00429.jpg,24505,24505,刘杰,1,429 +24505_c1_s1_00430.jpg,24505,24505,刘杰,1,430 +24505_c1_s1_00431.jpg,24505,24505,刘杰,1,431 +24505_c1_s1_00432.jpg,24505,24505,刘杰,1,432 +24505_c1_s1_00433.jpg,24505,24505,刘杰,1,433 +24505_c1_s1_00434.jpg,24505,24505,刘杰,1,434 +24505_c1_s1_00435.jpg,24505,24505,刘杰,1,435 +24505_c1_s1_00436.jpg,24505,24505,刘杰,1,436 +24505_c1_s1_00437.jpg,24505,24505,刘杰,1,437 +24505_c1_s1_00438.jpg,24505,24505,刘杰,1,438 +24505_c1_s1_00439.jpg,24505,24505,刘杰,1,439 +24505_c1_s1_00440.jpg,24505,24505,刘杰,1,440 +24505_c1_s1_00441.jpg,24505,24505,刘杰,1,441 +24505_c1_s1_00442.jpg,24505,24505,刘杰,1,442 +24505_c1_s1_00443.jpg,24505,24505,刘杰,1,443 +24505_c1_s1_00444.jpg,24505,24505,刘杰,1,444 +24505_c1_s1_00445.jpg,24505,24505,刘杰,1,445 +24505_c1_s1_00446.jpg,24505,24505,刘杰,1,446 +24505_c1_s1_00447.jpg,24505,24505,刘杰,1,447 +24505_c1_s1_00448.jpg,24505,24505,刘杰,1,448 +24505_c1_s1_00449.jpg,24505,24505,刘杰,1,449 +24505_c1_s1_00450.jpg,24505,24505,刘杰,1,450 +24505_c1_s1_00451.jpg,24505,24505,刘杰,1,451 +24505_c1_s1_00452.jpg,24505,24505,刘杰,1,452 +24505_c1_s1_00453.jpg,24505,24505,刘杰,1,453 +24505_c1_s1_00454.jpg,24505,24505,刘杰,1,454 +24505_c1_s1_00455.jpg,24505,24505,刘杰,1,455 +24505_c1_s1_00456.jpg,24505,24505,刘杰,1,456 +24505_c1_s1_00457.jpg,24505,24505,刘杰,1,457 +24505_c1_s1_00458.jpg,24505,24505,刘杰,1,458 +24505_c1_s1_00459.jpg,24505,24505,刘杰,1,459 +24505_c1_s1_00460.jpg,24505,24505,刘杰,1,460 +24505_c1_s1_00461.jpg,24505,24505,刘杰,1,461 +24505_c1_s1_00462.jpg,24505,24505,刘杰,1,462 +24505_c1_s1_00463.jpg,24505,24505,刘杰,1,463 +24505_c1_s1_00464.jpg,24505,24505,刘杰,1,464 +24505_c1_s1_00465.jpg,24505,24505,刘杰,1,465 +24505_c1_s1_00466.jpg,24505,24505,刘杰,1,466 +24505_c1_s1_00467.jpg,24505,24505,刘杰,1,467 +24505_c1_s1_00468.jpg,24505,24505,刘杰,1,468 +24505_c1_s1_00469.jpg,24505,24505,刘杰,1,469 +24505_c1_s1_00470.jpg,24505,24505,刘杰,1,470 +24505_c1_s1_00471.jpg,24505,24505,刘杰,1,471 +24505_c1_s1_00472.jpg,24505,24505,刘杰,1,472 +24505_c1_s1_00473.jpg,24505,24505,刘杰,1,473 +24505_c1_s1_00474.jpg,24505,24505,刘杰,1,474 +24505_c1_s1_00475.jpg,24505,24505,刘杰,1,475 +24505_c1_s1_00476.jpg,24505,24505,刘杰,1,476 +24505_c1_s1_00477.jpg,24505,24505,刘杰,1,477 +24505_c1_s1_00478.jpg,24505,24505,刘杰,1,478 +24505_c1_s1_00479.jpg,24505,24505,刘杰,1,479 +24505_c1_s1_00480.jpg,24505,24505,刘杰,1,480 +24505_c1_s1_00481.jpg,24505,24505,刘杰,1,481 +24505_c1_s1_00482.jpg,24505,24505,刘杰,1,482 +24505_c1_s1_00483.jpg,24505,24505,刘杰,1,483 +24505_c2_s1_00484.jpg,24505,24505,刘杰,2,484 +24505_c2_s1_00485.jpg,24505,24505,刘杰,2,485 +24505_c2_s1_00486.jpg,24505,24505,刘杰,2,486 +24505_c2_s1_00487.jpg,24505,24505,刘杰,2,487 +24505_c2_s1_00488.jpg,24505,24505,刘杰,2,488 +24505_c2_s1_00489.jpg,24505,24505,刘杰,2,489 +24505_c2_s1_00490.jpg,24505,24505,刘杰,2,490 +24505_c2_s1_00491.jpg,24505,24505,刘杰,2,491 +24505_c2_s1_00492.jpg,24505,24505,刘杰,2,492 +24505_c2_s1_00493.jpg,24505,24505,刘杰,2,493 +24505_c2_s1_00494.jpg,24505,24505,刘杰,2,494 +24505_c2_s1_00495.jpg,24505,24505,刘杰,2,495 +24505_c2_s1_00496.jpg,24505,24505,刘杰,2,496 +24505_c2_s1_00497.jpg,24505,24505,刘杰,2,497 +24505_c2_s1_00498.jpg,24505,24505,刘杰,2,498 +24505_c2_s1_00499.jpg,24505,24505,刘杰,2,499 +24505_c2_s1_00500.jpg,24505,24505,刘杰,2,500 +24505_c2_s1_00501.jpg,24505,24505,刘杰,2,501 +24505_c2_s1_00502.jpg,24505,24505,刘杰,2,502 +24505_c2_s1_00503.jpg,24505,24505,刘杰,2,503 +24505_c2_s1_00504.jpg,24505,24505,刘杰,2,504 +24505_c2_s1_00505.jpg,24505,24505,刘杰,2,505 +24505_c2_s1_00506.jpg,24505,24505,刘杰,2,506 +24505_c2_s1_00507.jpg,24505,24505,刘杰,2,507 +24505_c2_s1_00508.jpg,24505,24505,刘杰,2,508 +24505_c2_s1_00509.jpg,24505,24505,刘杰,2,509 +24505_c2_s1_00510.jpg,24505,24505,刘杰,2,510 +24505_c2_s1_00511.jpg,24505,24505,刘杰,2,511 +24505_c2_s1_00512.jpg,24505,24505,刘杰,2,512 +24505_c2_s1_00513.jpg,24505,24505,刘杰,2,513 +24505_c2_s1_00514.jpg,24505,24505,刘杰,2,514 +24505_c2_s1_00515.jpg,24505,24505,刘杰,2,515 +24505_c2_s1_00516.jpg,24505,24505,刘杰,2,516 +24505_c2_s1_00517.jpg,24505,24505,刘杰,2,517 +24505_c2_s1_00518.jpg,24505,24505,刘杰,2,518 +24505_c2_s1_00519.jpg,24505,24505,刘杰,2,519 +24505_c2_s1_00520.jpg,24505,24505,刘杰,2,520 +24505_c2_s1_00521.jpg,24505,24505,刘杰,2,521 +24505_c2_s1_00522.jpg,24505,24505,刘杰,2,522 +24505_c2_s1_00523.jpg,24505,24505,刘杰,2,523 +24505_c2_s1_00524.jpg,24505,24505,刘杰,2,524 +24505_c2_s1_00525.jpg,24505,24505,刘杰,2,525 +24505_c2_s1_00526.jpg,24505,24505,刘杰,2,526 +24505_c2_s1_00527.jpg,24505,24505,刘杰,2,527 +24505_c2_s1_00528.jpg,24505,24505,刘杰,2,528 +24505_c2_s1_00529.jpg,24505,24505,刘杰,2,529 +24505_c2_s1_00530.jpg,24505,24505,刘杰,2,530 +24505_c2_s1_00531.jpg,24505,24505,刘杰,2,531 +24505_c2_s1_00532.jpg,24505,24505,刘杰,2,532 +24505_c2_s1_00533.jpg,24505,24505,刘杰,2,533 +24505_c2_s1_00534.jpg,24505,24505,刘杰,2,534 +24505_c2_s1_00535.jpg,24505,24505,刘杰,2,535 +24505_c2_s1_00536.jpg,24505,24505,刘杰,2,536 +24505_c2_s1_00537.jpg,24505,24505,刘杰,2,537 +24505_c2_s1_00538.jpg,24505,24505,刘杰,2,538 +24505_c2_s1_00539.jpg,24505,24505,刘杰,2,539 +24505_c2_s1_00540.jpg,24505,24505,刘杰,2,540 +24505_c2_s1_00541.jpg,24505,24505,刘杰,2,541 +24505_c2_s1_00542.jpg,24505,24505,刘杰,2,542 +24505_c2_s1_00543.jpg,24505,24505,刘杰,2,543 +24505_c2_s1_00544.jpg,24505,24505,刘杰,2,544 +24505_c2_s1_00545.jpg,24505,24505,刘杰,2,545 +24505_c2_s1_00546.jpg,24505,24505,刘杰,2,546 +24505_c2_s1_00547.jpg,24505,24505,刘杰,2,547 +24505_c2_s1_00548.jpg,24505,24505,刘杰,2,548 +24505_c2_s1_00549.jpg,24505,24505,刘杰,2,549 +24505_c2_s1_00550.jpg,24505,24505,刘杰,2,550 +24505_c2_s1_00551.jpg,24505,24505,刘杰,2,551 +24505_c2_s1_00552.jpg,24505,24505,刘杰,2,552 +24505_c2_s1_00553.jpg,24505,24505,刘杰,2,553 +24505_c2_s1_00554.jpg,24505,24505,刘杰,2,554 +24505_c2_s1_00555.jpg,24505,24505,刘杰,2,555 +24505_c3_s1_00556.jpg,24505,24505,刘杰,3,556 +24505_c3_s1_00557.jpg,24505,24505,刘杰,3,557 +24505_c3_s1_00558.jpg,24505,24505,刘杰,3,558 +24505_c3_s1_00559.jpg,24505,24505,刘杰,3,559 +24505_c3_s1_00560.jpg,24505,24505,刘杰,3,560 +24505_c3_s1_00561.jpg,24505,24505,刘杰,3,561 +24505_c3_s1_00562.jpg,24505,24505,刘杰,3,562 +24505_c3_s1_00563.jpg,24505,24505,刘杰,3,563 +24505_c3_s1_00564.jpg,24505,24505,刘杰,3,564 +24505_c3_s1_00565.jpg,24505,24505,刘杰,3,565 +24505_c3_s1_00566.jpg,24505,24505,刘杰,3,566 +24505_c3_s1_00567.jpg,24505,24505,刘杰,3,567 +24505_c3_s1_00568.jpg,24505,24505,刘杰,3,568 +24505_c3_s1_00569.jpg,24505,24505,刘杰,3,569 +24505_c3_s1_00570.jpg,24505,24505,刘杰,3,570 +24505_c3_s1_00571.jpg,24505,24505,刘杰,3,571 +24505_c3_s1_00572.jpg,24505,24505,刘杰,3,572 +24505_c3_s1_00573.jpg,24505,24505,刘杰,3,573 +24505_c3_s1_00574.jpg,24505,24505,刘杰,3,574 +24505_c3_s1_00575.jpg,24505,24505,刘杰,3,575 +24505_c3_s1_00576.jpg,24505,24505,刘杰,3,576 +24505_c3_s1_00577.jpg,24505,24505,刘杰,3,577 +24505_c3_s1_00578.jpg,24505,24505,刘杰,3,578 +24505_c3_s1_00579.jpg,24505,24505,刘杰,3,579 +24505_c3_s1_00580.jpg,24505,24505,刘杰,3,580 +24505_c3_s1_00581.jpg,24505,24505,刘杰,3,581 +24505_c3_s1_00582.jpg,24505,24505,刘杰,3,582 +24505_c3_s1_00583.jpg,24505,24505,刘杰,3,583 +24505_c3_s1_00584.jpg,24505,24505,刘杰,3,584 +24505_c3_s1_00585.jpg,24505,24505,刘杰,3,585 +24505_c3_s1_00586.jpg,24505,24505,刘杰,3,586 +24505_c3_s1_00587.jpg,24505,24505,刘杰,3,587 +24505_c3_s1_00588.jpg,24505,24505,刘杰,3,588 +24505_c3_s1_00589.jpg,24505,24505,刘杰,3,589 +24505_c3_s1_00590.jpg,24505,24505,刘杰,3,590 +24505_c3_s1_00591.jpg,24505,24505,刘杰,3,591 +24505_c3_s1_00592.jpg,24505,24505,刘杰,3,592 +24505_c3_s1_00593.jpg,24505,24505,刘杰,3,593 +24505_c3_s1_00594.jpg,24505,24505,刘杰,3,594 +24505_c3_s1_00595.jpg,24505,24505,刘杰,3,595 +24505_c3_s1_00596.jpg,24505,24505,刘杰,3,596 +24505_c3_s1_00597.jpg,24505,24505,刘杰,3,597 +24505_c3_s1_00598.jpg,24505,24505,刘杰,3,598 +24505_c3_s1_00599.jpg,24505,24505,刘杰,3,599 +24505_c3_s1_00600.jpg,24505,24505,刘杰,3,600 +24505_c3_s1_00601.jpg,24505,24505,刘杰,3,601 +24505_c3_s1_00602.jpg,24505,24505,刘杰,3,602 +24505_c3_s1_00603.jpg,24505,24505,刘杰,3,603 +24505_c3_s1_00604.jpg,24505,24505,刘杰,3,604 +24505_c3_s1_00605.jpg,24505,24505,刘杰,3,605 +24505_c3_s1_00606.jpg,24505,24505,刘杰,3,606 +24505_c3_s1_00607.jpg,24505,24505,刘杰,3,607 +24505_c3_s1_00608.jpg,24505,24505,刘杰,3,608 +24505_c3_s1_00609.jpg,24505,24505,刘杰,3,609 +24505_c3_s1_00610.jpg,24505,24505,刘杰,3,610 +24505_c3_s1_00611.jpg,24505,24505,刘杰,3,611 +24505_c3_s1_00612.jpg,24505,24505,刘杰,3,612 +24505_c3_s1_00613.jpg,24505,24505,刘杰,3,613 +24505_c3_s1_00614.jpg,24505,24505,刘杰,3,614 +24505_c3_s1_00615.jpg,24505,24505,刘杰,3,615 +24505_c3_s1_00616.jpg,24505,24505,刘杰,3,616 +24505_c3_s1_00617.jpg,24505,24505,刘杰,3,617 +24505_c3_s1_00618.jpg,24505,24505,刘杰,3,618 +24505_c3_s1_00619.jpg,24505,24505,刘杰,3,619 +24505_c3_s1_00620.jpg,24505,24505,刘杰,3,620 +24505_c3_s1_00621.jpg,24505,24505,刘杰,3,621 +24505_c3_s1_00622.jpg,24505,24505,刘杰,3,622 +24506_c1_s1_00623.jpg,24506,24506,黄伟斌,1,623 +24506_c1_s1_00624.jpg,24506,24506,黄伟斌,1,624 +24506_c1_s1_00625.jpg,24506,24506,黄伟斌,1,625 +24506_c1_s1_00626.jpg,24506,24506,黄伟斌,1,626 +24506_c1_s1_00627.jpg,24506,24506,黄伟斌,1,627 +24506_c1_s1_00628.jpg,24506,24506,黄伟斌,1,628 +24506_c1_s1_00629.jpg,24506,24506,黄伟斌,1,629 +24506_c1_s1_00630.jpg,24506,24506,黄伟斌,1,630 +24506_c1_s1_00631.jpg,24506,24506,黄伟斌,1,631 +24506_c1_s1_00632.jpg,24506,24506,黄伟斌,1,632 +24506_c1_s1_00633.jpg,24506,24506,黄伟斌,1,633 +24506_c1_s1_00634.jpg,24506,24506,黄伟斌,1,634 +24506_c1_s1_00635.jpg,24506,24506,黄伟斌,1,635 +24506_c1_s1_00636.jpg,24506,24506,黄伟斌,1,636 +24506_c1_s1_00637.jpg,24506,24506,黄伟斌,1,637 +24506_c1_s1_00638.jpg,24506,24506,黄伟斌,1,638 +24506_c1_s1_00639.jpg,24506,24506,黄伟斌,1,639 +24506_c1_s1_00640.jpg,24506,24506,黄伟斌,1,640 +24506_c1_s1_00641.jpg,24506,24506,黄伟斌,1,641 +24506_c1_s1_00642.jpg,24506,24506,黄伟斌,1,642 +24506_c1_s1_00643.jpg,24506,24506,黄伟斌,1,643 +24506_c1_s1_00644.jpg,24506,24506,黄伟斌,1,644 +24506_c1_s1_00645.jpg,24506,24506,黄伟斌,1,645 +24506_c1_s1_00646.jpg,24506,24506,黄伟斌,1,646 +24506_c1_s1_00647.jpg,24506,24506,黄伟斌,1,647 +24506_c1_s1_00648.jpg,24506,24506,黄伟斌,1,648 +24506_c1_s1_00649.jpg,24506,24506,黄伟斌,1,649 +24506_c1_s1_00650.jpg,24506,24506,黄伟斌,1,650 +24506_c1_s1_00651.jpg,24506,24506,黄伟斌,1,651 +24506_c1_s1_00652.jpg,24506,24506,黄伟斌,1,652 +24506_c1_s1_00653.jpg,24506,24506,黄伟斌,1,653 +24506_c1_s1_00654.jpg,24506,24506,黄伟斌,1,654 +24506_c1_s1_00655.jpg,24506,24506,黄伟斌,1,655 +24506_c1_s1_00656.jpg,24506,24506,黄伟斌,1,656 +24506_c1_s1_00657.jpg,24506,24506,黄伟斌,1,657 +24506_c1_s1_00658.jpg,24506,24506,黄伟斌,1,658 +24506_c2_s1_00659.jpg,24506,24506,黄伟斌,2,659 +24506_c2_s1_00660.jpg,24506,24506,黄伟斌,2,660 +24506_c2_s1_00661.jpg,24506,24506,黄伟斌,2,661 +24506_c2_s1_00662.jpg,24506,24506,黄伟斌,2,662 +24506_c2_s1_00663.jpg,24506,24506,黄伟斌,2,663 +24506_c2_s1_00664.jpg,24506,24506,黄伟斌,2,664 +24506_c2_s1_00665.jpg,24506,24506,黄伟斌,2,665 +24506_c2_s1_00666.jpg,24506,24506,黄伟斌,2,666 +24506_c2_s1_00667.jpg,24506,24506,黄伟斌,2,667 +24506_c2_s1_00668.jpg,24506,24506,黄伟斌,2,668 +24506_c2_s1_00669.jpg,24506,24506,黄伟斌,2,669 +24506_c2_s1_00670.jpg,24506,24506,黄伟斌,2,670 +24506_c2_s1_00671.jpg,24506,24506,黄伟斌,2,671 +24506_c2_s1_00672.jpg,24506,24506,黄伟斌,2,672 +24506_c2_s1_00673.jpg,24506,24506,黄伟斌,2,673 +24506_c2_s1_00674.jpg,24506,24506,黄伟斌,2,674 +24506_c2_s1_00675.jpg,24506,24506,黄伟斌,2,675 +24506_c2_s1_00676.jpg,24506,24506,黄伟斌,2,676 +24506_c2_s1_00677.jpg,24506,24506,黄伟斌,2,677 +24506_c2_s1_00678.jpg,24506,24506,黄伟斌,2,678 +24506_c2_s1_00679.jpg,24506,24506,黄伟斌,2,679 +24506_c2_s1_00680.jpg,24506,24506,黄伟斌,2,680 +24506_c2_s1_00681.jpg,24506,24506,黄伟斌,2,681 +24506_c2_s1_00682.jpg,24506,24506,黄伟斌,2,682 +24506_c2_s1_00683.jpg,24506,24506,黄伟斌,2,683 +24506_c2_s1_00684.jpg,24506,24506,黄伟斌,2,684 +24506_c2_s1_00685.jpg,24506,24506,黄伟斌,2,685 +24506_c2_s1_00686.jpg,24506,24506,黄伟斌,2,686 +24506_c2_s1_00687.jpg,24506,24506,黄伟斌,2,687 +24506_c2_s1_00688.jpg,24506,24506,黄伟斌,2,688 +24506_c2_s1_00689.jpg,24506,24506,黄伟斌,2,689 +24506_c2_s1_00690.jpg,24506,24506,黄伟斌,2,690 +24506_c2_s1_00691.jpg,24506,24506,黄伟斌,2,691 +24506_c2_s1_00692.jpg,24506,24506,黄伟斌,2,692 +24506_c2_s1_00693.jpg,24506,24506,黄伟斌,2,693 +24506_c2_s1_00694.jpg,24506,24506,黄伟斌,2,694 +24506_c2_s1_00695.jpg,24506,24506,黄伟斌,2,695 +24506_c2_s1_00696.jpg,24506,24506,黄伟斌,2,696 +24506_c2_s1_00697.jpg,24506,24506,黄伟斌,2,697 +24506_c2_s1_00698.jpg,24506,24506,黄伟斌,2,698 +24506_c2_s1_00699.jpg,24506,24506,黄伟斌,2,699 +24506_c2_s1_00700.jpg,24506,24506,黄伟斌,2,700 +24506_c2_s1_00701.jpg,24506,24506,黄伟斌,2,701 +24506_c2_s1_00702.jpg,24506,24506,黄伟斌,2,702 +24506_c2_s1_00703.jpg,24506,24506,黄伟斌,2,703 +24506_c2_s1_00704.jpg,24506,24506,黄伟斌,2,704 +24506_c2_s1_00705.jpg,24506,24506,黄伟斌,2,705 +24506_c2_s1_00706.jpg,24506,24506,黄伟斌,2,706 +24506_c2_s1_00707.jpg,24506,24506,黄伟斌,2,707 +24506_c2_s1_00708.jpg,24506,24506,黄伟斌,2,708 +24506_c2_s1_00709.jpg,24506,24506,黄伟斌,2,709 +24506_c2_s1_00710.jpg,24506,24506,黄伟斌,2,710 +24506_c2_s1_00711.jpg,24506,24506,黄伟斌,2,711 +24506_c2_s1_00712.jpg,24506,24506,黄伟斌,2,712 +24506_c2_s1_00713.jpg,24506,24506,黄伟斌,2,713 +24506_c2_s1_00714.jpg,24506,24506,黄伟斌,2,714 +24506_c2_s1_00715.jpg,24506,24506,黄伟斌,2,715 +24506_c2_s1_00716.jpg,24506,24506,黄伟斌,2,716 +24506_c2_s1_00717.jpg,24506,24506,黄伟斌,2,717 +24506_c2_s1_00718.jpg,24506,24506,黄伟斌,2,718 +24506_c3_s1_00719.jpg,24506,24506,黄伟斌,3,719 +24506_c3_s1_00720.jpg,24506,24506,黄伟斌,3,720 +24506_c3_s1_00721.jpg,24506,24506,黄伟斌,3,721 +24506_c3_s1_00722.jpg,24506,24506,黄伟斌,3,722 +24506_c3_s1_00723.jpg,24506,24506,黄伟斌,3,723 +24506_c3_s1_00724.jpg,24506,24506,黄伟斌,3,724 +24506_c3_s1_00725.jpg,24506,24506,黄伟斌,3,725 +24506_c3_s1_00726.jpg,24506,24506,黄伟斌,3,726 +24506_c3_s1_00727.jpg,24506,24506,黄伟斌,3,727 +24506_c3_s1_00728.jpg,24506,24506,黄伟斌,3,728 +24506_c3_s1_00729.jpg,24506,24506,黄伟斌,3,729 +24506_c3_s1_00730.jpg,24506,24506,黄伟斌,3,730 +24506_c3_s1_00731.jpg,24506,24506,黄伟斌,3,731 +24506_c3_s1_00732.jpg,24506,24506,黄伟斌,3,732 +24506_c3_s1_00733.jpg,24506,24506,黄伟斌,3,733 +24506_c3_s1_00734.jpg,24506,24506,黄伟斌,3,734 +24506_c3_s1_00735.jpg,24506,24506,黄伟斌,3,735 +24506_c3_s1_00736.jpg,24506,24506,黄伟斌,3,736 +24506_c3_s1_00737.jpg,24506,24506,黄伟斌,3,737 +24506_c3_s1_00738.jpg,24506,24506,黄伟斌,3,738 +24506_c3_s1_00739.jpg,24506,24506,黄伟斌,3,739 +24506_c3_s1_00740.jpg,24506,24506,黄伟斌,3,740 +24506_c3_s1_00741.jpg,24506,24506,黄伟斌,3,741 +24506_c3_s1_00742.jpg,24506,24506,黄伟斌,3,742 +24506_c3_s1_00743.jpg,24506,24506,黄伟斌,3,743 +24506_c3_s1_00744.jpg,24506,24506,黄伟斌,3,744 +24506_c3_s1_00745.jpg,24506,24506,黄伟斌,3,745 +24506_c3_s1_00746.jpg,24506,24506,黄伟斌,3,746 +24506_c3_s1_00747.jpg,24506,24506,黄伟斌,3,747 +24506_c3_s1_00748.jpg,24506,24506,黄伟斌,3,748 +24506_c3_s1_00749.jpg,24506,24506,黄伟斌,3,749 +24506_c3_s1_00750.jpg,24506,24506,黄伟斌,3,750 +24506_c3_s1_00751.jpg,24506,24506,黄伟斌,3,751 +24506_c3_s1_00752.jpg,24506,24506,黄伟斌,3,752 +24506_c3_s1_00753.jpg,24506,24506,黄伟斌,3,753 +24506_c3_s1_00754.jpg,24506,24506,黄伟斌,3,754 +24506_c3_s1_00755.jpg,24506,24506,黄伟斌,3,755 +24506_c3_s1_00756.jpg,24506,24506,黄伟斌,3,756 +24506_c3_s1_00757.jpg,24506,24506,黄伟斌,3,757 +24506_c3_s1_00758.jpg,24506,24506,黄伟斌,3,758 +24506_c3_s1_00759.jpg,24506,24506,黄伟斌,3,759 diff --git a/doctor_identity_package/requirements.txt b/doctor_identity_package/requirements.txt new file mode 100644 index 0000000..81cc923 --- /dev/null +++ b/doctor_identity_package/requirements.txt @@ -0,0 +1,6 @@ +torch +torchvision +opencv-python +mediapipe +numpy +pillow diff --git a/doctor_identity_package/train_reid_contrastive.py b/doctor_identity_package/train_reid_contrastive.py new file mode 100644 index 0000000..3c92da4 --- /dev/null +++ b/doctor_identity_package/train_reid_contrastive.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +人体身份重识别(Person ReID)对比学习训练脚本(单机单文件)。 +- Dataset:Market-1501 风格文件名解析 pid/cam; +- PK 采样器:每个 batch 内 P 个身份 × 每张 K 样本; +- 模型:ImageNet ResNet50 骨干 + GAP + 512 维嵌入 + ID 分类头; +- 损失:Batch-Hard Triplet + ID(交叉熵)联合; +""" +from __future__ import annotations + +import argparse +import random +import re +from pathlib import Path +from typing import Iterator + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.utils.data import DataLoader, Dataset, Sampler +from torchvision import models, transforms + +# --------------------------------------------------------------------------- +# 文件名解析:例如 24502_c1_s1_00001.jpg → pid=24502,cam_id=1 +# --------------------------------------------------------------------------- + +_NAME_RE = re.compile( + r"^(?P\d+)_c(?P\d+)_", + flags=re.I, +) + + +def parse_market1501_style_name(stem: str) -> tuple[int | None, int | None]: + """从「不含后缀」的文件名 stem 中提取身份 ID、机位 ID。""" + m = _NAME_RE.match(stem) + if not m: + return None, None + return int(m.group("pid")), int(m.group("cam")) + + +class DoctorReIDDataset(Dataset): + """医生 ReID:解析 pid/cam,将 pid 重映射到 0..num_classes-1。""" + + def __init__(self, image_root: Path, augment: bool) -> None: + self.image_root = Path(image_root).resolve() + exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} + paths = sorted( + p + for p in self.image_root.rglob("*") + if p.is_file() and p.suffix.lower() in exts + ) + pid_raw_list: list[int] = [] + cam_raw_list: list[int] = [] + valid_paths: list[Path] = [] + for p in paths: + pid_raw, cam_raw = parse_market1501_style_name(p.stem) + if pid_raw is None: + continue + valid_paths.append(p) + pid_raw_list.append(pid_raw) + cam_raw_list.append(int(cam_raw)) + + unique_pids = sorted(set(pid_raw_list)) + self.pid_to_label = {pid: i for i, pid in enumerate(unique_pids)} + self.labels: list[int] = [self.pid_to_label[r] for r in pid_raw_list] + self.cam_ids: list[int] = cam_raw_list + self.paths = valid_paths + + if len(self.paths) == 0: + raise RuntimeError(f"目录下未发现有效图像: {self.image_root}") + + # Resize(128,256) 在 torchvision 中为 (height, width) → (256,128) 常见于 ReID。 + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + if augment: + self.transform = transforms.Compose( + [ + transforms.Resize((256, 128)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ColorJitter( + brightness=0.15, contrast=0.15, saturation=0.15, hue=0.02 + ), + transforms.ToTensor(), + normalize, + # RandomErasure 需作用在 Tensor 上,故放在 ToTensor 之后; + transforms.RandomErasing(p=0.5, scale=(0.02, 0.25), ratio=(0.3, 3.3)), + ] + ) + else: + self.transform = transforms.Compose( + [ + transforms.Resize((256, 128)), + transforms.ToTensor(), + normalize, + ] + ) + + def __len__(self) -> int: + return len(self.paths) + + def __getitem__(self, idx: int): + img = Image.open(self.paths[idx]).convert("RGB") + x = self.transform(img) + y = torch.tensor(self.labels[idx], dtype=torch.long) + cam = torch.tensor(self.cam_ids[idx], dtype=torch.long) + return x, y, cam + + +class RandomIdentityPKSampler(Sampler[int]): + """ + PK / Random Identity Sampler(Random Identity Sampling) + + 【原理简述】对比学习三元组需在 batch 内出现「同类多图」才能把 anchor / positive 找出来。 + PK 策略:在每个 mini-batch 中固定结构为: + - 先随机选 **P** 个不同身份; + - 每个身份若无放回地抽 **K** 张图像(样本不足则从该身份放回随机抽满 K); + 则 batch_size = P * K。 + 这样保证了每个身份在同一个 batch 中至少有若干张可用于 Triplet, + Batch-Hard 才能选「最难的正对 / 最难的负样本」。 + 【注意】P 不能超过数据集中可用身份总数;若 K > 该类张数则用放回采样。 + """ + + def __init__( + self, + labels: list[int], + p: int, + k: int, + *, + seed: int = 0, + length: int | None = None, + ) -> None: + super().__init__() + self.labels_np = np.array(labels) + self.p = int(p) + self.k = int(k) + self.seed = seed + self.epoch = 0 + + self.identities = sorted(np.unique(self.labels_np).tolist()) + self.num_identities = len(self.identities) + + idx_by_label: dict[int, list[int]] = {} + for i, lbl in enumerate(self.labels_np.tolist()): + idx_by_label.setdefault(int(lbl), []).append(i) + self.idx_by_label = {k_: np.array(v) for k_, v in idx_by_label.items()} + + if self.num_identities == 0: + raise RuntimeError("PKSampler: 无任何身份类别") + + if self.p > self.num_identities: + raise ValueError( + f"P={self.p} 大于数据集身份数 {self.num_identities}。" + " 请减小 --p。" + ) + + # 每个 epoch 内「迭代步数」:batch 数(每次 batch 含 P×K 张图) + self.labels_flat = labels + if length is None: + # 约按「扫过一遍身份组合」的量级设一个稳定值 + self.num_batches_est = max(32, min(200, len(self.labels_flat) // max(1, self.p * self.k))) + else: + self.num_batches_est = int(length) + + self.batch_size = self.p * self.k + + def __len__(self) -> int: + return self.num_batches_est * self.batch_size + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + def __iter__(self) -> Iterator[int]: + rng = np.random.RandomState((self.epoch * 9973 + self.seed) & 0xFFFFFFFF) + ids_arr = np.asarray(self.identities, dtype=np.int64) + + for _batch_id in range(self.num_batches_est): + # 一步:无放回抽 P 个身份(若身份不够则允许放回,实际 5 类且 P≤5 时总是无放回) + if ids_arr.size >= self.p: + chosen = rng.choice(ids_arr, size=self.p, replace=False) + else: + chosen = rng.choice(ids_arr, size=self.p, replace=True) + + out: list[int] = [] + for pid_pick in chosen.tolist(): + pool = self.idx_by_label[int(pid_pick)] + if pool.size >= self.k: + idx_pick = rng.choice(pool, size=self.k, replace=False) + else: + idx_pick = rng.choice(pool, size=self.k, replace=True) + out.extend(int(t) for t in idx_pick) + + yield from out + + +def batch_hard_triplet_loss(embeddings: torch.Tensor, labels: torch.Tensor, margin: float) -> torch.Tensor: + """ + Batch-Hard Triplet Loss(Hermans ECCV17 一类的标准形式) + + 对每个 anchor(batch 里的每个样本 i)在同一 batch 中选: + - Positive:与它同身份的样本 j 里面,距离**最大**的那一个( hardest positive ); + - Negative:与它不同身份的样本里,距离**最小**的那一个( hardest negative ); + 单项损失通常为 relu(d_pos_hard - d_neg_hard + margin) 对每个有效 anchor 求平均。 + 【数学含义】拉大「最难正对」相对「最易负样本」的间隔。 + 【实现】使用欧氏距离;若某样本在 batch 内无「异类」(极少见)或无「同类第二张」则不参与均值。 + """ + dist = torch.cdist(embeddings.float(), embeddings.float(), p=2.0).clamp(min=1e-8) + + bs = dist.size(0) + lbl = labels.long() + same = lbl.unsqueeze(1).eq(lbl.unsqueeze(0)) + losses: list[torch.Tensor] = [] + + for i in range(bs): + same_pos = same[i].clone() + same_pos[i] = False + if not same_pos.any(): + continue + hardest_pos = dist[i][same_pos].max() + + neg_mask = ~same[i] + neg_mask[i] = False + if not neg_mask.any(): + continue + hardest_neg = dist[i][neg_mask].min() + + losses.append(F.relu(hardest_pos - hardest_neg + margin)) + + if not losses: + return embeddings.sum() * 0.0 + return torch.stack(losses).mean() + + +class ReIDEmbedModel(nn.Module): + """ResNet50 预训练 backbone(去掉 fc)→ BN 512 嵌入 → logits(num_classes)。""" + + def __init__(self, num_classes: int, feat_dim: int = 512) -> None: + super().__init__() + backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) + self.backbone = nn.Sequential(*list(backbone.children())[:-2]) # 到 GAP 前,输出 [B,2048,7,7] + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.bottleneck = nn.Sequential( + nn.Linear(2048, feat_dim), + nn.BatchNorm1d(feat_dim), + nn.ReLU(inplace=True), + ) + self.classifier = nn.Linear(feat_dim, num_classes) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + h = self.backbone(x) + h = self.gap(h) + h = h.view(h.size(0), -1) + emb = self.bottleneck(h) + logits = self.classifier(emb) + return emb, logits + + +def collate_fn_pk(batch): + xs, ys, cams = zip(*batch, strict=True) + return torch.stack(xs, dim=0), torch.stack(ys, dim=0), torch.stack(cams, dim=0) + + +def train_one_epoch( + model: nn.Module, + loader: DataLoader, + optim: torch.optim.Optimizer, + device: torch.device, + margin: float, + triplet_w: float, + id_w: float, +) -> tuple[float, float, float]: + model.train() + sum_t = 0.0 + sum_id = 0.0 + n = 0 + ce = nn.CrossEntropyLoss() + + for x, y, _cam in loader: + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + emb, logits = model(x) + emb = F.normalize(emb, p=2, dim=1) + loss_t = batch_hard_triplet_loss(emb, y, margin=margin) + loss_id = ce(logits, y) + loss = triplet_w * loss_t + id_w * loss_id + + optim.zero_grad() + loss.backward() + optim.step() + + bs = x.size(0) + sum_t += loss_t.detach().item() * bs + sum_id += loss_id.detach().item() * bs + n += bs + + return sum_t / max(n, 1), sum_id / max(n, 1) + + +@torch.no_grad() +def estimate_id_accuracy(model: nn.Module, loader: DataLoader, device: torch.device) -> float: + model.eval() + correct = 0 + total = 0 + for x, y, _ in loader: + x = x.to(device) + y = y.to(device) + _, logits = model(x) + pred = logits.argmax(dim=1) + correct += int((pred == y).sum().item()) + total += y.numel() + return correct / max(total, 1) + + +def resolve_image_root(cli: str | Path) -> Path: + p = Path(cli).resolve() + sub = p / "doctor_picture" + if sub.is_dir(): + return sub + if p.is_dir(): + jpgs = list(p.glob("*.jpg")) + if len(jpgs) > 0: + return p + raise FileNotFoundError(f"未找到图像目录(可传 doctor_info_detect 或 doctor_picture): {cli}") + + +def parse_args(): + ap = argparse.ArgumentParser(description="Doctor Re-ID PK Triplet + ID 训练") + ap.add_argument( + "--data-root", + type=Path, + default=Path(__file__).resolve().parent, + help="含 doctor_picture 或图片根目录的路径", + ) + ap.add_argument("--epochs", type=int, default=50, help="epoch 建议在 40–60(默认 50)") + ap.add_argument( + "--batch-p", + type=int, + default=5, + help="PK 采样:每 batch 采样的身份数 P(将自动不大于身份总数,如默认 5 类)", + ) + ap.add_argument( + "--batch-k", + type=int, + default=8, + help="PK 采样:每位身份抽样张数 K;batch_size=P×K", + ) + ap.add_argument("--lr", type=float, default=3e-4) + ap.add_argument("--triplet-margin", type=float, default=0.3) + ap.add_argument("--triplet-weight", type=float, default=1.0) + ap.add_argument("--id-weight", type=float, default=2.0, help="ID Loss(CrossEntropy)权重") + ap.add_argument("--workers", type=int, default=4) + ap.add_argument("--seed", type=int, default=42) + ap.add_argument( + "--save", + type=Path, + default=Path(__file__).resolve().parent / "doctor_reid_best.pth", + help="最佳权重路径", + ) + return ap.parse_args() + + +def main() -> None: + args = parse_args() + rng = random.Random(args.seed) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + img_root = resolve_image_root(args.data_root) + + ds_train_aug = DoctorReIDDataset(img_root, augment=True) + ds_eval = DoctorReIDDataset(img_root, augment=False) + + n = len(ds_train_aug) + perm = list(range(n)) + rng.shuffle(perm) + n_val = max(32, int(0.1 * n)) + val_ix = sorted(perm[:n_val]) + train_ix = perm[n_val:] + + labels_tr = [ds_train_aug.labels[i] for i in train_ix] + + num_classes = len(ds_train_aug.pid_to_label) + p_eff = min(args.batch_p, len(set(labels_tr))) + + from torch.utils.data import Subset + + sampler = RandomIdentityPKSampler( + labels_tr, + p=p_eff, + k=args.batch_k, + seed=args.seed, + ) + train_loader = DataLoader( + Subset(ds_train_aug, train_ix), + batch_size=sampler.batch_size, + sampler=sampler, + num_workers=args.workers, + pin_memory=True, + drop_last=True, + collate_fn=collate_fn_pk, + ) + val_loader = DataLoader( + Subset(ds_eval, val_ix), + batch_size=64, + shuffle=False, + num_workers=args.workers, + collate_fn=collate_fn_pk, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = ReIDEmbedModel(num_classes=num_classes, feat_dim=512).to(device) + + optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4) + scheduler = CosineAnnealingLR(optim, T_max=args.epochs, eta_min=1e-6) + + best_acc = -1.0 + for epoch in range(1, args.epochs + 1): + sampler.set_epoch(epoch) + tr_t, tr_id = train_one_epoch( + model, + train_loader, + optim, + device, + margin=args.triplet_margin, + triplet_w=args.triplet_weight, + id_w=args.id_weight, + ) + scheduler.step() + lr_now = optim.param_groups[0]["lr"] + val_acc = estimate_id_accuracy(model, val_loader, device) + + print( + f"epoch {epoch:03d}/{args.epochs} | " + f"triplet {tr_t:.4f} | id_loss_ce {tr_id:.4f} | " + f"lr {lr_now:.6f} | val_id_acc ~ {val_acc:.4f}" + ) + + if val_acc >= best_acc: + best_acc = val_acc + torch.save( + { + "epoch": epoch, + "model_state": model.state_dict(), + "num_classes": num_classes, + "pid_to_label": ds_train_aug.pid_to_label, + "args": vars(args), + }, + args.save, + ) + print(f"[保存] checkpoint → {args.save} (best val_id_acc {best_acc:.4f})") + + print(f"训练结束。最佳 val_id_acc≈{best_acc:.4f}, 权重: {args.save}") + + +if __name__ == "__main__": + main() diff --git a/input/.gitkeep b/input/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/input/.gitkeep @@ -0,0 +1 @@ + diff --git a/main.py b/main.py new file mode 100644 index 0000000..2200041 --- /dev/null +++ b/main.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +"""pack/5.11 唯一入口:从 YAML 读入路径与阈值,运行手术室耗材主流程。""" +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +PACK_ROOT = Path(__file__).resolve().parent +sys.path.insert(0, str(PACK_ROOT / "src")) + +from paths import ensure_code_on_path + +ensure_code_on_path(PACK_ROOT) + +from config import load_run_config +from orchestrator import run_pipeline + + +def main() -> int: + os.environ.setdefault("OPENCV_FFMPEG_LOGLEVEL", "8") + ap = argparse.ArgumentParser(description="手术室耗材主流程(YAML 配置)") + ap.add_argument( + "--config", + type=Path, + default=PACK_ROOT / "configs" / "default_config.yaml", + help="配置文件路径(默认 pack 内 configs/default_config.yaml)", + ) + args = ap.parse_args() + cfg_path = args.config.resolve() + if not cfg_path.is_file(): + print("找不到配置:", cfg_path, file=sys.stderr) + return 1 + + run_cfg = load_run_config(PACK_ROOT, cfg_path) + return int(run_pipeline(run_cfg)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/main_debug.py b/main_debug.py new file mode 100644 index 0000000..954561b --- /dev/null +++ b/main_debug.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +"""Debug 入口:Excel I 列时间段替代 ActionFormer,其余 Phase2 与 main.py 一致。""" +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +PACK_ROOT = Path(__file__).resolve().parent +sys.path.insert(0, str(PACK_ROOT / "src")) + +from paths import ensure_code_on_path + +ensure_code_on_path(PACK_ROOT) + +from config import load_run_config +from orchestrator import run_debug_pipeline + + +def main() -> int: + os.environ.setdefault("OPENCV_FFMPEG_LOGLEVEL", "8") + ap = argparse.ArgumentParser( + description="手术室耗材 Debug 主流程(Excel 时间段 → Phase2,跳过 ActionFormer)" + ) + ap.add_argument("--video", type=Path, required=True, help="输入 MP4") + ap.add_argument( + "--excel", + type=Path, + required=True, + help="商品表 Excel(C 列白名单 + I 列时间段 + 产品编码)", + ) + ap.add_argument("--out", type=Path, required=True, help="输出 TSV") + ap.add_argument( + "--config", + type=Path, + default=PACK_ROOT / "configs" / "default_config.yaml", + help="继承 weights / phase2 / tear_merge / doctor 的 YAML(忽略 io 与 phase1)", + ) + ap.add_argument( + "--time-col-index", + type=int, + default=8, + help="时间段列索引,默认 8 即 Excel I 列;视频2 可用 9(J 列)", + ) + ap.add_argument( + "--min-seg-seconds", + type=float, + default=None, + help="最短段时长(秒);默认 0 表示不过滤短段", + ) + args = ap.parse_args() + + cfg_path = args.config.resolve() + if not cfg_path.is_file(): + print("找不到配置:", cfg_path, file=sys.stderr) + return 1 + + run_cfg = load_run_config(PACK_ROOT, cfg_path) + run_cfg.video = args.video.resolve() + run_cfg.excel = args.excel.resolve() + run_cfg.out = args.out.resolve() + run_cfg.excel_time_col_index = int(args.time_col_index) + if args.min_seg_seconds is not None: + run_cfg.af_min_seg_seconds = float(args.min_seg_seconds) + else: + run_cfg.af_min_seg_seconds = 0.0 + + run_cfg.merge_adjacent_tear = False + + return int(run_debug_pipeline(run_cfg)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/output/.gitkeep b/output/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/output/result_remuxed_sample.txt b/output/result_remuxed_sample.txt new file mode 100644 index 0000000..32496fd --- /dev/null +++ b/output/result_remuxed_sample.txt @@ -0,0 +1,8 @@ +rank start_sec end_sec product_id_top1 top1_name top1_conf product_id_top2 top2_name top2_conf product_id_top3 top3_name top3_conf +1 7.705316 13.708928 215-93-1 一次性使用无菌敷贴 0.983056 8186-02-03 00:00:00 一次性使用精密过滤输液器 带针 0.016944 +2 23.872417 34.601734 1281-39-3 密闭式防针刺伤型静脉留置针 0.983541 8186-02-03 00:00:00 一次性使用精密过滤输液器 带针 0.016459 +3 42.936733 56.751099 8186-02-03 00:00:00 一次性使用精密过滤输液器 带针 0.927646 215-93-1 一次性使用无菌敷贴 0.047184 1518-34-17 一次性使用胃管 0.025171 +4 62.147926 72.155182 21444-1-2 一次性使用气管插管 0.499763 14780-3-5 一次性使用无菌气管插管Tracheal Tube 0.332695 8186-02-03 00:00:00 一次性使用精密过滤输液器 带针 0.167542 +5 79.344002 90.151703 7386-61-46 一次性无菌喉罩 0.886499 14730-10-10 一次性使用血液透析管路 0.068667 8186-02-03 00:00:00 一次性使用精密过滤输液器 带针 0.044834 +6 99.754677 104.440956 14780-3-5 一次性使用无菌气管插管Tracheal Tube 0.914568 215-93-1 一次性使用无菌敷贴 0.051220 739-2-1 血液净化装置的体外循环血路 0.034212 +医生信息:付玉峰 (id=24503, conf=0.8552) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6c842c2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +# pack/5.11 运行依赖(推理主流程最小集) +# 使用前请按显卡环境先安装匹配的 PyTorch: https://pytorch.org/get-started/locally/ +# +# pip install -r requirements.txt +# pip install -e code/actionformer_release/libs/utils +# +# 第二行在 pack/5.11 目录下执行(安装 ActionFormer 用 nms 扩展) + +torch>=2.0.0 +torchvision>=0.15.0 +ultralytics>=8.0.0 +opencv-python>=4.8.0 +numpy>=1.23.0 +pandas>=2.0.0 +openpyxl>=3.1.0 +PyYAML>=6.0 +mediapipe>=0.10.0 diff --git a/scripts/compare_phase1_postprocess.py b/scripts/compare_phase1_postprocess.py new file mode 100644 index 0000000..c81306a --- /dev/null +++ b/scripts/compare_phase1_postprocess.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +"""对比 Phase1 后处理:greedy_mutual_exclusive vs hybrid_nms_and_trimming。""" +from __future__ import annotations + +import argparse +import json +import shutil +import sys +import tempfile +from pathlib import Path + +PACK_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(PACK_ROOT / "src")) + +from paths import ensure_code_on_path # noqa: E402 + +ensure_code_on_path(PACK_ROOT) + +import cv2 # noqa: E402 +import run_haocai_actionformer_consumables_e2e as e2e # noqa: E402 +from config import load_run_config # noqa: E402 +from pack_utils import log # noqa: E402 + + +def _filter_min_len( + segs: list[tuple[float, float, float]], min_seg: float +) -> list[tuple[float, float, float]]: + if min_seg <= 0: + return segs + return [(s, e, sc) for s, e, sc in segs if (e - s) >= min_seg - 1e-9] + + +def _fmt_segs(segs: list[tuple[float, float, float]]) -> list[str]: + return [ + f" [{s:8.3f}, {e:8.3f}] len={e - s:5.2f}s score={sc:.4f}" + for s, e, sc in segs + ] + + +def run_phase1_raw( + *, + video_path: Path, + stem: str, + work: Path, + actionformer_ckpt: Path, + af_min_score: float, + python_exe: str, + feat_batch_size: int, + device: str, +) -> list[tuple[float, float, float]]: + inp = work / "input" + feat_dir = work / "features" + inp.mkdir(parents=True, exist_ok=True) + feat_dir.mkdir(parents=True, exist_ok=True) + for stale in inp.glob("*.mp4"): + stale.unlink(missing_ok=True) + + single_video = inp / video_path.name + if single_video.resolve() != video_path.resolve(): + shutil.copy2(video_path, single_video) + + meta_path = feat_dir / "meta.json" + e2e.run_feature_extraction( + python_exe=python_exe, + data_root=inp, + output_dir=feat_dir, + meta_file=meta_path, + device=device, + batch_size=max(1, feat_batch_size), + ) + + meta = json.loads(meta_path.read_text(encoding="utf-8")) + duration, fps = e2e.duration_fps_from_meta(meta, stem) + if stem not in meta.get("videos", {}): + cap0 = cv2.VideoCapture(str(video_path)) + if cap0.isOpened(): + fps = float(cap0.get(cv2.CAP_PROP_FPS)) or fps + nfr = int(cap0.get(cv2.CAP_PROP_FRAME_COUNT)) + cap0.release() + if fps > 0 and nfr > 0: + duration = nfr / fps + + json_path = work / "infer_single.json" + e2e.write_infer_json(json_path, stem, duration, fps) + yaml_path = work / "infer_single.yaml" + e2e.write_infer_yaml(yaml_path, json_path.resolve(), feat_dir.resolve()) + + pkl_dest = work / "eval_results.pkl" + e2e.run_actionformer_eval( + python_exe=python_exe, + yaml_path=yaml_path.resolve(), + ckpt_path=actionformer_ckpt.resolve(), + copy_pkl_to=pkl_dest, + ) + + raw = e2e.parse_actionformer_pkl(pkl_dest, stem) + return [(s, e, sc) for s, e, sc in raw if sc > af_min_score] + + +def main() -> int: + ap = argparse.ArgumentParser(description="Phase1 后处理对比") + ap.add_argument( + "--config", + type=Path, + default=PACK_ROOT / "configs/run_20260519_151255_tracking.yaml", + ) + ap.add_argument( + "--report", + type=Path, + default=PACK_ROOT / "output/phase1_compare_20260519_151255.txt", + ) + ap.add_argument("--work-dir", type=Path, default=None, help="保留中间文件目录") + args = ap.parse_args() + + run_cfg = load_run_config(PACK_ROOT, args.config.resolve()) + video_path = Path(run_cfg.video).resolve() + if not video_path.is_file(): + log(f"找不到视频: {video_path}") + return 1 + + stem = video_path.stem + if args.work_dir is not None: + work = args.work_dir.resolve() + work.mkdir(parents=True, exist_ok=True) + else: + work = Path(tempfile.mkdtemp(prefix="phase1_compare_")) + + log(f"视频: {video_path}") + log(f"工作目录: {work}") + + raw = run_phase1_raw( + video_path=video_path, + stem=stem, + work=work, + actionformer_ckpt=Path(run_cfg.actionformer_ckpt), + af_min_score=float(run_cfg.af_min_score), + python_exe=str(run_cfg.python), + feat_batch_size=int(run_cfg.feat_batch_size), + device=str(run_cfg.device), + ) + + min_seg = float(run_cfg.af_min_seg_seconds) + legacy = _filter_min_len(e2e.greedy_mutual_exclusive(raw), min_seg) + hybrid = e2e.hybrid_nms_and_trimming(raw, min_len=min_seg) + + lines = [ + f"video: {video_path}", + f"af_min_score: {run_cfg.af_min_score}", + f"af_min_seg_seconds: {min_seg}", + f"ActionFormer 原始候选: {len(raw)}", + "", + f"greedy_mutual_exclusive: {len(legacy)} 段", + *_fmt_segs(legacy), + "", + f"hybrid_nms_and_trimming: {len(hybrid)} 段", + *_fmt_segs(hybrid), + "", + "差异摘要:", + f" 段数差 (hybrid - legacy): {len(hybrid) - len(legacy)}", + ] + + only_legacy = [] + only_hybrid = [] + for s, e, sc in legacy: + if not any(abs(s - hs) < 0.05 and abs(e - he) < 0.05 for hs, he, _ in hybrid): + only_legacy.append((s, e, sc)) + for s, e, sc in hybrid: + if not any(abs(s - ls) < 0.05 and abs(e - le) < 0.05 for ls, le, _ in legacy): + only_hybrid.append((s, e, sc)) + + if only_legacy: + lines.append(" 仅 legacy 有:") + lines.extend(_fmt_segs(only_legacy)) + if only_hybrid: + lines.append(" 仅 hybrid 有:") + lines.extend(_fmt_segs(only_hybrid)) + + report_path = args.report.resolve() + report_path.parent.mkdir(parents=True, exist_ok=True) + report_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + for line in lines: + print(line) + log(f"已写出: {report_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/merge_gap_adjacent_result.py b/scripts/merge_gap_adjacent_result.py new file mode 100644 index 0000000..d03744a --- /dev/null +++ b/scripts/merge_gap_adjacent_result.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +"""对已有 TSV 结果做相邻 gap 合并:组内各段重推理取 pairs_h,拼接后 aggregate top3。""" +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +PACK_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(PACK_ROOT / "src")) + +from paths import ensure_code_on_path # noqa: E402 + +ensure_code_on_path(PACK_ROOT) + +import cv2 # noqa: E402 +import run_haocai_actionformer_consumables_e2e as e2e # noqa: E402 +from config import load_run_config # noqa: E402 +from pack_utils import load_allowed_names_from_excel, log # noqa: E402 +from pipeline.gap_adjacent_merge import ( # noqa: E402 + group_rows_by_gap, + merge_all_by_gap, + span_key, +) +from pipeline.hand_roi_merge import HandMergeConfig, HandRoiGrouper # noqa: E402 +from pipeline.segment_processor import ( # noqa: E402 + FineGrainedClassifier, + process_segment_multi_hand_tear_with_gate_retries, +) +from pipeline.tear_gate_merge import parse_e2e_rows_from_body_lines # noqa: E402 +from run_segments_consumable_vote import pad_box as _pad_box # noqa: E402 +from ultralytics import YOLO # noqa: E402 + + +def _read_tsv(path: Path) -> tuple[str, list[str], str | None]: + lines = path.read_text(encoding="utf-8").splitlines() + if not lines: + raise ValueError(f"空文件: {path}") + header = lines[0] + body: list[str] = [] + doctor_line: str | None = None + for line in lines[1:]: + if not line.strip(): + continue + if line.startswith("医生信息"): + doctor_line = line + break + body.append(line) + return header, body, doctor_line + + +def _infer_pairs_for_row( + cap: cv2.VideoCapture, + *, + det, + fg: FineGrainedClassifier, + grouper: HandRoiGrouper, + row, + args, + cls_names, + allowed_idx, + predict_kw: dict, +) -> list[tuple[str, float]]: + h_retry = args.haocai_min_conf_retry + if h_retry is not None and h_retry <= 0: + h_retry = None + elif h_retry is not None and h_retry >= float(args.haocai_min_conf) - 1e-12: + h_retry = None + info = process_segment_multi_hand_tear_with_gate_retries( + cap, + det, + fg, + grouper, + start_sec=row.start_sec, + end_sec=row.end_sec, + seek_margin_sec=args.seek_margin_sec, + det_conf=args.det_conf, + imgsz_det=args.imgsz_det, + frame_stride=max(1, args.frame_stride), + tracking_alpha=float(args.tracking_alpha), + tracking_max_lost_frames=int(args.tracking_max_lost_frames), + good_top1_conf_threshold=float(args.good_top1_conf_threshold), + good_top1_retry_threshold=float(args.good_top1_retry_threshold), + haocai_min_conf=float(args.haocai_min_conf), + haocai_min_conf_retry=h_retry, + cls_names=cls_names, + allowed_class_idx=allowed_idx, + empty_cache_every=args.empty_cache_every, + log_fn=log, + log_prefix=f"gap_merge 重推理 rank={row.rank}: ", + ) + if not info.get("ok"): + return [] + return list(info.get("pairs_h") or []) + + +def main() -> int: + ap = argparse.ArgumentParser(description="相邻 gap 段合并(pairs 拼接 + top3 归一化)") + ap.add_argument("--tsv", type=Path, required=True, help="输入 TSV 结果") + ap.add_argument("--video", type=Path, required=True, help="对应视频") + ap.add_argument("--config", type=Path, required=True, help="运行配置 YAML") + ap.add_argument("--gap-sec", type=float, default=2.0, help="相邻段 gap 上限(严格小于)") + ap.add_argument("--out", type=Path, required=True, help="输出 TSV") + args_cli = ap.parse_args() + + args = load_run_config(PACK_ROOT, args_cli.config.resolve()) + header, body_lines, doctor_line = _read_tsv(args_cli.tsv.resolve()) + e2e_rows = parse_e2e_rows_from_body_lines(body_lines) + groups = group_rows_by_gap(e2e_rows, max_gap_sec=args_cli.gap_sec) + + excel_path = args.excel + if args.whitelist_json and Path(args.whitelist_json).is_file(): + allowed_names = e2e.load_allowed_names_from_json(Path(args.whitelist_json)) + else: + allowed_names = load_allowed_names_from_excel(excel_path) + product_map = e2e.load_product_code_map(excel_path) + + predict_kw: dict = {"device": args.device} + if args.half: + predict_kw["half"] = True + + merge_cfg = HandMergeConfig( + merge_iou_gt=args.merge_iou_gt, + merge_center_dist_max_px=args.merge_center_dist_max_px, + merge_center_dist_max_frac_diag=args.merge_center_dist_max_frac_diag, + ) + + log("加载 YOLO 模型…") + det = YOLO(str(args.hand_model)) + gb = YOLO(str(args.goodbad_model)) + cls_m = YOLO(str(args.haocai_model)) + tear_m = YOLO(str(args.tear_model)) + cls_names = cls_m.names + fg = FineGrainedClassifier( + gb, + cls_m, + tear_m, + gb_names=gb.names, + cls_names=cls_names, + tear_names=tear_m.names, + imgsz_cls=args.imgsz_cls, + predict_kw=predict_kw, + ) + grouper = HandRoiGrouper(merge_cfg, pad_box_fn=_pad_box, pad_ratio=args.pad_ratio) + allowed_idx = e2e.allowed_indices_from_json_names(allowed_names, cls_names) + + video_path = args_cli.video.resolve() + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + log(f"无法打开视频: {video_path}") + return 1 + + span_to_pairs: dict[tuple[float, float], list[tuple[str, float]]] = {} + try: + for grp in groups: + if len(grp) <= 1: + r = grp[0] + if r.is_success(): + pairs = _infer_pairs_for_row( + cap, + det=det, + fg=fg, + grouper=grouper, + row=r, + args=args, + cls_names=cls_names, + allowed_idx=allowed_idx, + predict_kw=predict_kw, + ) + span_to_pairs[span_key(r.start_sec, r.end_sec)] = pairs + continue + for r in grp: + pairs = _infer_pairs_for_row( + cap, + det=det, + fg=fg, + grouper=grouper, + row=r, + args=args, + cls_names=cls_names, + allowed_idx=allowed_idx, + predict_kw=predict_kw, + ) + span_to_pairs[span_key(r.start_sec, r.end_sec)] = pairs + finally: + cap.release() + + merged_rows = merge_all_by_gap( + e2e_rows, + span_to_pairs, + product_map, + max_gap_sec=args_cli.gap_sec, + log_fn=log, + ) + + out_lines = [header] + for er in merged_rows: + out_lines.append(er.to_line12(er.rank)) + if doctor_line: + out_lines.append(doctor_line) + + out_path = args_cli.out.resolve() + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text("\n".join(out_lines) + "\n", encoding="utf-8") + log(f"已写入 {out_path}({len(merged_rows)} 行耗材,原 {len(e2e_rows)} 行)") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/remux_hevc.sh b/scripts/remux_hevc.sh new file mode 100755 index 0000000..c1ba2c7 --- /dev/null +++ b/scripts/remux_hevc.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +# HEVC 主视角 MP4 转 H.264,供 VideoSwin 特征提取与 OpenCV 解码。 +# 用法: +# ./scripts/remux_hevc.sh /path/to/source.mp4 [output.mp4] +# 未指定输出时写入 input/remuxed/_h264.mp4 + +set -euo pipefail + +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +SRC="${1:?用法: remux_hevc.sh [output.mp4]}" +STEM="$(basename "${SRC%.*}")" +OUT="${2:-${ROOT}/input/remuxed/${STEM}_h264.mp4}" + +mkdir -p "$(dirname "$OUT")" +echo "[remux] ${SRC} -> ${OUT}" +ffmpeg -y -i "$SRC" -c:v libx264 -preset ultrafast -crf 23 -an "$OUT" +echo "[done] ${OUT}" diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..ebf7288 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +# pack 5.11 src package diff --git a/src/actionformer_utils.py b/src/actionformer_utils.py new file mode 100644 index 0000000..0b02459 --- /dev/null +++ b/src/actionformer_utils.py @@ -0,0 +1,96 @@ +"""Phase1:VideoSwin 特征 + ActionFormer 时段(与仓库 main_pipeline.ActionSegmenter 一致)。""" +from __future__ import annotations + +import json +import shutil +from pathlib import Path +from typing import Any + +import cv2 + +import run_haocai_actionformer_consumables_e2e as e2e +from pack_utils import log + + +class ActionSegmenter: + @staticmethod + def build_segments( + *, + video_path: Path, + stem: str, + work: Path, + actionformer_ckpt: Path, + af_min_score: float, + af_min_seg_seconds: float, + python_exe: str, + feat_batch_size: int, + device: str, + ) -> list[tuple[float, float, float]]: + inp = work / "input" + feat_dir = work / "features" + inp.mkdir(parents=True, exist_ok=True) + feat_dir.mkdir(parents=True, exist_ok=True) + for stale in inp.glob("*.mp4"): + stale.unlink(missing_ok=True) + + single_video = inp / video_path.name + if single_video.resolve() != video_path.resolve(): + shutil.copy2(video_path, single_video) + + meta_path = feat_dir / "meta.json" + e2e.run_feature_extraction( + python_exe=python_exe, + data_root=inp, + output_dir=feat_dir, + meta_file=meta_path, + device=device, + batch_size=max(1, feat_batch_size), + ) + + meta = json.loads(meta_path.read_text(encoding="utf-8")) + duration, fps = e2e.duration_fps_from_meta(meta, stem) + if stem not in meta.get("videos", {}): + log("meta 中未找到 video_id=stem,使用 OpenCV 估 duration…") + cap0 = cv2.VideoCapture(str(video_path)) + if cap0.isOpened(): + fps = float(cap0.get(cv2.CAP_PROP_FPS)) or fps + nfr = int(cap0.get(cv2.CAP_PROP_FRAME_COUNT)) + cap0.release() + if fps > 0 and nfr > 0: + duration = nfr / fps + + npy_path = feat_dir / f"{stem}.npy" + if not npy_path.is_file(): + raise FileNotFoundError(f"特征文件不存在: {npy_path}") + + json_path = work / "infer_single.json" + e2e.write_infer_json(json_path, stem, duration, fps) + + yaml_path = work / "infer_single.yaml" + e2e.write_infer_yaml(yaml_path, json_path.resolve(), feat_dir.resolve()) + + pkl_dest = work / "eval_results.pkl" + e2e.run_actionformer_eval( + python_exe=python_exe, + yaml_path=yaml_path.resolve(), + ckpt_path=actionformer_ckpt.resolve(), + copy_pkl_to=pkl_dest, + ) + + raw_segs = e2e.parse_actionformer_pkl(pkl_dest, stem) + raw_segs = [(s, e, sc) for s, e, sc in raw_segs if sc > af_min_score] + segs = e2e.greedy_mutual_exclusive(raw_segs) + n_exclusive = len(segs) + min_seg = float(af_min_seg_seconds) + if min_seg > 0: + segs = [(s, e, sc) for s, e, sc in segs if (e - s) >= min_seg - 1e-9] + if min_seg > 0: + log( + f"ActionFormer 候选 {len(raw_segs)} -> 互斥后 {n_exclusive} 段 -> " + f"剔除短于 {min_seg:g}s 后 {len(segs)} 段(score>{af_min_score})" + ) + else: + log( + f"ActionFormer 候选 {len(raw_segs)} -> 互斥后 {n_exclusive} 段(score>{af_min_score})" + ) + return segs diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..66b9d1d --- /dev/null +++ b/src/config.py @@ -0,0 +1,114 @@ +"""加载 configs/*.yaml,解析为与 main_pipeline argparse 等价的 SimpleNamespace。""" +from __future__ import annotations + +import sys +from argparse import Namespace +from pathlib import Path +from typing import Any + +import yaml + + +def _rel(pack_root: Path, raw: str | None) -> Path | None: + if raw is None: + return None + path = Path(raw) + if path.is_absolute(): + return path.resolve() + return (pack_root / path).resolve() + + +def load_run_config(pack_root: Path, config_path: Path) -> Namespace: + pack_root = pack_root.resolve() + data: dict[str, Any] = yaml.safe_load(config_path.read_text(encoding="utf-8")) + io = data["io"] + w = data["weights"] + rt = data["runtime"] + dev = data["device"] + p1 = data["phase1"] + p2 = data["phase2"] + cl = data["classification"] + tm = data["tear_merge"] + gm = data.get("gap_merge", {}) + outopt = data["output"] + did = data.get("doctor_identity", {}) + + py = rt.get("python") + python_exe = sys.executable if py is None or str(py).strip() == "" else str(py) + + whitelist_raw = io.get("whitelist_json") + whitelist_path = _rel(pack_root, whitelist_raw) if whitelist_raw else None + + tear_w_raw = tm.get("tear_merge_weights") + tear_w = _rel(pack_root, tear_w_raw) if tear_w_raw else None + + work_raw = rt.get("work_dir") + work_dir = _rel(pack_root, work_raw) if work_raw else None + + doctor_ckpt_raw = did.get("checkpoint", "doctor_identity_package/doctor_info.pth") + doctor_labels_raw = did.get("labels_csv", "doctor_identity_package/labels.csv") + + return Namespace( + video=_rel(pack_root, io["video"]), + excel=_rel(pack_root, io["excel"]), + out=_rel(pack_root, io["out"]), + whitelist_json=whitelist_path, + work_dir=work_dir, + keep_work_dir=bool(rt.get("keep_work_dir", False)), + python=python_exe, + actionformer_ckpt=_rel(pack_root, w["actionformer"]), + hand_model=_rel(pack_root, w["hand"]), + goodbad_model=_rel(pack_root, w["goodbad"]), + haocai_model=_rel(pack_root, w["haocai"]), + tear_model=_rel(pack_root, w["tear"]), + device=str(dev.get("type", "cuda")), + half=bool(dev.get("half", False)), + af_min_score=float(p1["af_min_score"]), + af_min_seg_seconds=float(p1["af_min_seg_seconds"]), + feat_batch_size=int(p1.get("feat_batch_size", 1)), + seek_margin_sec=float(p2["seek_margin_sec"]), + frame_stride=int(p2["frame_stride"]), + det_conf=float(p2["det_conf"]), + pad_ratio=float(p2["pad_ratio"]), + imgsz_det=int(p2["imgsz_det"]), + merge_iou_gt=float(p2["merge_iou_gt"]), + merge_center_dist_max_px=( + float(p2["merge_center_dist_max_px"]) + if p2.get("merge_center_dist_max_px") is not None + else None + ), + merge_center_dist_max_frac_diag=( + float(p2["merge_center_dist_max_frac_diag"]) + if p2.get("merge_center_dist_max_frac_diag") is not None + else None + ), + tracking_alpha=float(p2.get("tracking_alpha", 0.6)), + tracking_max_lost_frames=int(p2.get("tracking_max_lost_frames", 0)), + imgsz_cls=int(cl["imgsz_cls"]), + good_top1_conf_threshold=float(cl["good_top1_conf_threshold"]), + good_top1_retry_threshold=float(cl["good_top1_retry_threshold"]), + haocai_min_conf=float(cl["haocai_min_conf"]), + haocai_min_conf_retry=float(cl["haocai_min_conf_retry"]), + empty_cache_every=int(cl.get("empty_cache_every", 0)), + legacy_12_col_only=bool(outopt.get("legacy_12_col_only", True)), + merge_adjacent_tear=bool(tm.get("merge_adjacent_tear", False)), + tear_merge_weights=tear_w, + tear_merge_class=str(tm.get("tear_merge_class", "tearing")), + tear_merge_head_sec=float(tm.get("tear_merge_head_sec", 3.0)), + tear_merge_prob=float(tm.get("tear_merge_prob", 0.9)), + tear_merge_min_frames=int(tm.get("tear_merge_min_frames", 6)), + tear_merge_verbose=bool(tm.get("tear_merge_verbose", False)), + tear_merge_full_frame=bool(tm.get("tear_merge_full_frame", False)), + gap_merge_enabled=bool(gm.get("enabled", False)), + gap_merge_max_gap_sec=float(gm.get("max_gap_sec", 2.0)), + doctor_identity_enabled=bool(did.get("enabled", True)), + doctor_identity_checkpoint=_rel(pack_root, doctor_ckpt_raw), + doctor_identity_labels_csv=_rel(pack_root, doctor_labels_raw), + doctor_identity_pose_min_detection_confidence=float( + did.get("pose_min_detection_confidence", 0.3) + ), + doctor_identity_min_identity_confidence=float(did.get("min_identity_confidence", 0.0)), + doctor_identity_middle_seconds=float(did.get("middle_seconds", 10.0)), + doctor_identity_sample_fps=float(did.get("sample_fps", 3.0)), + doctor_identity_pad_frac=float(did.get("pad_frac", 0.15)), + ) diff --git a/src/excel_segments.py b/src/excel_segments.py new file mode 100644 index 0000000..a5f8bdd --- /dev/null +++ b/src/excel_segments.py @@ -0,0 +1,154 @@ +"""从 Excel 时间段列加载段列表,供 debug 主流程替代 ActionFormer。""" +from __future__ import annotations + +import re +from pathlib import Path +from typing import List, Tuple + +import cv2 +import pandas as pd + +from pack_utils import log + + +def parse_mm_ss_to_seconds(value: str) -> float: + text = str(value).strip() + if not text: + raise ValueError("empty time value") + if "." in text: + left, right = text.split(".", 1) + minutes = int(left) if left else 0 + seconds = int(right) if right else 0 + if seconds >= 60: + raise ValueError(f"invalid mm.ss seconds >= 60: {text}") + return float(minutes * 60 + seconds) + return float(int(text)) + + +def _is_legacy_mm_dot_ss(token: str) -> bool: + if "." not in token: + return False + a, b = token.split(".", 1) + if not a.isdigit() or not b.isdigit(): + return False + return 1 <= len(b) <= 2 + + +def parse_time_token(t: str) -> float: + t = str(t).strip().replace(":", ":") + if not t: + raise ValueError("empty token") + if ":" in t: + parts = [float(x) for x in t.split(":")] + if len(parts) == 3: + return parts[0] * 3600.0 + parts[1] * 60.0 + parts[2] + if len(parts) == 2: + return parts[0] * 60.0 + parts[1] + raise ValueError(f"bad colon time: {t}") + if _is_legacy_mm_dot_ss(t): + return parse_mm_ss_to_seconds(t) + return float(t) + + +def parse_cell_to_segments_v2(cell: object) -> List[Tuple[float, float]]: + """解析单元格内多段「开始-结束」(冒号 / 分.秒 / 纯秒)。""" + if cell is None or (isinstance(cell, float) and pd.isna(cell)): + return [] + text = str(cell).strip() + if not text: + return [] + text = ( + text.replace(";", ";") + .replace(",", ",") + .replace("、", ",") + .replace("\n", ";") + .replace(":", ":") + .replace(" ", "") + ) + chunks = re.split(r"[;,]+", text) + segments: List[Tuple[float, float]] = [] + for ch in chunks: + if not ch: + continue + m = re.match(r"^(.+?)\-(.+)$", ch) + if not m: + continue + left, right = m.group(1), m.group(2) + try: + s = parse_time_token(left) + e = parse_time_token(right) + except (ValueError, TypeError): + continue + if e > s: + segments.append((s, e)) + return segments + + +def _video_duration_sec(video_path: Path | None) -> float | None: + if video_path is None: + return None + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + return None + fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0 + nfr = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + if fps > 0 and nfr > 0: + return nfr / fps + return None + + +def load_segments_from_excel_column_i( + excel_path: Path, + *, + col_index: int = 8, + sheet_name: int | str = 0, + video_path: Path | None = None, + default_score: float = 1.0, +) -> list[tuple[float, float, float]]: + """ + 从 Excel 指定列(默认 I 列 index=8)汇总所有行的时间段,返回 (start, end, score)。 + """ + excel_path = excel_path.resolve() + df = pd.read_excel(excel_path, sheet_name=sheet_name, header=0) + + if df.shape[1] > col_index: + time_series = df.iloc[:, col_index] + time_col_name = str(df.columns[col_index]) + else: + cand_cols = [c for c in df.columns if "时间段" in str(c)] + if not cand_cols: + raise ValueError( + f"Excel 列数不足且未找到含「时间段」的列: {excel_path} (cols={df.shape[1]})" + ) + time_col_name = str(cand_cols[0]) + time_series = df[time_col_name] + + duration = _video_duration_sec(video_path) + raw_pairs: list[tuple[float, float]] = [] + invalid_cnt = 0 + + for cell in time_series.tolist(): + segs = parse_cell_to_segments_v2(cell) + for s, e in segs: + cs, ce = s, e + if duration is not None: + cs = max(0.0, min(s, duration)) + ce = max(0.0, min(e, duration)) + if ce <= cs: + invalid_cnt += 1 + continue + raw_pairs.append((cs, ce)) + + raw_pairs.sort(key=lambda x: (x[0], x[1])) + segs_out = [(s, e, float(default_score)) for s, e in raw_pairs] + + log( + f"[debug] Excel 时间段列「{time_col_name}」(index={col_index}) " + f"→ {len(segs_out)} 段" + + (f",丢弃无效 {invalid_cnt} 段" if invalid_cnt else "") + ) + if duration is not None: + log(f"[debug] 视频时长 {duration:.3f}s,段已裁剪到 [0, duration]") + + return segs_out diff --git a/src/orchestrator.py b/src/orchestrator.py new file mode 100644 index 0000000..51c65b3 --- /dev/null +++ b/src/orchestrator.py @@ -0,0 +1,473 @@ +"""主流程编排:与仓库 main_pipeline.PipelineManager 逻辑一致,参数来自 YAML(SimpleNamespace)。""" +from __future__ import annotations + +import importlib.util +import tempfile +from argparse import Namespace +from pathlib import Path +from typing import Any + +import cv2 +import run_haocai_actionformer_consumables_e2e as e2e +from actionformer_utils import ActionSegmenter +from excel_segments import load_segments_from_excel_column_i +from pipeline.hand_roi_merge import HandMergeConfig, HandRoiGrouper +from pipeline.segment_processor import ( + FineGrainedClassifier, + process_segment_multi_hand_tear_with_gate_retries, +) +from pipeline.gap_adjacent_merge import merge_all_by_gap +from pipeline.tear_gate_merge import ( + merge_all, + parse_e2e_rows_from_body_lines, + tear_class_index, +) +from run_segments_consumable_vote import pad_box as _pad_box +from ultralytics import YOLO + +from pack_utils import load_allowed_names_from_excel, log + + +def _load_doctor_module(script_path: Path) -> Any: + spec = importlib.util.spec_from_file_location("doctor_identity_runtime", script_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"无法加载医生识别脚本: {script_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _infer_doctor_text(args: Namespace, video_path: Path) -> str: + if not bool(getattr(args, "doctor_identity_enabled", True)): + return "未启用" + + checkpoint = Path(args.doctor_identity_checkpoint).resolve() + labels_csv = Path(args.doctor_identity_labels_csv).resolve() + if not checkpoint.is_file(): + return f"识别失败(缺少权重: {checkpoint})" + if not labels_csv.is_file(): + return f"识别失败(缺少标签映射: {labels_csv})" + + pack_root = Path(__file__).resolve().parent.parent + script_path = pack_root / "doctor_identity_package" / "infer_doctor_from_video.py" + if not script_path.is_file(): + return f"识别失败(缺少脚本: {script_path})" + + try: + doctor_mod = _load_doctor_module(script_path) + model_path = doctor_mod._ensure_pose_lite_model(script_path.parent / ".mediapipe_models") + opts = doctor_mod.PoseLandmarkerOptions( + base_options=doctor_mod.BaseOptions(model_asset_path=str(model_path)), + running_mode=doctor_mod.VisionRunningMode.IMAGE, + min_pose_detection_confidence=float( + args.doctor_identity_pose_min_detection_confidence + ), + ) + landmarker = doctor_mod.PoseLandmarker.create_from_options(opts) + try: + best_crop = doctor_mod.pick_best_person_crop( + video_path=video_path, + landmarker=landmarker, + middle_seconds=float(args.doctor_identity_middle_seconds), + sample_fps=float(args.doctor_identity_sample_fps), + pad_frac=float(args.doctor_identity_pad_frac), + ) + finally: + landmarker.close() + + raw_pid, conf = doctor_mod.run_inference(best_crop, checkpoint) + min_conf = float(args.doctor_identity_min_identity_confidence) + name_map = doctor_mod.load_name_mapping(labels_csv) + doctor_name = name_map.get(str(raw_pid), "") + suffix = " [低置信度]" if conf < min_conf else "" + if doctor_name: + return f"{doctor_name} (id={raw_pid}, conf={conf:.4f}){suffix}" + return f"doctor_id={raw_pid} (conf={conf:.4f}){suffix}" + except Exception as exc: # noqa: BLE001 + return f"识别失败({exc})" + + +def _resolve_allowed_names(args: Namespace, excel_path: Path) -> list[str] | None: + if args.whitelist_json is not None: + if not args.whitelist_json.is_file(): + log(f"找不到白名单 JSON: {args.whitelist_json}") + return None + return e2e.load_whitelist_json(args.whitelist_json.resolve()) + return load_allowed_names_from_excel(excel_path) + + +def _validate_phase2_weights(args: Namespace, *, require_actionformer: bool) -> bool: + checks: list[tuple[Any, str]] = [ + (args.hand_model, "手部检测"), + (args.goodbad_model, "好坏帧"), + (args.haocai_model, "耗材分类"), + (args.tear_model, "撕膜分类"), + ] + if require_actionformer: + checks.insert(0, (args.actionformer_ckpt, "ActionFormer ckpt")) + for p, lab in checks: + if not Path(p).is_file(): + log(f"缺少{lab}: {p}") + return False + if args.merge_adjacent_tear: + tmw = (args.tear_merge_weights or args.tear_model).resolve() + if not tmw.is_file(): + log(f"撕膜门控需要权重文件: {tmw}") + return False + return True + + +def _filter_segments_by_min_length( + segs: list[tuple[float, float, float]], min_seg_seconds: float +) -> list[tuple[float, float, float]]: + if min_seg_seconds <= 0: + return segs + return [(s, e, sc) for s, e, sc in segs if (e - s) >= min_seg_seconds - 1e-9] + + +class PipelineManager: + def __init__(self, args: Namespace) -> None: + self.args = args + + def run(self) -> int: + args = self.args + video_path = args.video.resolve() + if not video_path.is_file(): + log(f"找不到视频: {video_path}") + return 1 + excel_path = args.excel.resolve() + if not excel_path.is_file(): + log(f"找不到 Excel: {excel_path}") + return 1 + + allowed_names = _resolve_allowed_names(args, excel_path) + if allowed_names is None: + return 1 + if not _validate_phase2_weights(args, require_actionformer=True): + return 1 + + stem = video_path.stem + tmp_ctx: tempfile.TemporaryDirectory | None = None + if args.work_dir is not None: + work = Path(args.work_dir).resolve() + work.mkdir(parents=True, exist_ok=True) + elif args.keep_work_dir: + work = Path(tempfile.mkdtemp(prefix="main_pipeline_")) + log(f"工作目录(保留): {work}") + else: + tmp_ctx = tempfile.TemporaryDirectory(prefix="main_pipeline_") + work = Path(tmp_ctx.name) + + try: + product_map = e2e.load_product_code_map(excel_path) + segs = ActionSegmenter.build_segments( + video_path=video_path, + stem=stem, + work=work, + actionformer_ckpt=args.actionformer_ckpt, + af_min_score=args.af_min_score, + af_min_seg_seconds=args.af_min_seg_seconds, + python_exe=args.python, + feat_batch_size=args.feat_batch_size, + device=args.device, + ) + return self._run_phase2_and_write( + segs, + video_path=video_path, + excel_path=excel_path, + allowed_names=allowed_names, + product_map=product_map, + work_dir_log=work if args.work_dir is not None or args.keep_work_dir else None, + ) + finally: + if tmp_ctx is not None: + tmp_ctx.cleanup() + + def _run_phase2_and_write( + self, + segs: list[tuple[float, float, float]], + *, + video_path: Path, + excel_path: Path, + allowed_names: list[str], + product_map: dict[str, str], + work_dir_log: Path | None = None, + ) -> int: + args = self.args + + predict_kw: dict[str, Any] = {"device": args.device} + if args.half: + predict_kw["half"] = True + + merge_cfg = HandMergeConfig( + merge_iou_gt=args.merge_iou_gt, + merge_center_dist_max_px=args.merge_center_dist_max_px, + merge_center_dist_max_frac_diag=args.merge_center_dist_max_frac_diag, + ) + + log("Phase2:加载 YOLO(手 / 好坏帧 / 耗材 / 撕膜)…") + det = YOLO(str(args.hand_model)) + gb = YOLO(str(args.goodbad_model)) + cls_m = YOLO(str(args.haocai_model)) + tear_m = YOLO(str(args.tear_model)) + + gb_names = gb.names + cls_names = cls_m.names + tear_names = tear_m.names + + fg = FineGrainedClassifier( + gb, + cls_m, + tear_m, + gb_names=gb_names, + cls_names=cls_names, + tear_names=tear_names, + imgsz_cls=args.imgsz_cls, + predict_kw=predict_kw, + ) + grouper = HandRoiGrouper(merge_cfg, pad_box_fn=_pad_box, pad_ratio=args.pad_ratio) + + allowed_idx = e2e.allowed_indices_from_json_names(allowed_names, cls_names) + + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + log("无法打开视频") + return 1 + + sep = "\t" + base_cols = [ + "rank", + "start_sec", + "end_sec", + "product_id_top1", + "top1_name", + "top1_conf", + "product_id_top2", + "top2_name", + "top2_conf", + "product_id_top3", + "top3_name", + "top3_conf", + ] + ext_cols = ["tear_top1_name", "tear_top2_name"] + header = sep.join(base_cols if args.legacy_12_col_only else base_cols + ext_cols) + lines_out = [header] + span_to_cells: dict[tuple[float, float], list[str]] = {} + span_to_pairs: dict[tuple[float, float], list[tuple[str, float]]] = {} + + def span_key(t0: float, t1: float) -> tuple[float, float]: + return (round(float(t0), 6), round(float(t1), 6)) + + def infer_one(rank: int, t0: float, t1: float) -> str: + h_retry = args.haocai_min_conf_retry + if h_retry is not None and h_retry <= 0: + h_retry = None + elif h_retry is not None and h_retry >= float(args.haocai_min_conf) - 1e-12: + h_retry = None + info = process_segment_multi_hand_tear_with_gate_retries( + cap, + det, + fg, + grouper, + start_sec=t0, + end_sec=t1, + seek_margin_sec=args.seek_margin_sec, + det_conf=args.det_conf, + imgsz_det=args.imgsz_det, + frame_stride=max(1, args.frame_stride), + tracking_alpha=float(args.tracking_alpha), + tracking_max_lost_frames=int(args.tracking_max_lost_frames), + good_top1_conf_threshold=float(args.good_top1_conf_threshold), + good_top1_retry_threshold=float(args.good_top1_retry_threshold), + haocai_min_conf=float(args.haocai_min_conf), + haocai_min_conf_retry=h_retry, + cls_names=cls_names, + allowed_class_idx=allowed_idx, + empty_cache_every=args.empty_cache_every, + log_fn=log, + log_prefix=f"段落 rank={rank}: ", + ) + if not info.get("ok"): + reason = str(info.get("reason", "")) + span_to_pairs[span_key(t0, t1)] = [] + row = [ + str(rank), + f"{t0:.6f}", + f"{t1:.6f}", + "", + reason, + "", + "", + "", + "", + "", + "", + "", + ] + if not args.legacy_12_col_only: + row.extend(["", ""]) + span_to_cells[span_key(t0, t1)] = row[1:] + return sep.join(row) + + n1, n2, n3 = info["top_names"] + c1, c2, c3 = info["top_confs"] + tn1, tn2 = info["tear_top_names"] + id1 = product_map.get(n1, "") if n1 else "" + id2 = product_map.get(n2, "") if n2 else "" + id3 = product_map.get(n3, "") if n3 else "" + for nm, pid in ((n1, id1), (n2, id2), (n3, id3)): + if nm and not pid: + log(f"警告: 商品表无名称「{nm}」,产品编码置空。") + + row = [ + str(rank), + f"{t0:.6f}", + f"{t1:.6f}", + id1, + n1, + f"{c1:.6f}" if n1 else "", + id2, + n2, + f"{c2:.6f}" if n2 else "", + id3, + n3, + f"{c3:.6f}" if n3 else "", + ] + if not args.legacy_12_col_only: + row.extend([tn1, tn2]) + span_to_cells[span_key(t0, t1)] = row[1:] + span_to_pairs[span_key(t0, t1)] = list(info.get("pairs_h") or []) + return sep.join(row) + + try: + for rank, (t0, t1, af_sc) in enumerate(segs, start=1): + log(f"段落 rank={rank} [{t0:.3f},{t1:.3f}] score={af_sc:.4f} …") + lines_out.append(infer_one(rank, t0, t1)) + + if args.merge_adjacent_tear: + log("撕膜门控:合并相邻同 top1 成功段…") + tw_path = (args.tear_merge_weights or args.tear_model).resolve() + if Path(args.tear_model).resolve() == tw_path: + tear_gate_m = tear_m + else: + tear_gate_m = YOLO(str(tw_path)) + tidx = tear_class_index(tear_gate_m, args.tear_merge_class) + body_lines = lines_out[1:] + e2e_rows = parse_e2e_rows_from_body_lines(body_lines) + mg_det = det if not args.tear_merge_full_frame else None + mg_grouper = grouper if not args.tear_merge_full_frame else None + merged_rows = merge_all( + e2e_rows, + cap, + tear_gate_m, + tidx, + head_sec=float(args.tear_merge_head_sec), + tear_prob=float(args.tear_merge_prob), + tear_min_frames=int(args.tear_merge_min_frames), + imgsz=int(args.imgsz_cls), + predict_kw=predict_kw, + verbose=bool(args.tear_merge_verbose), + det=mg_det, + grouper=mg_grouper, + imgsz_det=int(args.imgsz_det), + det_conf=float(args.det_conf), + ) + lines_out = [header] + for j, er in enumerate(merged_rows, start=1): + sk = span_key(er.start_sec, er.end_sec) + if sk in span_to_cells: + lines_out.append(sep.join([str(j)] + span_to_cells[sk])) + else: + log( + f"[tear_merge] 合并窗段全量重推理 rank={j} " + f"[{er.start_sec:.3f},{er.end_sec:.3f}]" + ) + lines_out.append(infer_one(j, er.start_sec, er.end_sec)) + + if getattr(args, "gap_merge_enabled", False): + log("相邻 gap 合并…") + body_lines = lines_out[1:] + e2e_rows = parse_e2e_rows_from_body_lines(body_lines) + gap_merged = merge_all_by_gap( + e2e_rows, + span_to_pairs, + product_map, + max_gap_sec=float(args.gap_merge_max_gap_sec), + log_fn=log, + ) + lines_out = [header] + for er in gap_merged: + lines_out.append(er.to_line12(er.rank)) + finally: + cap.release() + + log("医生识别:开始执行…") + doctor_text = _infer_doctor_text(args, video_path) + log(f"医生识别:{doctor_text}") + lines_out.append(f"医生信息:{doctor_text}") + + args.out.parent.mkdir(parents=True, exist_ok=True) + args.out.write_text("\n".join(lines_out) + "\n", encoding="utf-8") + log(f"已写出: {args.out.resolve()}") + if work_dir_log is not None: + log(f"工作目录: {work_dir_log}") + + return 0 + + +class DebugPipelineManager(PipelineManager): + """跳过 ActionFormer,用 Excel 时间段列作为段列表。""" + + def run(self) -> int: + args = self.args + video_path = args.video.resolve() + if not video_path.is_file(): + log(f"找不到视频: {video_path}") + return 1 + excel_path = args.excel.resolve() + if not excel_path.is_file(): + log(f"找不到 Excel: {excel_path}") + return 1 + + log("[debug] 使用 Excel 时间段,跳过 ActionFormer") + args.merge_adjacent_tear = False + log("[debug] 跳过撕膜相邻段合并(merge_adjacent_tear=false)") + + allowed_names = _resolve_allowed_names(args, excel_path) + if allowed_names is None: + return 1 + if not _validate_phase2_weights(args, require_actionformer=False): + return 1 + + col_index = int(getattr(args, "excel_time_col_index", 8)) + segs = load_segments_from_excel_column_i( + excel_path, + col_index=col_index, + video_path=video_path, + ) + if not segs: + log("Excel 未解析到任何有效时间段") + return 1 + + min_seg = float(getattr(args, "af_min_seg_seconds", 0.0)) + segs = _filter_segments_by_min_length(segs, min_seg) + if not segs: + log(f"最短段过滤(>={min_seg:g}s)后无剩余段") + return 1 + + product_map = e2e.load_product_code_map(excel_path) + return self._run_phase2_and_write( + segs, + video_path=video_path, + excel_path=excel_path, + allowed_names=allowed_names, + product_map=product_map, + ) + + +def run_pipeline(args: Namespace) -> int: + return PipelineManager(args).run() + + +def run_debug_pipeline(args: Namespace) -> int: + return DebugPipelineManager(args).run() diff --git a/src/pack_utils.py b/src/pack_utils.py new file mode 100644 index 0000000..f00b0f8 --- /dev/null +++ b/src/pack_utils.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import time +from pathlib import Path + + +def log(msg: str) -> None: + print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) + + +def load_allowed_names_from_excel(excel_path: Path) -> list[str]: + import pandas as pd + + df = pd.read_excel(excel_path, sheet_name=0, header=0) + if df.shape[1] < 3: + raise ValueError(f"Excel 至少需要 C 列(第 3 列): {excel_path}") + col = df.iloc[:, 2] + names: list[str] = [] + seen: set[str] = set() + for raw in col: + if pd.isna(raw): + continue + s = str(raw).strip() + if not s or s == "商品名称": + continue + if s not in seen: + seen.add(s) + names.append(s) + return names diff --git a/src/paths.py b/src/paths.py new file mode 100644 index 0000000..122866d --- /dev/null +++ b/src/paths.py @@ -0,0 +1,23 @@ +"""pack/5.11:将 vendor code 根目录加入 sys.path(顺序与 main_pipeline 一致)。""" +from __future__ import annotations + +import sys +from pathlib import Path + + +def ensure_code_on_path(pack_root: Path) -> Path: + """ + pack_root: pack/5.11 根目录。 + 返回 CODE_ROOT(即 pack_root / 'code')。 + """ + code = (pack_root / "code").resolve() + if not (code / "repo_root.py").is_file(): + raise FileNotFoundError(f"缺少 vendor code 根: {code}") + + scripts = code / "video_clip_cls" / "scripts" + infer = code / "video_clip_cls" / "infer_single_0506" + for p in (infer, scripts, code): + s = str(p) + if s not in sys.path: + sys.path.insert(0, s) + return code diff --git a/visualize_result_video.py b/visualize_result_video.py new file mode 100644 index 0000000..7a10c40 --- /dev/null +++ b/visualize_result_video.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +"""根据 output/result.txt 生成手部融合框可视化视频。""" +from __future__ import annotations + +import argparse +import os +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import cv2 +import numpy as np + +try: + from PIL import Image, ImageDraw, ImageFont +except Exception: # noqa: BLE001 + Image = None + ImageDraw = None + ImageFont = None + +PACK_ROOT = Path(__file__).resolve().parent +sys.path.insert(0, str(PACK_ROOT / "src")) + +from paths import ensure_code_on_path # noqa: E402 + +ensure_code_on_path(PACK_ROOT) + +from pipeline.hand_roi_merge import HandMergeConfig, HandRoiGrouper, two_largest_hands, union_xyxy # noqa: E402 +from run_segments_consumable_vote import collect_hand_boxes, pad_box as _pad_box # noqa: E402 +from ultralytics import YOLO # noqa: E402 + + +@dataclass +class SegmentRow: + rank: int + start_sec: float + end_sec: float + top1_name: str + + +_FONT_CANDIDATES = [ + Path("/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc"), + Path("/usr/share/fonts/opentype/noto/NotoSerifCJK-Regular.ttc"), + Path("/usr/share/fonts/truetype/wqy/wqy-microhei.ttc"), +] + + +def log(msg: str) -> None: + print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) + + +def parse_result_txt(path: Path) -> tuple[list[SegmentRow], str]: + lines = path.read_text(encoding="utf-8").splitlines() + header_idx = None + doctor_info = "" + for i, raw in enumerate(lines): + line = raw.strip() + if line.startswith("医生信息:"): + doctor_info = line + if line.lower().startswith("rank\t"): + header_idx = i + break + if header_idx is None: + raise ValueError(f"未找到结果表头: {path}") + + header = lines[header_idx].split("\t") + col_idx = {name: idx for idx, name in enumerate(header)} + for key in ("rank", "start_sec", "end_sec", "top1_name"): + if key not in col_idx: + raise ValueError(f"结果文件缺少列 {key}: {path}") + + out: list[SegmentRow] = [] + for raw in lines[header_idx + 1 :]: + line = raw.strip() + if not line: + continue + if line.startswith("医生信息:"): + doctor_info = line + continue + parts = raw.split("\t") + need = max(col_idx.values()) + 1 + if len(parts) < need: + continue + try: + rank = int(parts[col_idx["rank"]].strip()) + start_sec = float(parts[col_idx["start_sec"]].strip()) + end_sec = float(parts[col_idx["end_sec"]].strip()) + except ValueError: + continue + top1_name = parts[col_idx["top1_name"]].strip() + out.append( + SegmentRow( + rank=rank, + start_sec=start_sec, + end_sec=end_sec, + top1_name=top1_name, + ) + ) + out.sort(key=lambda x: (x.start_sec, x.end_sec, x.rank)) + return out, doctor_info + + +def active_segment_at(segments: list[SegmentRow], idx_hint: int, t_sec: float) -> tuple[int, SegmentRow | None]: + i = idx_hint + n = len(segments) + while i < n and t_sec > segments[i].end_sec + 1e-6: + i += 1 + if i < n: + seg = segments[i] + if seg.start_sec - 1e-6 <= t_sec <= seg.end_sec + 1e-6: + return i, seg + return i, None + + +def fused_box_padded( + frame, + hands: list[list[float]], + grouper: HandRoiGrouper, +) -> tuple[int, int, int, int] | None: + if not hands: + return None + h, w = frame.shape[:2] + pad_fn = grouper.pad_box_fn + ratio = grouper.pad_ratio + if len(hands) == 1: + return pad_fn(hands[0], w, h, ratio) + + # 需求:不要分别画两只手;两手时统一合成为一个外接框。 + h1, h2 = two_largest_hands(hands) + uni = union_xyxy(h1, h2) + return pad_fn(uni, w, h, ratio) + + +def load_pil_font(font_path: Path | None, font_size: int): + if ImageFont is None: + return None, None + candidates: list[Path] = [] + if font_path is not None: + candidates.append(font_path) + candidates.extend(_FONT_CANDIDATES) + for p in candidates: + if p.is_file(): + try: + return ImageFont.truetype(str(p), font_size), p + except Exception: # noqa: BLE001 + continue + try: + return ImageFont.load_default(), None + except Exception: # noqa: BLE001 + return None, None + + +def draw_label_box(frame, rect: tuple[int, int, int, int], label: str, pil_font) -> None: + x1, y1, x2, y2 = rect + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), 2) + text = label.strip() if label.strip() else "unknown" + text = text.replace("\t", " ") + + if Image is not None and ImageDraw is not None and pil_font is not None: + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = Image.fromarray(frame_rgb) + draw = ImageDraw.Draw(img) + l, t, r, b = draw.textbbox((0, 0), text, font=pil_font) + tw = max(1, r - l) + th = max(1, b - t) + by2 = max(0, y1 - 4) + by1 = max(0, by2 - th - 8) + bx2 = min(frame.shape[1] - 1, x1 + tw + 8) + draw.rectangle([(x1, by1), (bx2, by2)], fill=(255, 165, 0)) + draw.text((x1 + 4, by1 + 2), text, font=pil_font, fill=(0, 0, 0)) + frame[:, :, :] = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) + return + + (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.62, 2) + by2 = max(0, y1 - 4) + by1 = max(0, by2 - th - 8) + bx2 = min(frame.shape[1] - 1, x1 + tw + 8) + cv2.rectangle(frame, (x1, by1), (bx2, by2), (0, 165, 255), -1) + cv2.putText(frame, text, (x1 + 4, max(0, by2 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.62, (0, 0, 0), 2, cv2.LINE_AA) + + +def draw_bottom_right_info(frame, text: str, pil_font) -> None: + info = text.strip() + if not info: + return + + if Image is not None and ImageDraw is not None and pil_font is not None: + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = Image.fromarray(frame_rgb) + draw = ImageDraw.Draw(img) + l, t, r, b = draw.textbbox((0, 0), info, font=pil_font) + tw = max(1, r - l) + th = max(1, b - t) + pad = 10 + x1 = max(0, frame.shape[1] - tw - pad * 2 - 12) + y1 = max(0, frame.shape[0] - th - pad * 2 - 12) + x2 = min(frame.shape[1] - 1, x1 + tw + pad * 2) + y2 = min(frame.shape[0] - 1, y1 + th + pad * 2) + draw.rectangle([(x1, y1), (x2, y2)], fill=(255, 165, 0)) + draw.text((x1 + pad, y1 + pad), info, font=pil_font, fill=(0, 0, 0)) + frame[:, :, :] = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) + return + + (tw, th), _ = cv2.getTextSize(info, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 2) + pad = 8 + x1 = max(0, frame.shape[1] - tw - pad * 2 - 10) + y1 = max(0, frame.shape[0] - th - pad * 2 - 10) + x2 = min(frame.shape[1] - 1, x1 + tw + pad * 2) + y2 = min(frame.shape[0] - 1, y1 + th + pad * 2) + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), -1) + cv2.putText(frame, info, (x1 + pad, y2 - pad), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 2, cv2.LINE_AA) + + +def main() -> int: + os.environ.setdefault("OPENCV_FFMPEG_LOGLEVEL", "8") + ap = argparse.ArgumentParser(description="按 result.txt 时间段绘制手部融合框+耗材标签,输出 MP4。") + ap.add_argument("--video", type=Path, default=PACK_ROOT / "input" / "sample.mp4") + ap.add_argument("--result-txt", type=Path, default=PACK_ROOT / "output" / "result.txt") + ap.add_argument("--hand-model", type=Path, default=PACK_ROOT / "weights" / "hand_detect.pt") + ap.add_argument("--out-video", type=Path, default=PACK_ROOT / "output" / "result_vis.mp4") + ap.add_argument("--det-conf", type=float, default=0.6) + ap.add_argument("--imgsz-det", type=int, default=640) + ap.add_argument("--pad-ratio", type=float, default=0.20) + ap.add_argument("--merge-iou-gt", type=float, default=0.0) + ap.add_argument("--merge-center-dist-max-px", type=float, default=None) + ap.add_argument("--merge-center-dist-max-frac-diag", type=float, default=None) + ap.add_argument("--device", type=str, default="cuda") + ap.add_argument("--half", action="store_true", help="传给 YOLO predict 的 half=True") + ap.add_argument( + "--font-path", + type=Path, + default=None, + help="中文字体文件(ttf/ttc)路径;不传则自动尝试系统常见 CJK 字体", + ) + args = ap.parse_args() + + video_path = args.video.resolve() + txt_path = args.result_txt.resolve() + model_path = args.hand_model.resolve() + out_path = args.out_video.resolve() + out_path.parent.mkdir(parents=True, exist_ok=True) + + for p, name in ((video_path, "输入视频"), (txt_path, "结果txt"), (model_path, "手部权重")): + if not p.is_file(): + print(f"缺少{name}: {p}", file=sys.stderr) + return 1 + + segs, doctor_info_text = parse_result_txt(txt_path) + if not segs: + print(f"未在 txt 中解析到有效时间段: {txt_path}", file=sys.stderr) + return 1 + if doctor_info_text: + log(f"医生信息: {doctor_info_text}") + + log(f"加载手部模型: {model_path}") + det = YOLO(str(model_path)) + merge_cfg = HandMergeConfig( + merge_iou_gt=float(args.merge_iou_gt), + merge_center_dist_max_px=args.merge_center_dist_max_px, + merge_center_dist_max_frac_diag=args.merge_center_dist_max_frac_diag, + ) + grouper = HandRoiGrouper(merge_cfg, pad_box_fn=_pad_box, pad_ratio=float(args.pad_ratio)) + + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + print(f"无法打开视频: {video_path}", file=sys.stderr) + return 1 + fps = float(cap.get(cv2.CAP_PROP_FPS)) + if fps <= 0: + fps = 25.0 + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + font_size = max(18, int(h * 0.028)) + font_path = args.font_path.resolve() if args.font_path is not None else None + pil_font, font_used = load_pil_font(font_path, font_size) + if font_used is not None: + log(f"标签字体: {font_used}") + elif pil_font is not None: + log("标签字体: Pillow 默认字体(可能不支持中文)") + else: + log("标签字体: 回退 OpenCV 内置字体(中文可能显示异常)") + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + writer = cv2.VideoWriter(str(out_path), fourcc, fps, (w, h)) + if not writer.isOpened(): + cap.release() + print(f"无法创建视频写入器: {out_path}", file=sys.stderr) + return 1 + + predict_kw: dict[str, Any] = {"device": args.device} + if bool(args.half): + predict_kw["half"] = True + + frame_idx = 0 + seg_idx = 0 + n_drawn = 0 + try: + while True: + ok, frame = cap.read() + if not ok or frame is None: + break + frame_idx += 1 + t_sec = frame_idx / fps + + seg_idx, seg = active_segment_at(segs, seg_idx, t_sec) + if seg is not None: + r0 = det.predict( + frame, + conf=float(args.det_conf), + imgsz=int(args.imgsz_det), + verbose=False, + **predict_kw, + )[0] + hands = collect_hand_boxes(det, r0.boxes) if r0.boxes else [] + fused = fused_box_padded(frame, hands, grouper) + if fused is not None: + draw_label_box(frame, fused, seg.top1_name, pil_font) + n_drawn += 1 + if doctor_info_text: + draw_bottom_right_info(frame, doctor_info_text, pil_font) + writer.write(frame) + + if frame_idx % 200 == 0: + log(f"处理中: {frame_idx}/{max(total, 1)} 帧") + finally: + writer.release() + cap.release() + + log(f"完成: 输出 {out_path}") + log(f"共绘制 {n_drawn} 帧融合框(总帧 {frame_idx})") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/weights/.gitkeep b/weights/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/weights/.gitkeep @@ -0,0 +1 @@ +