Files
FishServer/FishMeasure/detection/prepare_yolo_dataset.py

296 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
YOLO 数据预处理脚本
- 从 source_dir (默认: ../data) 自动发现 COCO 格式的 JSON 标注
- 按比例切分 train/val/test
- 转换为 YOLO TXT 标注,并整理为 Ultralytics YOLO 所需目录结构
- 生成 dataset.yaml 便于训练
用法示例:
python prepare_yolo_dataset.py \
--source_dir ../data \
--out_dir ../datasets/yolo_dataset \
--train_ratio 0.8 --val_ratio 0.1 --test_ratio 0.1 \
--copy
注意:
- 默认假设 JSON 为 COCO 格式images / annotations / categories
- 如果同目录存在多个 JSON会选择包含 images 数量最多的一个
- 支持拷贝(--copy)或硬链接(--link默认)方式整理图片
"""
import os
import sys
import json
import argparse
import random
import shutil
import unicodedata
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from urllib.parse import unquote
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="准备YOLO数据集(从COCO JSON)")
parser.add_argument("--source_dir", type=str, default="../data", help="原始数据根目录包含图像与json")
parser.add_argument("--json", type=str, default="", help="指定COCO json路径可选默认自动发现")
parser.add_argument("--out_dir", type=str, default="../datasets/yolo_dataset", help="输出数据集根目录")
parser.add_argument("--train_ratio", type=float, default=0.8, help="训练集比例")
parser.add_argument("--val_ratio", type=float, default=0.1, help="验证集比例")
parser.add_argument("--test_ratio", type=float, default=0.1, help="测试集比例")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
group = parser.add_mutually_exclusive_group()
group.add_argument("--copy", action="store_true", help="复制图片到目标目录(默认使用硬链接)")
group.add_argument("--symlink", action="store_true", help="使用软链接")
return parser.parse_args()
def find_images(root: Path) -> List[Path]:
return [p for p in root.rglob("*") if p.suffix.lower() in IMG_EXTS]
def auto_find_json(source_dir: Path) -> Path:
jsons = list(source_dir.rglob("*.json"))
if not jsons:
raise FileNotFoundError(f"未在 {source_dir} 下找到任何json标注文件")
# 优先选择images数量最多的coco json
best = None
best_n = -1
for j in jsons:
try:
with open(j, "r", encoding="utf-8") as f:
data = json.load(f)
n = len(data.get("images", []))
if n > best_n:
best, best_n = j, n
except Exception:
continue
if best is None:
raise RuntimeError("未找到可用的COCO json (无images字段)")
return best
def load_coco(json_path: Path):
with open(json_path, "r", encoding="utf-8") as f:
coco = json.load(f)
images = {img["id"]: img for img in coco.get("images", [])}
categories = {cat["id"]: cat for cat in coco.get("categories", [])}
# 将annotations按照image_id聚合
anns_by_img: Dict[int, List[Dict]] = {}
for ann in coco.get("annotations", []):
if ann.get("iscrowd", 0):
continue
img_id = ann["image_id"]
anns_by_img.setdefault(img_id, []).append(ann)
return coco, images, anns_by_img, categories
def bbox_coco_to_yolo(bbox: List[float], img_w: int, img_h: int):
# COCO: [x_min, y_min, w, h] (像素)
x_min, y_min, w, h = bbox
x_c = x_min + w / 2.0
y_c = y_min + h / 2.0
return x_c / img_w, y_c / img_h, w / img_w, h / img_h
def ensure_dirs(root: Path):
for sub in ["images/train", "images/val", "images/test", "labels/train", "labels/val", "labels/test"]:
(root / sub).mkdir(parents=True, exist_ok=True)
def place_image(src: Path, dst: Path, mode: str):
dst.parent.mkdir(parents=True, exist_ok=True)
if mode == "copy":
shutil.copy2(src, dst)
elif mode == "symlink":
if dst.exists():
dst.unlink()
os.symlink(src, dst)
else: # hardlink
if dst.exists():
dst.unlink()
# 对跨设备的硬链接失败回退为复制
try:
os.link(src, dst)
except OSError:
shutil.copy2(src, dst)
def write_label(label_path: Path, lines: List[str]):
label_path.parent.mkdir(parents=True, exist_ok=True)
with open(label_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + ("\n" if lines else ""))
def resolve_image_path(source_dir: Path, file_name: str) -> Optional[Path]:
"""尽最大可能解析COCO中的file_name为实际图片路径。"""
candidates = sanitize_file_name(file_name)
# 1) 直接拼接相对/绝对路径
for c in candidates:
p = (source_dir / c).resolve() if not os.path.isabs(c) else Path(c)
if p.exists() and p.suffix.lower() in IMG_EXTS:
return p
# 2) 在source_dir下用basename递归查找
base_names = [Path(c).name for c in candidates]
base_names = list(dict.fromkeys([b for b in base_names if b]))
for b in base_names:
found = list(source_dir.rglob(b))
for p in found:
if p.suffix.lower() in IMG_EXTS:
return p
# 3) 用stem + 扩展名的通配符在全目录搜索
for c in candidates:
stem = Path(c).stem
if not stem:
continue
for ext in IMG_EXTS:
found = list(source_dir.rglob(f"{stem}{ext}"))
if found:
return found[0]
return None
def generate_yaml(out_dir: Path, names: List[str]):
yaml_path = out_dir / "dataset.yaml"
content = (
f"path: {out_dir.resolve()}\n"
f"train: images/train\n"
f"val: images/val\n"
f"test: images/test\n"
f"names: {names}\n"
)
with open(yaml_path, "w", encoding="utf-8") as f:
f.write(content)
print(f"[OK] 生成数据集配置: {yaml_path}")
def main():
args = parse_args()
random.seed(args.seed)
source_dir = Path(args.source_dir).resolve()
out_dir = Path(args.out_dir).resolve()
json_path = Path(args.json).resolve() if args.json else None
if not source_dir.exists():
print(f"[错误] source_dir 不存在: {source_dir}")
sys.exit(1)
# 自动发现JSON
if json_path is None:
try:
json_path = auto_find_json(source_dir)
except Exception as e:
print(f"[错误] 自动发现json失败: {e}")
sys.exit(1)
print(f"使用标注: {json_path}")
# 加载COCO
coco, images, anns_by_img, categories = load_coco(json_path)
# 类别名列表按cat_id升序映射到从0开始的class_id
cat_id_list = sorted(categories.keys())
catid2cls = {cat_id: idx for idx, cat_id in enumerate(cat_id_list)}
names = [categories[cid]["name"] for cid in cat_id_list]
# 收集有效图像
valid: List[Tuple[int, Path]] = []
miss_count = 0
for img in images.values():
fn = img.get("file_name", "")
p = resolve_image_path(source_dir, fn)
if p is None:
miss_count += 1
if miss_count <= 20:
print(f"[警告] 未找到文件: '{fn}' (将跳过)")
continue
valid.append((img["id"], p))
if not valid:
print("[错误] 未找到任何有效图像文件,请检查 file_name 与路径/扩展名/编码")
sys.exit(1)
# 切分
img_ids = [img_id for img_id, _ in valid]
random.shuffle(img_ids)
n = len(img_ids)
n_train = int(n * args.train_ratio)
n_val = int(n * args.val_ratio)
n_test = n - n_train - n_val
train_ids = set(img_ids[:n_train])
val_ids = set(img_ids[n_train:n_train + n_val])
test_ids = set(img_ids[n_train + n_val:])
print(f"总图像: {n} | train: {len(train_ids)} val: {len(val_ids)} test: {len(test_ids)} | 缺失: {miss_count}")
# 准备目录
ensure_dirs(out_dir)
place_mode = "copy" if args.copy else ("symlink" if args.symlink else "hardlink")
# 处理每张图片与标注
for img_id, img_path in valid:
img_info = images[img_id]
img_w = int(img_info.get("width", 0) or 0)
img_h = int(img_info.get("height", 0) or 0)
if img_w <= 0 or img_h <= 0:
# 若缺失尝试用PIL读取尺寸
try:
from PIL import Image
with Image.open(img_path) as im:
img_w, img_h = im.size
except Exception:
print(f"[警告] 无法获取尺寸,跳过: {img_path}")
continue
# 生成YOLO标签行
yolo_lines: List[str] = []
for ann in anns_by_img.get(img_id, []):
cat_id = ann.get("category_id")
if cat_id not in catid2cls:
continue
cls_id = catid2cls[cat_id]
bbox = ann.get("bbox", None)
if not bbox:
continue
x_c, y_c, w, h = bbox_coco_to_yolo(bbox, img_w, img_h)
# 过滤异常框
if w <= 0 or h <= 0 or x_c <= 0 or y_c <= 0 or x_c >= 1 or y_c >= 1:
continue
yolo_lines.append(f"{cls_id} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}")
# 目标子目录
if img_id in train_ids:
split = "train"
elif img_id in val_ids:
split = "val"
else:
split = "test"
dst_img = out_dir / f"images/{split}/{img_path.name}"
dst_lbl = out_dir / f"labels/{split}/{img_path.with_suffix('.txt').name}"
# 放置图片
try:
place_image(img_path, dst_img, place_mode)
except Exception as e:
print(f"[警告] 放置图片失败 {img_path} -> {dst_img}: {e}")
continue
# 写入标签
write_label(dst_lbl, yolo_lines)
# 写 dataset.yaml
generate_yaml(out_dir, names)
print("[完成] 数据集已整理为YOLO格式")
if __name__ == "__main__":
main()