Files
FishServer/FishMeasure/segmentation/visualize_yolo_seg_labels.py
2026-04-08 19:32:23 +08:00

216 lines
6.6 KiB
Python
Executable File

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Visualize YOLOv8-seg labels by drawing polygon masks on images.
This helps verify that label conversion (e.g., from Labelme JSON to YOLO .txt) is correct.
Example:
python3 segmentation/visualize_yolo_seg_labels.py \
--dataset ./datasets/fish_body_seg_filtered \
--output ./visualizations/yolo_seg_labels \
--max_images 50 \
--split train
"""
from __future__ import annotations
import argparse
import random
from pathlib import Path
from typing import List, Tuple
import cv2
import numpy as np
def parse_yolo_seg_label(label_path: Path, img_w: int, img_h: int) -> List[np.ndarray]:
"""
Parse YOLO segmentation label file.
Returns list of polygons (each as Nx2 numpy array in pixel coordinates).
"""
if not label_path.exists():
return []
polygons: List[np.ndarray] = []
lines = label_path.read_text(encoding="utf-8").strip().split("\n")
for line in lines:
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) < 7: # class_id + at least 3 points (x,y pairs)
continue
try:
_class_id = int(parts[0])
coords = [float(x) for x in parts[1:]]
if len(coords) % 2 != 0:
continue
# Convert normalized [0,1] to pixel coordinates
points = []
for i in range(0, len(coords), 2):
x_norm = coords[i]
y_norm = coords[i + 1]
x_px = int(x_norm * img_w)
y_px = int(y_norm * img_h)
points.append([x_px, y_px])
if len(points) >= 3:
polygons.append(np.array(points, dtype=np.int32))
except (ValueError, IndexError):
continue
return polygons
def draw_polygons_on_image(
img: np.ndarray, polygons: List[np.ndarray], class_colors: List[Tuple[int, int, int]], alpha: float = 0.5
) -> np.ndarray:
"""
Draw polygons as semi-transparent masks on image.
Returns a new image with overlays.
"""
overlay = img.copy()
mask = np.zeros(img.shape[:2], dtype=np.uint8)
for i, poly in enumerate(polygons):
color_idx = i % len(class_colors)
color = class_colors[color_idx]
cv2.fillPoly(mask, [poly], 255)
cv2.fillPoly(overlay, [poly], color)
# Also draw outline
cv2.polylines(overlay, [poly], isClosed=True, color=color, thickness=2)
# Blend overlay with original
result = cv2.addWeighted(overlay, alpha, img, 1.0 - alpha, 0)
return result
def visualize_dataset(
dataset_dir: Path,
output_dir: Path,
split: str = "train",
max_images: int = 50,
class_names: List[str] = None,
alpha: float = 0.5,
) -> None:
"""
Visualize YOLO segmentation labels on images.
"""
dataset_dir = dataset_dir.expanduser().resolve()
if not dataset_dir.exists():
raise SystemExit(f"dataset_dir not found: {dataset_dir}")
img_dir = dataset_dir / "images" / split
lbl_dir = dataset_dir / "labels" / split
if not img_dir.exists():
raise SystemExit(f"images directory not found: {img_dir}")
if not lbl_dir.exists():
raise SystemExit(f"labels directory not found: {lbl_dir}")
# Collect image-label pairs
pairs: List[Tuple[Path, Path]] = []
for img_path in sorted(img_dir.iterdir()):
if img_path.suffix.lower() not in {".jpg", ".jpeg", ".png", ".bmp"}:
continue
lbl_path = lbl_dir / f"{img_path.stem}.txt"
if lbl_path.exists():
pairs.append((img_path, lbl_path))
if not pairs:
raise SystemExit(f"No image-label pairs found in {split} split")
# Limit number of images
if max_images > 0 and len(pairs) > max_images:
random.seed(42)
pairs = random.sample(pairs, max_images)
print(f"Visualizing {len(pairs)} images from {split} split...")
# Generate colors for classes (BGR format for OpenCV)
if class_names is None:
class_names = ["class0", "class1", "class2"]
colors = [
(0, 255, 0), # green
(255, 0, 0), # blue
(0, 0, 255), # red
(255, 255, 0), # cyan
(255, 0, 255), # magenta
(0, 255, 255), # yellow
]
class_colors = colors[: len(class_names)]
output_dir = output_dir.expanduser().resolve()
output_dir.mkdir(parents=True, exist_ok=True)
for img_path, lbl_path in pairs:
# Load image
img = cv2.imread(str(img_path))
if img is None:
print(f"[warn] failed to load: {img_path}")
continue
h, w = img.shape[:2]
# Parse labels
polygons = parse_yolo_seg_label(lbl_path, w, h)
if not polygons:
print(f"[warn] no polygons found in: {lbl_path}")
# Still save the original image for reference
out_path = output_dir / f"{img_path.stem}_no_labels{img_path.suffix}"
cv2.imwrite(str(out_path), img)
continue
# Draw polygons
vis_img = draw_polygons_on_image(img, polygons, class_colors, alpha=alpha)
# Save visualization
out_path = output_dir / f"{img_path.stem}_vis{img_path.suffix}"
cv2.imwrite(str(out_path), vis_img)
print(f"[done] saved {len(pairs)} visualizations to: {output_dir}")
def main() -> None:
parser = argparse.ArgumentParser(description="Visualize YOLOv8-seg labels on images")
parser.add_argument("--dataset", type=str, required=True, help="Path to YOLO-seg dataset directory")
parser.add_argument("--output", type=str, required=True, help="Output directory for visualizations")
parser.add_argument(
"--split", type=str, default="train", choices=["train", "val", "test"], help="Dataset split to visualize"
)
parser.add_argument("--max_images", type=int, default=50, help="Maximum number of images to visualize (0=all)")
parser.add_argument(
"--classes",
type=str,
default="fishbody",
help="Comma-separated class names (for color assignment, e.g. 'fishbody' or 'body,fin,tail')",
)
parser.add_argument(
"--alpha", type=float, default=0.5, help="Transparency of mask overlay (0.0=transparent, 1.0=opaque)"
)
args = parser.parse_args()
class_names = [c.strip() for c in (args.classes or "").split(",") if c.strip()]
if not class_names:
class_names = ["class0"]
visualize_dataset(
dataset_dir=Path(args.dataset),
output_dir=Path(args.output),
split=args.split,
max_images=args.max_images,
class_names=class_names,
alpha=args.alpha,
)
if __name__ == "__main__":
main()