#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 使用 Ultralytics YOLO 训练参考物(ref)单类检测模型。 默认权重:仓库根目录 yolo26n.pt 默认数据:先运行 prepare_refbox_dataset.py 生成 detect_refbox/dataset/refbox.yaml 小样本建议:较轻增强、适中 epoch、较小 batch。 """ from __future__ import annotations import argparse import subprocess import sys from datetime import datetime from pathlib import Path def repo_root() -> Path: return Path(__file__).resolve().parents[1] def parse_args() -> argparse.Namespace: root = repo_root() p = argparse.ArgumentParser(description="训练 ref 框 YOLO(yolo26n)") p.add_argument( "--model", type=Path, default=root / "yolo26n.pt", help="初始权重路径", ) p.add_argument( "--data", type=Path, default=root / "detect_refbox" / "dataset" / "refbox.yaml", help="数据集 YAML", ) p.add_argument( "--source", type=Path, default=Path("/home/ubuntu/data/fish/2016-1-22-last_images"), help="原始 JSON 图像根目录(仅 --prepare 时)", ) p.add_argument( "--dataset-out", type=Path, default=root / "detect_refbox" / "dataset", help="prepare 输出目录", ) p.add_argument("--prepare", action="store_true", help="训练前先执行数据集准备") p.add_argument("--val-ratio", type=float, default=0.2, help="验证比例(prepare)") p.add_argument("--copy-images", action="store_true", help="prepare 时复制图像") p.add_argument("--epochs", type=int, default=200) p.add_argument("--batch", type=int, default=8) p.add_argument("--imgsz", type=int, default=640) p.add_argument("--device", type=str, default="", help="如 0 或 cpu,空则自动") p.add_argument( "--project", type=Path, default=root / "detect_refbox" / "runs", help="Ultralytics project 目录", ) p.add_argument("--name", type=str, default="", help="运行名,默认带时间戳") p.add_argument("--workers", type=int, default=4) p.add_argument("--patience", type=int, default=80) p.add_argument("--seed", type=int, default=42) p.add_argument("--exist-ok", action="store_true") return p.parse_args() def run_prepare(args: argparse.Namespace) -> int: prep = Path(__file__).resolve().parent / "prepare_refbox_dataset.py" cmd = [ sys.executable, str(prep), "--source", str(args.source), "--out", str(args.dataset_out), "--val-ratio", str(args.val_ratio), "--seed", str(args.seed), ] if args.copy_images: cmd.append("--copy-images") return subprocess.call(cmd) def main() -> int: args = parse_args() if args.prepare: rc = run_prepare(args) if rc != 0: return rc data_yaml = args.data.expanduser().resolve() if not data_yaml.is_file(): print( f"[错误] 未找到 {data_yaml}。请先运行:\n" f" python3 {Path(__file__).parent / 'prepare_refbox_dataset.py'} " f"--source {args.source}", file=sys.stderr, ) return 1 model_path = args.model.expanduser().resolve() if not model_path.is_file(): print(f"[错误] 未找到权重: {model_path}", file=sys.stderr) return 1 try: from ultralytics import YOLO except ImportError as e: print("[错误] 需要 ultralytics: pip install ultralytics", file=sys.stderr) print(e, file=sys.stderr) return 1 name = args.name or f"refbox_y26n_{datetime.now().strftime('%Y%m%d_%H%M%S')}" args.project = args.project.expanduser().resolve() args.project.mkdir(parents=True, exist_ok=True) # 小数据:关闭强 mosaic/mixup,减轻过拟合风险 model = YOLO(str(model_path)) model.train( data=str(data_yaml), epochs=args.epochs, imgsz=args.imgsz, batch=args.batch, device=args.device if args.device else None, project=str(args.project), name=name, workers=args.workers, patience=args.patience, seed=args.seed, exist_ok=args.exist_ok, verbose=True, mosaic=0.0, mixup=0.0, copy_paste=0.0, degrees=5.0, translate=0.08, scale=0.25, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, hsv_h=0.01, hsv_s=0.4, hsv_v=0.3, close_mosaic=0, ) save_dir = args.project / name print(f"\n完成。权重目录: {save_dir / 'weights'}") return 0 if __name__ == "__main__": raise SystemExit(main())