Files
FishServer/FishMeasure/dataset/dataset.py
2026-05-06 15:59:38 +08:00

265 lines
8.3 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Dataset preparation script for YOLO training
- Reads train/xxx/images/ and val/xxx/images/ from source directory
- Converts labelme JSON format to YOLO format
- Generates YOLO-compatible dataset structure
"""
import os
import json
import shutil
import argparse
from pathlib import Path
from typing import List, Tuple, Optional
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser(
description="Prepare YOLO dataset from labelme JSON files"
)
parser.add_argument(
"--source_dir",
type=str,
default="/home/ubuntu/data/fish/detection/svo_batch/data",
help="Source directory containing train/ and val/ subdirectories",
)
parser.add_argument(
"--output_dir",
type=str,
default="./yolo_dataset",
help="Output directory for YOLO dataset",
)
parser.add_argument(
"--class_name",
type=str,
default="fish",
help="Class name for YOLO labels (default: fish)",
)
return parser.parse_args()
def find_matching_pairs(images_dir: Path) -> List[Tuple[Path, Path]]:
"""
Find all JSON files and their corresponding PNG files.
Returns list of (json_path, png_path) tuples.
"""
pairs = []
json_files = list(images_dir.glob("*.json"))
for json_file in json_files:
# Find corresponding PNG file (same base name)
png_file = json_file.with_suffix(".png")
if png_file.exists():
pairs.append((json_file, png_file))
else:
print(f"[WARNING] No matching PNG for {json_file.name}, skipping")
return pairs
def labelme_rectangle_to_yolo(points: List[List[float]], img_width: int, img_height: int) -> Optional[Tuple[float, float, float, float]]:
"""
Convert labelme rectangle format to YOLO format.
Labelme rectangle: [[x1, y1], [x2, y2]] (top-left and bottom-right)
YOLO format: (x_center, y_center, width, height) normalized [0, 1]
"""
if len(points) != 2:
return None
x1, y1 = points[0]
x2, y2 = points[1]
# Ensure x1 < x2 and y1 < y2
x_min = min(x1, x2)
x_max = max(x1, x2)
y_min = min(y1, y2)
y_max = max(y1, y2)
# Calculate center and dimensions
width = x_max - x_min
height = y_max - y_min
x_center = x_min + width / 2.0
y_center = y_min + height / 2.0
# Normalize to [0, 1]
x_center_norm = x_center / img_width
y_center_norm = y_center / img_height
width_norm = width / img_width
height_norm = height / img_height
# Validate bounds
if not (0 <= x_center_norm <= 1 and 0 <= y_center_norm <= 1 and
0 < width_norm <= 1 and 0 < height_norm <= 1):
return None
return (x_center_norm, y_center_norm, width_norm, height_norm)
def convert_labelme_to_yolo(json_path: Path, img_path: Path, class_id: int) -> List[str]:
"""
Convert labelme JSON to YOLO format label lines.
Returns list of YOLO label strings: "class_id x_center y_center width height"
"""
try:
with open(json_path, 'r', encoding='utf-8') as f:
labelme_data = json.load(f)
except Exception as e:
print(f"[ERROR] Failed to read {json_path}: {e}")
return []
# Get image dimensions
try:
with Image.open(img_path) as img:
img_width, img_height = img.size
except Exception as e:
print(f"[ERROR] Failed to read image {img_path}: {e}")
return []
yolo_lines = []
shapes = labelme_data.get('shapes', [])
for shape in shapes:
shape_type = shape.get('shape_type', '')
label = shape.get('label', '')
points = shape.get('points', [])
# Only process rectangles for now
if shape_type == 'rectangle' and len(points) == 2:
yolo_bbox = labelme_rectangle_to_yolo(points, img_width, img_height)
if yolo_bbox is not None:
x_center, y_center, width, height = yolo_bbox
yolo_lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
elif shape_type == 'polygon':
# For polygons, convert to bounding box
if len(points) >= 3:
xs = [p[0] for p in points]
ys = [p[1] for p in points]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
# Convert to rectangle format
rect_points = [[x_min, y_min], [x_max, y_max]]
yolo_bbox = labelme_rectangle_to_yolo(rect_points, img_width, img_height)
if yolo_bbox is not None:
x_center, y_center, width, height = yolo_bbox
yolo_lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
return yolo_lines
def process_split(source_dir: Path, split: str, output_dir: Path, class_id: int):
"""
Process train or val split.
"""
split_dir = source_dir / split
if not split_dir.exists():
print(f"[WARNING] {split_dir} does not exist, skipping")
return 0
# Create output directories
output_images_dir = output_dir / "images" / split
output_labels_dir = output_dir / "labels" / split
output_images_dir.mkdir(parents=True, exist_ok=True)
output_labels_dir.mkdir(parents=True, exist_ok=True)
processed_count = 0
skipped_count = 0
# Iterate through all subdirectories
for subfolder in sorted(split_dir.iterdir()):
if not subfolder.is_dir():
continue
images_dir = subfolder / "images"
if not images_dir.exists():
continue
# Find all JSON-PNG pairs
pairs = find_matching_pairs(images_dir)
if not pairs:
print(f"[SKIP] No JSON-PNG pairs found in {subfolder.name}")
skipped_count += 1
continue
# Process each pair
for json_path, img_path in pairs:
# Convert labelme to YOLO
yolo_lines = convert_labelme_to_yolo(json_path, img_path, class_id)
if not yolo_lines:
print(f"[SKIP] No valid labels in {json_path.name}")
continue
# Copy image to output directory
dst_img = output_images_dir / img_path.name
shutil.copy2(img_path, dst_img)
# Write YOLO label file
label_name = img_path.stem + ".txt"
dst_label = output_labels_dir / label_name
with open(dst_label, 'w', encoding='utf-8') as f:
f.write('\n'.join(yolo_lines) + '\n')
processed_count += 1
print(f"[{split.upper()}] Processed {processed_count} images, skipped {skipped_count} subfolders")
return processed_count
def generate_dataset_yaml(output_dir: Path, class_name: str):
"""
Generate dataset.yaml file for YOLO training.
"""
yaml_path = output_dir / "dataset.yaml"
content = f"""path: {output_dir.resolve()}
train: images/train
val: images/val
names: [{class_name}]
"""
with open(yaml_path, 'w', encoding='utf-8') as f:
f.write(content)
print(f"[OK] Generated dataset.yaml: {yaml_path}")
def main():
args = parse_args()
source_dir = Path(args.source_dir)
output_dir = Path(args.output_dir)
if not source_dir.exists():
print(f"[ERROR] Source directory does not exist: {source_dir}")
return
# Create output directory
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Source directory: {source_dir}")
print(f"Output directory: {output_dir}")
print(f"Class name: {args.class_name}")
print("-" * 60)
# Process train and val splits
class_id = 0 # YOLO uses 0-indexed class IDs
train_count = process_split(source_dir, "train", output_dir, class_id)
val_count = process_split(source_dir, "val", output_dir, class_id)
# Generate dataset.yaml
generate_dataset_yaml(output_dir, args.class_name)
print("-" * 60)
print(f"[SUMMARY]")
print(f" Train images: {train_count}")
print(f" Val images: {val_count}")
print(f" Total images: {train_count + val_count}")
print(f" Dataset ready at: {output_dir}")
if __name__ == "__main__":
main()