Files

163 lines
4.7 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用 Ultralytics YOLO 训练参考物ref单类检测模型
默认权重仓库根目录 yolo26n.pt
默认数据先运行 prepare_refbox_dataset.py 生成 detect_refbox/dataset/refbox.yaml
小样本建议较轻增强适中 epoch较小 batch
"""
from __future__ import annotations
import argparse
import subprocess
import sys
from datetime import datetime
from pathlib import Path
def repo_root() -> Path:
return Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
root = repo_root()
p = argparse.ArgumentParser(description="训练 ref 框 YOLOyolo26n")
p.add_argument(
"--model",
type=Path,
default=root / "yolo26n.pt",
help="初始权重路径",
)
p.add_argument(
"--data",
type=Path,
default=root / "detect_refbox" / "dataset" / "refbox.yaml",
help="数据集 YAML",
)
p.add_argument(
"--source",
type=Path,
default=Path("/home/ubuntu/data/fish/2016-1-22-last_images"),
help="原始 JSON 图像根目录(仅 --prepare 时)",
)
p.add_argument(
"--dataset-out",
type=Path,
default=root / "detect_refbox" / "dataset",
help="prepare 输出目录",
)
p.add_argument("--prepare", action="store_true", help="训练前先执行数据集准备")
p.add_argument("--val-ratio", type=float, default=0.2, help="验证比例prepare")
p.add_argument("--copy-images", action="store_true", help="prepare 时复制图像")
p.add_argument("--epochs", type=int, default=200)
p.add_argument("--batch", type=int, default=8)
p.add_argument("--imgsz", type=int, default=640)
p.add_argument("--device", type=str, default="", help="如 0 或 cpu空则自动")
p.add_argument(
"--project",
type=Path,
default=root / "detect_refbox" / "runs",
help="Ultralytics project 目录",
)
p.add_argument("--name", type=str, default="", help="运行名,默认带时间戳")
p.add_argument("--workers", type=int, default=4)
p.add_argument("--patience", type=int, default=80)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--exist-ok", action="store_true")
return p.parse_args()
def run_prepare(args: argparse.Namespace) -> int:
prep = Path(__file__).resolve().parent / "prepare_refbox_dataset.py"
cmd = [
sys.executable,
str(prep),
"--source",
str(args.source),
"--out",
str(args.dataset_out),
"--val-ratio",
str(args.val_ratio),
"--seed",
str(args.seed),
]
if args.copy_images:
cmd.append("--copy-images")
return subprocess.call(cmd)
def main() -> int:
args = parse_args()
if args.prepare:
rc = run_prepare(args)
if rc != 0:
return rc
data_yaml = args.data.expanduser().resolve()
if not data_yaml.is_file():
print(
f"[错误] 未找到 {data_yaml}。请先运行:\n"
f" python3 {Path(__file__).parent / 'prepare_refbox_dataset.py'} "
f"--source {args.source}",
file=sys.stderr,
)
return 1
model_path = args.model.expanduser().resolve()
if not model_path.is_file():
print(f"[错误] 未找到权重: {model_path}", file=sys.stderr)
return 1
try:
from ultralytics import YOLO
except ImportError as e:
print("[错误] 需要 ultralytics: pip install ultralytics", file=sys.stderr)
print(e, file=sys.stderr)
return 1
name = args.name or f"refbox_y26n_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
args.project = args.project.expanduser().resolve()
args.project.mkdir(parents=True, exist_ok=True)
# 小数据:关闭强 mosaic/mixup减轻过拟合风险
model = YOLO(str(model_path))
model.train(
data=str(data_yaml),
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
device=args.device if args.device else None,
project=str(args.project),
name=name,
workers=args.workers,
patience=args.patience,
seed=args.seed,
exist_ok=args.exist_ok,
verbose=True,
mosaic=0.0,
mixup=0.0,
copy_paste=0.0,
degrees=5.0,
translate=0.08,
scale=0.25,
shear=0.0,
perspective=0.0,
flipud=0.0,
fliplr=0.5,
hsv_h=0.01,
hsv_s=0.4,
hsv_v=0.3,
close_mosaic=0,
)
save_dir = args.project / name
print(f"\n完成。权重目录: {save_dir / 'weights'}")
return 0
if __name__ == "__main__":
raise SystemExit(main())