265 lines
8.3 KiB
Python
265 lines
8.3 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Dataset preparation script for YOLO training
|
|
- Reads train/xxx/images/ and val/xxx/images/ from source directory
|
|
- Converts labelme JSON format to YOLO format
|
|
- Generates YOLO-compatible dataset structure
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import shutil
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Optional
|
|
from PIL import Image
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Prepare YOLO dataset from labelme JSON files"
|
|
)
|
|
parser.add_argument(
|
|
"--source_dir",
|
|
type=str,
|
|
default="/home/ubuntu/data/fish/detection/svo_batch/data",
|
|
help="Source directory containing train/ and val/ subdirectories",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
default="./yolo_dataset",
|
|
help="Output directory for YOLO dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--class_name",
|
|
type=str,
|
|
default="fish",
|
|
help="Class name for YOLO labels (default: fish)",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def find_matching_pairs(images_dir: Path) -> List[Tuple[Path, Path]]:
|
|
"""
|
|
Find all JSON files and their corresponding PNG files.
|
|
Returns list of (json_path, png_path) tuples.
|
|
"""
|
|
pairs = []
|
|
json_files = list(images_dir.glob("*.json"))
|
|
|
|
for json_file in json_files:
|
|
# Find corresponding PNG file (same base name)
|
|
png_file = json_file.with_suffix(".png")
|
|
if png_file.exists():
|
|
pairs.append((json_file, png_file))
|
|
else:
|
|
print(f"[WARNING] No matching PNG for {json_file.name}, skipping")
|
|
|
|
return pairs
|
|
|
|
|
|
def labelme_rectangle_to_yolo(points: List[List[float]], img_width: int, img_height: int) -> Optional[Tuple[float, float, float, float]]:
|
|
"""
|
|
Convert labelme rectangle format to YOLO format.
|
|
Labelme rectangle: [[x1, y1], [x2, y2]] (top-left and bottom-right)
|
|
YOLO format: (x_center, y_center, width, height) normalized [0, 1]
|
|
"""
|
|
if len(points) != 2:
|
|
return None
|
|
|
|
x1, y1 = points[0]
|
|
x2, y2 = points[1]
|
|
|
|
# Ensure x1 < x2 and y1 < y2
|
|
x_min = min(x1, x2)
|
|
x_max = max(x1, x2)
|
|
y_min = min(y1, y2)
|
|
y_max = max(y1, y2)
|
|
|
|
# Calculate center and dimensions
|
|
width = x_max - x_min
|
|
height = y_max - y_min
|
|
x_center = x_min + width / 2.0
|
|
y_center = y_min + height / 2.0
|
|
|
|
# Normalize to [0, 1]
|
|
x_center_norm = x_center / img_width
|
|
y_center_norm = y_center / img_height
|
|
width_norm = width / img_width
|
|
height_norm = height / img_height
|
|
|
|
# Validate bounds
|
|
if not (0 <= x_center_norm <= 1 and 0 <= y_center_norm <= 1 and
|
|
0 < width_norm <= 1 and 0 < height_norm <= 1):
|
|
return None
|
|
|
|
return (x_center_norm, y_center_norm, width_norm, height_norm)
|
|
|
|
|
|
def convert_labelme_to_yolo(json_path: Path, img_path: Path, class_id: int) -> List[str]:
|
|
"""
|
|
Convert labelme JSON to YOLO format label lines.
|
|
Returns list of YOLO label strings: "class_id x_center y_center width height"
|
|
"""
|
|
try:
|
|
with open(json_path, 'r', encoding='utf-8') as f:
|
|
labelme_data = json.load(f)
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to read {json_path}: {e}")
|
|
return []
|
|
|
|
# Get image dimensions
|
|
try:
|
|
with Image.open(img_path) as img:
|
|
img_width, img_height = img.size
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to read image {img_path}: {e}")
|
|
return []
|
|
|
|
yolo_lines = []
|
|
shapes = labelme_data.get('shapes', [])
|
|
|
|
for shape in shapes:
|
|
shape_type = shape.get('shape_type', '')
|
|
label = shape.get('label', '')
|
|
points = shape.get('points', [])
|
|
|
|
# Only process rectangles for now
|
|
if shape_type == 'rectangle' and len(points) == 2:
|
|
yolo_bbox = labelme_rectangle_to_yolo(points, img_width, img_height)
|
|
if yolo_bbox is not None:
|
|
x_center, y_center, width, height = yolo_bbox
|
|
yolo_lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
|
elif shape_type == 'polygon':
|
|
# For polygons, convert to bounding box
|
|
if len(points) >= 3:
|
|
xs = [p[0] for p in points]
|
|
ys = [p[1] for p in points]
|
|
x_min, x_max = min(xs), max(xs)
|
|
y_min, y_max = min(ys), max(ys)
|
|
# Convert to rectangle format
|
|
rect_points = [[x_min, y_min], [x_max, y_max]]
|
|
yolo_bbox = labelme_rectangle_to_yolo(rect_points, img_width, img_height)
|
|
if yolo_bbox is not None:
|
|
x_center, y_center, width, height = yolo_bbox
|
|
yolo_lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
|
|
|
return yolo_lines
|
|
|
|
|
|
def process_split(source_dir: Path, split: str, output_dir: Path, class_id: int):
|
|
"""
|
|
Process train or val split.
|
|
"""
|
|
split_dir = source_dir / split
|
|
if not split_dir.exists():
|
|
print(f"[WARNING] {split_dir} does not exist, skipping")
|
|
return 0
|
|
|
|
# Create output directories
|
|
output_images_dir = output_dir / "images" / split
|
|
output_labels_dir = output_dir / "labels" / split
|
|
output_images_dir.mkdir(parents=True, exist_ok=True)
|
|
output_labels_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
processed_count = 0
|
|
skipped_count = 0
|
|
|
|
# Iterate through all subdirectories
|
|
for subfolder in sorted(split_dir.iterdir()):
|
|
if not subfolder.is_dir():
|
|
continue
|
|
|
|
images_dir = subfolder / "images"
|
|
if not images_dir.exists():
|
|
continue
|
|
|
|
# Find all JSON-PNG pairs
|
|
pairs = find_matching_pairs(images_dir)
|
|
|
|
if not pairs:
|
|
print(f"[SKIP] No JSON-PNG pairs found in {subfolder.name}")
|
|
skipped_count += 1
|
|
continue
|
|
|
|
# Process each pair
|
|
for json_path, img_path in pairs:
|
|
# Convert labelme to YOLO
|
|
yolo_lines = convert_labelme_to_yolo(json_path, img_path, class_id)
|
|
|
|
if not yolo_lines:
|
|
print(f"[SKIP] No valid labels in {json_path.name}")
|
|
continue
|
|
|
|
# Copy image to output directory
|
|
dst_img = output_images_dir / img_path.name
|
|
shutil.copy2(img_path, dst_img)
|
|
|
|
# Write YOLO label file
|
|
label_name = img_path.stem + ".txt"
|
|
dst_label = output_labels_dir / label_name
|
|
with open(dst_label, 'w', encoding='utf-8') as f:
|
|
f.write('\n'.join(yolo_lines) + '\n')
|
|
|
|
processed_count += 1
|
|
|
|
print(f"[{split.upper()}] Processed {processed_count} images, skipped {skipped_count} subfolders")
|
|
return processed_count
|
|
|
|
|
|
def generate_dataset_yaml(output_dir: Path, class_name: str):
|
|
"""
|
|
Generate dataset.yaml file for YOLO training.
|
|
"""
|
|
yaml_path = output_dir / "dataset.yaml"
|
|
content = f"""path: {output_dir.resolve()}
|
|
train: images/train
|
|
val: images/val
|
|
names: [{class_name}]
|
|
"""
|
|
with open(yaml_path, 'w', encoding='utf-8') as f:
|
|
f.write(content)
|
|
print(f"[OK] Generated dataset.yaml: {yaml_path}")
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
source_dir = Path(args.source_dir)
|
|
output_dir = Path(args.output_dir)
|
|
|
|
if not source_dir.exists():
|
|
print(f"[ERROR] Source directory does not exist: {source_dir}")
|
|
return
|
|
|
|
# Create output directory
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
print(f"Source directory: {source_dir}")
|
|
print(f"Output directory: {output_dir}")
|
|
print(f"Class name: {args.class_name}")
|
|
print("-" * 60)
|
|
|
|
# Process train and val splits
|
|
class_id = 0 # YOLO uses 0-indexed class IDs
|
|
train_count = process_split(source_dir, "train", output_dir, class_id)
|
|
val_count = process_split(source_dir, "val", output_dir, class_id)
|
|
|
|
# Generate dataset.yaml
|
|
generate_dataset_yaml(output_dir, args.class_name)
|
|
|
|
print("-" * 60)
|
|
print(f"[SUMMARY]")
|
|
print(f" Train images: {train_count}")
|
|
print(f" Val images: {val_count}")
|
|
print(f" Total images: {train_count + val_count}")
|
|
print(f" Dataset ready at: {output_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|