318 lines
9.7 KiB
Python
318 lines
9.7 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
从递归目录中的 LabelMe 风格 JSON 构建 YOLO 检测数据集(仅参考物 ref)。
|
||
|
||
- 仅使用同时满足:存在对应图像、JSON 内至少有一个可转换的 ref 矩形框 的样本。
|
||
- 图像与 JSON 同目录,或通过 imagePath 解析;若仅有 imageData 则解码写出。
|
||
- 输出扁平唯一文件名(相对路径转 __),避免不同子目录同名帧冲突。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import base64
|
||
import io
|
||
import json
|
||
import random
|
||
import shutil
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||
|
||
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
|
||
# 归一到单一类别 ref(class 0)
|
||
REF_LABELS: Set[str] = {"ref", "reference", "refbox", "参考", "参考物"}
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
p = argparse.ArgumentParser(description="准备 ref 框 YOLO 数据集")
|
||
p.add_argument(
|
||
"--source",
|
||
type=Path,
|
||
default=Path("/home/ubuntu/data/fish/2016-1-22-last_images"),
|
||
help="含递归子目录与 JSON 的根目录",
|
||
)
|
||
p.add_argument(
|
||
"--out",
|
||
type=Path,
|
||
default=None,
|
||
help="输出数据集根目录(默认:本仓库 detect_refbox/dataset)",
|
||
)
|
||
p.add_argument("--val-ratio", type=float, default=0.2, help="验证集比例")
|
||
p.add_argument("--seed", type=int, default=42, help="划分随机种子")
|
||
p.add_argument(
|
||
"--copy-images",
|
||
action="store_true",
|
||
help="复制图像;默认硬链接(同盘失败时回退复制)",
|
||
)
|
||
return p.parse_args()
|
||
|
||
|
||
def repo_root() -> Path:
|
||
return Path(__file__).resolve().parents[1]
|
||
|
||
|
||
def norm_bbox_yolo(
|
||
x1: float, y1: float, x2: float, y2: float, w: int, h: int
|
||
) -> Tuple[float, float, float, float]:
|
||
x_min, y_min = min(x1, x2), min(y1, y2)
|
||
x_max, y_max = max(x1, x2), max(y1, y2)
|
||
bw = max(0.0, x_max - x_min)
|
||
bh = max(0.0, y_max - y_min)
|
||
cx = x_min + bw / 2.0
|
||
cy = y_min + bh / 2.0
|
||
if w <= 0 or h <= 0:
|
||
raise ValueError("invalid image size")
|
||
return cx / w, cy / h, bw / w, bh / h
|
||
|
||
|
||
def load_json(path: Path) -> Optional[Dict[str, Any]]:
|
||
for enc in ("utf-8", "gbk", "gb2312", "latin-1"):
|
||
try:
|
||
with open(path, "r", encoding=enc) as f:
|
||
return json.load(f)
|
||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||
continue
|
||
return None
|
||
|
||
|
||
def find_image_path(json_path: Path, data: Dict[str, Any]) -> Optional[Path]:
|
||
ip = (data.get("imagePath") or "").strip()
|
||
if ip:
|
||
cand = (json_path.parent / ip).resolve()
|
||
if cand.exists() and cand.suffix.lower() in IMG_EXTS:
|
||
return cand
|
||
name = Path(ip).name
|
||
for ext in IMG_EXTS:
|
||
c2 = json_path.parent / f"{Path(name).stem}{ext}"
|
||
if c2.exists():
|
||
return c2
|
||
stem = json_path.stem
|
||
for ext in IMG_EXTS:
|
||
c = json_path.parent / f"{stem}{ext}"
|
||
if c.exists():
|
||
return c
|
||
return None
|
||
|
||
|
||
def image_from_image_data(data: Dict[str, Any]) -> Optional[bytes]:
|
||
raw = data.get("imageData")
|
||
if not raw or not isinstance(raw, str):
|
||
return None
|
||
try:
|
||
return base64.b64decode(raw)
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def shapes_to_yolo_lines(
|
||
data: Dict[str, Any], img_w: int, img_h: int
|
||
) -> List[str]:
|
||
shapes = data.get("shapes") or data.get("annotations") or []
|
||
if not isinstance(shapes, list):
|
||
return []
|
||
lines: List[str] = []
|
||
for shp in shapes:
|
||
st = shp.get("shape_type")
|
||
if st not in (None, "rectangle", "bbox", "box"):
|
||
continue
|
||
label = str(shp.get("label", "")).strip()
|
||
if label not in REF_LABELS:
|
||
continue
|
||
pts = shp.get("points")
|
||
if not pts or len(pts) < 2:
|
||
continue
|
||
(x1, y1), (x2, y2) = pts[0], pts[1]
|
||
try:
|
||
xc, yc, bw, bh = norm_bbox_yolo(
|
||
float(x1), float(y1), float(x2), float(y2), img_w, img_h
|
||
)
|
||
except Exception:
|
||
continue
|
||
if bw <= 0 or bh <= 0:
|
||
continue
|
||
lines.append(f"0 {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}")
|
||
return lines
|
||
|
||
|
||
def ensure_image_size(
|
||
data: Dict[str, Any], img_path: Optional[Path]
|
||
) -> Tuple[int, int]:
|
||
w = int(data.get("imageWidth", 0) or 0)
|
||
h = int(data.get("imageHeight", 0) or 0)
|
||
if w > 0 and h > 0:
|
||
return w, h
|
||
if img_path and img_path.exists():
|
||
try:
|
||
from PIL import Image
|
||
|
||
with Image.open(img_path) as im:
|
||
return im.size
|
||
except Exception:
|
||
pass
|
||
blob = image_from_image_data(data)
|
||
if blob:
|
||
try:
|
||
from PIL import Image
|
||
|
||
with Image.open(io.BytesIO(blob)) as im:
|
||
return im.size
|
||
except Exception:
|
||
pass
|
||
return 0, 0
|
||
|
||
|
||
def unique_stem(json_path: Path, source_root: Path) -> str:
|
||
rel = json_path.parent.relative_to(source_root)
|
||
prefix = rel.as_posix().replace("/", "__")
|
||
return f"{prefix}__{json_path.stem}"
|
||
|
||
|
||
def write_dataset_yaml(out_root: Path) -> Path:
|
||
yaml_path = out_root / "refbox.yaml"
|
||
# Ultralytics:path 为数据集根;train/val 为相对 path 的图像目录
|
||
text = (
|
||
f"path: {out_root.resolve()}\n"
|
||
"train: images/train\n"
|
||
"val: images/val\n"
|
||
"nc: 1\n"
|
||
"names:\n"
|
||
" 0: ref\n"
|
||
)
|
||
yaml_path.write_text(text, encoding="utf-8")
|
||
return yaml_path
|
||
|
||
|
||
def main() -> int:
|
||
args = parse_args()
|
||
source = args.source.expanduser().resolve()
|
||
if not source.is_dir():
|
||
print(f"[错误] 数据目录不存在: {source}", file=sys.stderr)
|
||
return 1
|
||
|
||
out_root = (
|
||
args.out.expanduser().resolve()
|
||
if args.out
|
||
else repo_root() / "detect_refbox" / "dataset"
|
||
)
|
||
img_train = out_root / "images" / "train"
|
||
img_val = out_root / "images" / "val"
|
||
lbl_train = out_root / "labels" / "train"
|
||
lbl_val = out_root / "labels" / "val"
|
||
for d in (img_train, img_val, lbl_train, lbl_val):
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
|
||
records: List[Tuple[Path, Path, List[str], str]] = []
|
||
# (json_path, src_image_path, yolo_lines, unique_stem)
|
||
|
||
json_files = sorted(source.rglob("*.json"))
|
||
skipped = 0
|
||
for jp in json_files:
|
||
data = load_json(jp)
|
||
if not data:
|
||
skipped += 1
|
||
continue
|
||
|
||
img_path = find_image_path(jp, data)
|
||
if not img_path:
|
||
blob = image_from_image_data(data)
|
||
if not blob:
|
||
skipped += 1
|
||
continue
|
||
ext = Path((data.get("imagePath") or "img.png")).suffix.lower()
|
||
if ext not in IMG_EXTS:
|
||
ext = ".png"
|
||
stem = unique_stem(jp, source)
|
||
tmp_img = out_root / "_tmp_decode" / f"{stem}{ext}"
|
||
tmp_img.parent.mkdir(parents=True, exist_ok=True)
|
||
tmp_img.write_bytes(blob)
|
||
img_path = tmp_img
|
||
|
||
iw, ih = ensure_image_size(data, img_path)
|
||
if iw <= 0 or ih <= 0:
|
||
skipped += 1
|
||
continue
|
||
|
||
lines = shapes_to_yolo_lines(data, iw, ih)
|
||
if not lines:
|
||
skipped += 1
|
||
continue
|
||
|
||
stem = unique_stem(jp, source)
|
||
ext = img_path.suffix.lower()
|
||
if ext not in IMG_EXTS:
|
||
ext = ".png"
|
||
records.append((jp, img_path, lines, stem + ext))
|
||
|
||
if not records:
|
||
print("[错误] 没有可用样本(需 JSON + 图像 + ref 矩形)", file=sys.stderr)
|
||
return 1
|
||
|
||
rng = random.Random(args.seed)
|
||
rng.shuffle(records)
|
||
n_val = int(round(len(records) * args.val_ratio))
|
||
n_val = max(1, n_val) if len(records) >= 2 else 0
|
||
n_val = min(n_val, len(records) - 1) if len(records) >= 2 else 0
|
||
# 仅 1 张:训练与验证共用同一张(写入两个目录),避免 YOLO 无 val
|
||
single_dup = len(records) == 1
|
||
val_set = set(range(len(records) - n_val, len(records))) if not single_dup else set()
|
||
|
||
n_tr = 0
|
||
n_va = 0
|
||
|
||
def materialize(
|
||
src_img: Path, lines: List[str], fname: str, is_val: bool
|
||
) -> bool:
|
||
nonlocal n_tr, n_va, skipped
|
||
idir = img_val if is_val else img_train
|
||
ldir = lbl_val if is_val else lbl_train
|
||
dst_img = idir / fname
|
||
stem = Path(fname).stem
|
||
dst_lbl = ldir / f"{stem}.txt"
|
||
dst_img.parent.mkdir(parents=True, exist_ok=True)
|
||
try:
|
||
if args.copy_images:
|
||
shutil.copy2(src_img, dst_img)
|
||
else:
|
||
if dst_img.exists():
|
||
dst_img.unlink()
|
||
try:
|
||
dst_img.hardlink_to(src_img)
|
||
except OSError:
|
||
shutil.copy2(src_img, dst_img)
|
||
except Exception as e:
|
||
print(f"[跳过] 复制图像失败 {src_img}: {e}", file=sys.stderr)
|
||
skipped += 1
|
||
return False
|
||
dst_lbl.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||
if is_val:
|
||
n_va += 1
|
||
else:
|
||
n_tr += 1
|
||
return True
|
||
|
||
for i, (_, src_img, lines, fname) in enumerate(records):
|
||
if single_dup:
|
||
materialize(src_img, lines, fname, is_val=False)
|
||
materialize(src_img, lines, fname, is_val=True)
|
||
break
|
||
is_val = i in val_set
|
||
materialize(src_img, lines, fname, is_val=is_val)
|
||
|
||
yaml_path = write_dataset_yaml(out_root)
|
||
tmp = out_root / "_tmp_decode"
|
||
if tmp.exists():
|
||
shutil.rmtree(tmp, ignore_errors=True)
|
||
|
||
print(
|
||
f"完成: 训练 {n_tr} / 验证 {n_va}(跳过 {skipped} 个 JSON)\n"
|
||
f"数据集: {out_root}\n"
|
||
f"YAML: {yaml_path}"
|
||
)
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|