Initial commit: FishServer monorepo (FishAction, FishMeasure, fish_api)
Made-with: Cursor
This commit is contained in:
151
FishMeasure/segmentation/train_yolo_seg.py
Executable file
151
FishMeasure/segmentation/train_yolo_seg.py
Executable file
@@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Ultralytics YOLOv8 segmentation training script.
|
||||
|
||||
Example (using filtered dataset):
|
||||
python3 segmentation/train_yolo_seg.py \
|
||||
--data ./datasets/fish_body_seg_filtered/dataset.yaml \
|
||||
--model yolo26s-seg.pt \
|
||||
--epochs 100 \
|
||||
--batch 16 \
|
||||
--imgsz 640 \
|
||||
--project runs/seg \
|
||||
--name fish_body_seg_$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
Example (with more options):
|
||||
python3 segmentation/train_yolo_seg.py \
|
||||
--data ./datasets/fish_body_seg_filtered/dataset.yaml \
|
||||
--model yolov8s-seg.pt \
|
||||
--epochs 300 \
|
||||
--batch 32 \
|
||||
--imgsz 640 \
|
||||
--device 0 \
|
||||
--workers 8 \
|
||||
--patience 50 \
|
||||
--pretrained \
|
||||
--cache \
|
||||
--project runs/seg \
|
||||
--name fish_body_seg_yolov8s_$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
Dependency:
|
||||
pip install ultralytics
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Ultralytics YOLOv8-seg training")
|
||||
p.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
default="./datasets/fish_body_seg_filtered/dataset.yaml",
|
||||
help="dataset.yaml path (default: ./datasets/fish_body_seg_filtered/dataset.yaml)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="yolo26l-seg.pt",
|
||||
help="model weights/arch, e.g. yolov8n-seg.pt/yolov8s-seg.pt or your .pt",
|
||||
)
|
||||
p.add_argument("--epochs", type=int, default=100)
|
||||
p.add_argument("--batch", type=int, default=16)
|
||||
p.add_argument("--imgsz", type=int, default=640)
|
||||
p.add_argument("--device", type=str, default="", help="CUDA device like '0' or '0,1'. Empty=auto")
|
||||
p.add_argument("--project", type=str, default="runs/seg", help="output project dir")
|
||||
p.add_argument("--name", type=str, default="", help="run name (default: model + timestamp)")
|
||||
p.add_argument("--workers", type=int, default=8)
|
||||
p.add_argument("--patience", type=int, default=50)
|
||||
p.add_argument("--lr0", type=float, default=0.01)
|
||||
p.add_argument("--pretrained", action="store_true", help="use pretrained weights")
|
||||
p.add_argument("--cache", action="store_true")
|
||||
p.add_argument("--seed", type=int, default=0)
|
||||
p.add_argument("--exist-ok", action="store_true")
|
||||
p.add_argument("--resume", action="store_true")
|
||||
p.add_argument("--export", action="store_true", help="export ONNX/TorchScript after training")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
except Exception as e:
|
||||
print("[error] ultralytics not found. Install with: pip install ultralytics")
|
||||
print(f"details: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if not os.path.exists(args.data):
|
||||
print(f"[error] dataset yaml not found: {args.data}")
|
||||
sys.exit(1)
|
||||
|
||||
if not args.name:
|
||||
model_stem = os.path.splitext(os.path.basename(args.model))[0]
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
args.name = f"{model_stem}_{timestamp}"
|
||||
|
||||
os.makedirs(args.project, exist_ok=True)
|
||||
|
||||
print("======== YOLOv8-seg Train ========")
|
||||
print(f"data : {args.data}")
|
||||
print(f"model : {args.model}")
|
||||
print(f"epochs : {args.epochs}")
|
||||
print(f"batch : {args.batch}")
|
||||
print(f"imgsz : {args.imgsz}")
|
||||
print(f"device : {args.device or 'auto'}")
|
||||
print(f"project : {args.project}")
|
||||
print(f"name : {args.name}")
|
||||
print("=================================")
|
||||
|
||||
model = YOLO(args.model)
|
||||
model.train(
|
||||
data=args.data,
|
||||
epochs=args.epochs,
|
||||
imgsz=args.imgsz,
|
||||
batch=args.batch,
|
||||
device=args.device if args.device else None,
|
||||
project=args.project,
|
||||
name=args.name,
|
||||
pretrained=args.pretrained,
|
||||
cache=args.cache,
|
||||
workers=args.workers,
|
||||
patience=args.patience,
|
||||
lr0=args.lr0,
|
||||
seed=args.seed,
|
||||
exist_ok=args.exist_ok,
|
||||
resume=args.resume,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
save_dir = os.path.join(args.project, args.name)
|
||||
best_pt = os.path.join(save_dir, "weights", "best.pt")
|
||||
last_pt = os.path.join(save_dir, "weights", "last.pt")
|
||||
print("\n======== Train done ========")
|
||||
print(f"save_dir : {save_dir}")
|
||||
if os.path.exists(best_pt):
|
||||
print(f"best.pt : {best_pt}")
|
||||
if os.path.exists(last_pt):
|
||||
print(f"last.pt : {last_pt}")
|
||||
|
||||
if args.export and os.path.exists(best_pt):
|
||||
try:
|
||||
exp = YOLO(best_pt)
|
||||
onnx_path = exp.export(format="onnx", imgsz=args.imgsz)
|
||||
ts_path = exp.export(format="torchscript", imgsz=args.imgsz)
|
||||
print(f"export onnx : {onnx_path}")
|
||||
print(f"export torchscript: {ts_path}")
|
||||
except Exception as e:
|
||||
print(f"[warn] export failed: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user