Files
FishServer/FishMeasure/detect_refbox/train_refbox_yolo.py

163 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用 Ultralytics YOLO 训练参考物ref单类检测模型。
默认权重:仓库根目录 yolo26n.pt
默认数据:先运行 prepare_refbox_dataset.py 生成 detect_refbox/dataset/refbox.yaml
小样本建议:较轻增强、适中 epoch、较小 batch。
"""
from __future__ import annotations
import argparse
import subprocess
import sys
from datetime import datetime
from pathlib import Path
def repo_root() -> Path:
return Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
root = repo_root()
p = argparse.ArgumentParser(description="训练 ref 框 YOLOyolo26n")
p.add_argument(
"--model",
type=Path,
default=root / "yolo26n.pt",
help="初始权重路径",
)
p.add_argument(
"--data",
type=Path,
default=root / "detect_refbox" / "dataset" / "refbox.yaml",
help="数据集 YAML",
)
p.add_argument(
"--source",
type=Path,
default=Path("/home/ubuntu/data/fish/2016-1-22-last_images"),
help="原始 JSON 图像根目录(仅 --prepare 时)",
)
p.add_argument(
"--dataset-out",
type=Path,
default=root / "detect_refbox" / "dataset",
help="prepare 输出目录",
)
p.add_argument("--prepare", action="store_true", help="训练前先执行数据集准备")
p.add_argument("--val-ratio", type=float, default=0.2, help="验证比例prepare")
p.add_argument("--copy-images", action="store_true", help="prepare 时复制图像")
p.add_argument("--epochs", type=int, default=200)
p.add_argument("--batch", type=int, default=8)
p.add_argument("--imgsz", type=int, default=640)
p.add_argument("--device", type=str, default="", help="如 0 或 cpu空则自动")
p.add_argument(
"--project",
type=Path,
default=root / "detect_refbox" / "runs",
help="Ultralytics project 目录",
)
p.add_argument("--name", type=str, default="", help="运行名,默认带时间戳")
p.add_argument("--workers", type=int, default=4)
p.add_argument("--patience", type=int, default=80)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--exist-ok", action="store_true")
return p.parse_args()
def run_prepare(args: argparse.Namespace) -> int:
prep = Path(__file__).resolve().parent / "prepare_refbox_dataset.py"
cmd = [
sys.executable,
str(prep),
"--source",
str(args.source),
"--out",
str(args.dataset_out),
"--val-ratio",
str(args.val_ratio),
"--seed",
str(args.seed),
]
if args.copy_images:
cmd.append("--copy-images")
return subprocess.call(cmd)
def main() -> int:
args = parse_args()
if args.prepare:
rc = run_prepare(args)
if rc != 0:
return rc
data_yaml = args.data.expanduser().resolve()
if not data_yaml.is_file():
print(
f"[错误] 未找到 {data_yaml}。请先运行:\n"
f" python3 {Path(__file__).parent / 'prepare_refbox_dataset.py'} "
f"--source {args.source}",
file=sys.stderr,
)
return 1
model_path = args.model.expanduser().resolve()
if not model_path.is_file():
print(f"[错误] 未找到权重: {model_path}", file=sys.stderr)
return 1
try:
from ultralytics import YOLO
except ImportError as e:
print("[错误] 需要 ultralytics: pip install ultralytics", file=sys.stderr)
print(e, file=sys.stderr)
return 1
name = args.name or f"refbox_y26n_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
args.project = args.project.expanduser().resolve()
args.project.mkdir(parents=True, exist_ok=True)
# 小数据:关闭强 mosaic/mixup减轻过拟合风险
model = YOLO(str(model_path))
model.train(
data=str(data_yaml),
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
device=args.device if args.device else None,
project=str(args.project),
name=name,
workers=args.workers,
patience=args.patience,
seed=args.seed,
exist_ok=args.exist_ok,
verbose=True,
mosaic=0.0,
mixup=0.0,
copy_paste=0.0,
degrees=5.0,
translate=0.08,
scale=0.25,
shear=0.0,
perspective=0.0,
flipud=0.0,
fliplr=0.5,
hsv_h=0.01,
hsv_s=0.4,
hsv_v=0.3,
close_mosaic=0,
)
save_dir = args.project / name
print(f"\n完成。权重目录: {save_dir / 'weights'}")
return 0
if __name__ == "__main__":
raise SystemExit(main())