Files
FishServer/FishMeasure/segmentation/train_yolo_seg.py

152 lines
4.8 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Ultralytics YOLOv8 segmentation training script.
Example (using filtered dataset):
python3 segmentation/train_yolo_seg.py \
--data ./datasets/fish_body_seg_filtered/dataset.yaml \
--model yolo26s-seg.pt \
--epochs 100 \
--batch 16 \
--imgsz 640 \
--project runs/seg \
--name fish_body_seg_$(date +%Y%m%d_%H%M%S)
Example (with more options):
python3 segmentation/train_yolo_seg.py \
--data ./datasets/fish_body_seg_filtered/dataset.yaml \
--model yolov8s-seg.pt \
--epochs 300 \
--batch 32 \
--imgsz 640 \
--device 0 \
--workers 8 \
--patience 50 \
--pretrained \
--cache \
--project runs/seg \
--name fish_body_seg_yolov8s_$(date +%Y%m%d_%H%M%S)
Dependency:
pip install ultralytics
"""
from __future__ import annotations
import argparse
import os
import sys
from datetime import datetime
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Ultralytics YOLOv8-seg training")
p.add_argument(
"--data",
type=str,
default="./datasets/fish_body_seg_filtered/dataset.yaml",
help="dataset.yaml path (default: ./datasets/fish_body_seg_filtered/dataset.yaml)",
)
p.add_argument(
"--model",
type=str,
default="yolo26l-seg.pt",
help="model weights/arch, e.g. yolov8n-seg.pt/yolov8s-seg.pt or your .pt",
)
p.add_argument("--epochs", type=int, default=100)
p.add_argument("--batch", type=int, default=16)
p.add_argument("--imgsz", type=int, default=640)
p.add_argument("--device", type=str, default="", help="CUDA device like '0' or '0,1'. Empty=auto")
p.add_argument("--project", type=str, default="runs/seg", help="output project dir")
p.add_argument("--name", type=str, default="", help="run name (default: model + timestamp)")
p.add_argument("--workers", type=int, default=8)
p.add_argument("--patience", type=int, default=50)
p.add_argument("--lr0", type=float, default=0.01)
p.add_argument("--pretrained", action="store_true", help="use pretrained weights")
p.add_argument("--cache", action="store_true")
p.add_argument("--seed", type=int, default=0)
p.add_argument("--exist-ok", action="store_true")
p.add_argument("--resume", action="store_true")
p.add_argument("--export", action="store_true", help="export ONNX/TorchScript after training")
return p.parse_args()
def main() -> None:
args = parse_args()
try:
from ultralytics import YOLO
except Exception as e:
print("[error] ultralytics not found. Install with: pip install ultralytics")
print(f"details: {e}")
sys.exit(1)
if not os.path.exists(args.data):
print(f"[error] dataset yaml not found: {args.data}")
sys.exit(1)
if not args.name:
model_stem = os.path.splitext(os.path.basename(args.model))[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.name = f"{model_stem}_{timestamp}"
os.makedirs(args.project, exist_ok=True)
print("======== YOLOv8-seg Train ========")
print(f"data : {args.data}")
print(f"model : {args.model}")
print(f"epochs : {args.epochs}")
print(f"batch : {args.batch}")
print(f"imgsz : {args.imgsz}")
print(f"device : {args.device or 'auto'}")
print(f"project : {args.project}")
print(f"name : {args.name}")
print("=================================")
model = YOLO(args.model)
model.train(
data=args.data,
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
device=args.device if args.device else None,
project=args.project,
name=args.name,
pretrained=args.pretrained,
cache=args.cache,
workers=args.workers,
patience=args.patience,
lr0=args.lr0,
seed=args.seed,
exist_ok=args.exist_ok,
resume=args.resume,
verbose=True,
)
save_dir = os.path.join(args.project, args.name)
best_pt = os.path.join(save_dir, "weights", "best.pt")
last_pt = os.path.join(save_dir, "weights", "last.pt")
print("\n======== Train done ========")
print(f"save_dir : {save_dir}")
if os.path.exists(best_pt):
print(f"best.pt : {best_pt}")
if os.path.exists(last_pt):
print(f"last.pt : {last_pt}")
if args.export and os.path.exists(best_pt):
try:
exp = YOLO(best_pt)
onnx_path = exp.export(format="onnx", imgsz=args.imgsz)
ts_path = exp.export(format="torchscript", imgsz=args.imgsz)
print(f"export onnx : {onnx_path}")
print(f"export torchscript: {ts_path}")
except Exception as e:
print(f"[warn] export failed: {e}")
if __name__ == "__main__":
main()