163 lines
4.7 KiB
Python
163 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
使用 Ultralytics YOLO 训练参考物(ref)单类检测模型。
|
||
|
||
默认权重:仓库根目录 yolo26n.pt
|
||
默认数据:先运行 prepare_refbox_dataset.py 生成 detect_refbox/dataset/refbox.yaml
|
||
|
||
小样本建议:较轻增强、适中 epoch、较小 batch。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import subprocess
|
||
import sys
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
|
||
def repo_root() -> Path:
|
||
return Path(__file__).resolve().parents[1]
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
root = repo_root()
|
||
p = argparse.ArgumentParser(description="训练 ref 框 YOLO(yolo26n)")
|
||
p.add_argument(
|
||
"--model",
|
||
type=Path,
|
||
default=root / "yolo26n.pt",
|
||
help="初始权重路径",
|
||
)
|
||
p.add_argument(
|
||
"--data",
|
||
type=Path,
|
||
default=root / "detect_refbox" / "dataset" / "refbox.yaml",
|
||
help="数据集 YAML",
|
||
)
|
||
p.add_argument(
|
||
"--source",
|
||
type=Path,
|
||
default=Path("/home/ubuntu/data/fish/2016-1-22-last_images"),
|
||
help="原始 JSON 图像根目录(仅 --prepare 时)",
|
||
)
|
||
p.add_argument(
|
||
"--dataset-out",
|
||
type=Path,
|
||
default=root / "detect_refbox" / "dataset",
|
||
help="prepare 输出目录",
|
||
)
|
||
p.add_argument("--prepare", action="store_true", help="训练前先执行数据集准备")
|
||
p.add_argument("--val-ratio", type=float, default=0.2, help="验证比例(prepare)")
|
||
p.add_argument("--copy-images", action="store_true", help="prepare 时复制图像")
|
||
p.add_argument("--epochs", type=int, default=200)
|
||
p.add_argument("--batch", type=int, default=8)
|
||
p.add_argument("--imgsz", type=int, default=640)
|
||
p.add_argument("--device", type=str, default="", help="如 0 或 cpu,空则自动")
|
||
p.add_argument(
|
||
"--project",
|
||
type=Path,
|
||
default=root / "detect_refbox" / "runs",
|
||
help="Ultralytics project 目录",
|
||
)
|
||
p.add_argument("--name", type=str, default="", help="运行名,默认带时间戳")
|
||
p.add_argument("--workers", type=int, default=4)
|
||
p.add_argument("--patience", type=int, default=80)
|
||
p.add_argument("--seed", type=int, default=42)
|
||
p.add_argument("--exist-ok", action="store_true")
|
||
return p.parse_args()
|
||
|
||
|
||
def run_prepare(args: argparse.Namespace) -> int:
|
||
prep = Path(__file__).resolve().parent / "prepare_refbox_dataset.py"
|
||
cmd = [
|
||
sys.executable,
|
||
str(prep),
|
||
"--source",
|
||
str(args.source),
|
||
"--out",
|
||
str(args.dataset_out),
|
||
"--val-ratio",
|
||
str(args.val_ratio),
|
||
"--seed",
|
||
str(args.seed),
|
||
]
|
||
if args.copy_images:
|
||
cmd.append("--copy-images")
|
||
return subprocess.call(cmd)
|
||
|
||
|
||
def main() -> int:
|
||
args = parse_args()
|
||
if args.prepare:
|
||
rc = run_prepare(args)
|
||
if rc != 0:
|
||
return rc
|
||
|
||
data_yaml = args.data.expanduser().resolve()
|
||
if not data_yaml.is_file():
|
||
print(
|
||
f"[错误] 未找到 {data_yaml}。请先运行:\n"
|
||
f" python3 {Path(__file__).parent / 'prepare_refbox_dataset.py'} "
|
||
f"--source {args.source}",
|
||
file=sys.stderr,
|
||
)
|
||
return 1
|
||
|
||
model_path = args.model.expanduser().resolve()
|
||
if not model_path.is_file():
|
||
print(f"[错误] 未找到权重: {model_path}", file=sys.stderr)
|
||
return 1
|
||
|
||
try:
|
||
from ultralytics import YOLO
|
||
except ImportError as e:
|
||
print("[错误] 需要 ultralytics: pip install ultralytics", file=sys.stderr)
|
||
print(e, file=sys.stderr)
|
||
return 1
|
||
|
||
name = args.name or f"refbox_y26n_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
args.project = args.project.expanduser().resolve()
|
||
args.project.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 小数据:关闭强 mosaic/mixup,减轻过拟合风险
|
||
model = YOLO(str(model_path))
|
||
model.train(
|
||
data=str(data_yaml),
|
||
epochs=args.epochs,
|
||
imgsz=args.imgsz,
|
||
batch=args.batch,
|
||
device=args.device if args.device else None,
|
||
project=str(args.project),
|
||
name=name,
|
||
workers=args.workers,
|
||
patience=args.patience,
|
||
seed=args.seed,
|
||
exist_ok=args.exist_ok,
|
||
verbose=True,
|
||
mosaic=0.0,
|
||
mixup=0.0,
|
||
copy_paste=0.0,
|
||
degrees=5.0,
|
||
translate=0.08,
|
||
scale=0.25,
|
||
shear=0.0,
|
||
perspective=0.0,
|
||
flipud=0.0,
|
||
fliplr=0.5,
|
||
hsv_h=0.01,
|
||
hsv_s=0.4,
|
||
hsv_v=0.3,
|
||
close_mosaic=0,
|
||
)
|
||
|
||
save_dir = args.project / name
|
||
print(f"\n完成。权重目录: {save_dir / 'weights'}")
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|