163 lines
4.7 KiB
Python
163 lines
4.7 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
使用 Ultralytics YOLO 训练参考物(ref)单类检测模型。
|
|||
|
|
|
|||
|
|
默认权重:仓库根目录 yolo26n.pt
|
|||
|
|
默认数据:先运行 prepare_refbox_dataset.py 生成 detect_refbox/dataset/refbox.yaml
|
|||
|
|
|
|||
|
|
小样本建议:较轻增强、适中 epoch、较小 batch。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import subprocess
|
|||
|
|
import sys
|
|||
|
|
from datetime import datetime
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
|
|||
|
|
def repo_root() -> Path:
|
|||
|
|
return Path(__file__).resolve().parents[1]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_args() -> argparse.Namespace:
|
|||
|
|
root = repo_root()
|
|||
|
|
p = argparse.ArgumentParser(description="训练 ref 框 YOLO(yolo26n)")
|
|||
|
|
p.add_argument(
|
|||
|
|
"--model",
|
|||
|
|
type=Path,
|
|||
|
|
default=root / "yolo26n.pt",
|
|||
|
|
help="初始权重路径",
|
|||
|
|
)
|
|||
|
|
p.add_argument(
|
|||
|
|
"--data",
|
|||
|
|
type=Path,
|
|||
|
|
default=root / "detect_refbox" / "dataset" / "refbox.yaml",
|
|||
|
|
help="数据集 YAML",
|
|||
|
|
)
|
|||
|
|
p.add_argument(
|
|||
|
|
"--source",
|
|||
|
|
type=Path,
|
|||
|
|
default=Path("/home/ubuntu/data/fish/2016-1-22-last_images"),
|
|||
|
|
help="原始 JSON 图像根目录(仅 --prepare 时)",
|
|||
|
|
)
|
|||
|
|
p.add_argument(
|
|||
|
|
"--dataset-out",
|
|||
|
|
type=Path,
|
|||
|
|
default=root / "detect_refbox" / "dataset",
|
|||
|
|
help="prepare 输出目录",
|
|||
|
|
)
|
|||
|
|
p.add_argument("--prepare", action="store_true", help="训练前先执行数据集准备")
|
|||
|
|
p.add_argument("--val-ratio", type=float, default=0.2, help="验证比例(prepare)")
|
|||
|
|
p.add_argument("--copy-images", action="store_true", help="prepare 时复制图像")
|
|||
|
|
p.add_argument("--epochs", type=int, default=200)
|
|||
|
|
p.add_argument("--batch", type=int, default=8)
|
|||
|
|
p.add_argument("--imgsz", type=int, default=640)
|
|||
|
|
p.add_argument("--device", type=str, default="", help="如 0 或 cpu,空则自动")
|
|||
|
|
p.add_argument(
|
|||
|
|
"--project",
|
|||
|
|
type=Path,
|
|||
|
|
default=root / "detect_refbox" / "runs",
|
|||
|
|
help="Ultralytics project 目录",
|
|||
|
|
)
|
|||
|
|
p.add_argument("--name", type=str, default="", help="运行名,默认带时间戳")
|
|||
|
|
p.add_argument("--workers", type=int, default=4)
|
|||
|
|
p.add_argument("--patience", type=int, default=80)
|
|||
|
|
p.add_argument("--seed", type=int, default=42)
|
|||
|
|
p.add_argument("--exist-ok", action="store_true")
|
|||
|
|
return p.parse_args()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_prepare(args: argparse.Namespace) -> int:
|
|||
|
|
prep = Path(__file__).resolve().parent / "prepare_refbox_dataset.py"
|
|||
|
|
cmd = [
|
|||
|
|
sys.executable,
|
|||
|
|
str(prep),
|
|||
|
|
"--source",
|
|||
|
|
str(args.source),
|
|||
|
|
"--out",
|
|||
|
|
str(args.dataset_out),
|
|||
|
|
"--val-ratio",
|
|||
|
|
str(args.val_ratio),
|
|||
|
|
"--seed",
|
|||
|
|
str(args.seed),
|
|||
|
|
]
|
|||
|
|
if args.copy_images:
|
|||
|
|
cmd.append("--copy-images")
|
|||
|
|
return subprocess.call(cmd)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main() -> int:
|
|||
|
|
args = parse_args()
|
|||
|
|
if args.prepare:
|
|||
|
|
rc = run_prepare(args)
|
|||
|
|
if rc != 0:
|
|||
|
|
return rc
|
|||
|
|
|
|||
|
|
data_yaml = args.data.expanduser().resolve()
|
|||
|
|
if not data_yaml.is_file():
|
|||
|
|
print(
|
|||
|
|
f"[错误] 未找到 {data_yaml}。请先运行:\n"
|
|||
|
|
f" python3 {Path(__file__).parent / 'prepare_refbox_dataset.py'} "
|
|||
|
|
f"--source {args.source}",
|
|||
|
|
file=sys.stderr,
|
|||
|
|
)
|
|||
|
|
return 1
|
|||
|
|
|
|||
|
|
model_path = args.model.expanduser().resolve()
|
|||
|
|
if not model_path.is_file():
|
|||
|
|
print(f"[错误] 未找到权重: {model_path}", file=sys.stderr)
|
|||
|
|
return 1
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
except ImportError as e:
|
|||
|
|
print("[错误] 需要 ultralytics: pip install ultralytics", file=sys.stderr)
|
|||
|
|
print(e, file=sys.stderr)
|
|||
|
|
return 1
|
|||
|
|
|
|||
|
|
name = args.name or f"refbox_y26n_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|||
|
|
args.project = args.project.expanduser().resolve()
|
|||
|
|
args.project.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 小数据:关闭强 mosaic/mixup,减轻过拟合风险
|
|||
|
|
model = YOLO(str(model_path))
|
|||
|
|
model.train(
|
|||
|
|
data=str(data_yaml),
|
|||
|
|
epochs=args.epochs,
|
|||
|
|
imgsz=args.imgsz,
|
|||
|
|
batch=args.batch,
|
|||
|
|
device=args.device if args.device else None,
|
|||
|
|
project=str(args.project),
|
|||
|
|
name=name,
|
|||
|
|
workers=args.workers,
|
|||
|
|
patience=args.patience,
|
|||
|
|
seed=args.seed,
|
|||
|
|
exist_ok=args.exist_ok,
|
|||
|
|
verbose=True,
|
|||
|
|
mosaic=0.0,
|
|||
|
|
mixup=0.0,
|
|||
|
|
copy_paste=0.0,
|
|||
|
|
degrees=5.0,
|
|||
|
|
translate=0.08,
|
|||
|
|
scale=0.25,
|
|||
|
|
shear=0.0,
|
|||
|
|
perspective=0.0,
|
|||
|
|
flipud=0.0,
|
|||
|
|
fliplr=0.5,
|
|||
|
|
hsv_h=0.01,
|
|||
|
|
hsv_s=0.4,
|
|||
|
|
hsv_v=0.3,
|
|||
|
|
close_mosaic=0,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
save_dir = args.project / name
|
|||
|
|
print(f"\n完成。权重目录: {save_dir / 'weights'}")
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
raise SystemExit(main())
|