Files
FishServer/FishMeasure/detect_refbox/prepare_refbox_dataset.py

318 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
从递归目录中的 LabelMe 风格 JSON 构建 YOLO 检测数据集(仅参考物 ref
- 仅使用同时满足存在对应图像、JSON 内至少有一个可转换的 ref 矩形框 的样本。
- 图像与 JSON 同目录,或通过 imagePath 解析;若仅有 imageData 则解码写出。
- 输出扁平唯一文件名(相对路径转 __避免不同子目录同名帧冲突。
"""
from __future__ import annotations
import argparse
import base64
import io
import json
import random
import shutil
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
# 归一到单一类别 refclass 0
REF_LABELS: Set[str] = {"ref", "reference", "refbox", "参考", "参考物"}
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="准备 ref 框 YOLO 数据集")
p.add_argument(
"--source",
type=Path,
default=Path("/home/ubuntu/data/fish/2016-1-22-last_images"),
help="含递归子目录与 JSON 的根目录",
)
p.add_argument(
"--out",
type=Path,
default=None,
help="输出数据集根目录(默认:本仓库 detect_refbox/dataset",
)
p.add_argument("--val-ratio", type=float, default=0.2, help="验证集比例")
p.add_argument("--seed", type=int, default=42, help="划分随机种子")
p.add_argument(
"--copy-images",
action="store_true",
help="复制图像;默认硬链接(同盘失败时回退复制)",
)
return p.parse_args()
def repo_root() -> Path:
return Path(__file__).resolve().parents[1]
def norm_bbox_yolo(
x1: float, y1: float, x2: float, y2: float, w: int, h: int
) -> Tuple[float, float, float, float]:
x_min, y_min = min(x1, x2), min(y1, y2)
x_max, y_max = max(x1, x2), max(y1, y2)
bw = max(0.0, x_max - x_min)
bh = max(0.0, y_max - y_min)
cx = x_min + bw / 2.0
cy = y_min + bh / 2.0
if w <= 0 or h <= 0:
raise ValueError("invalid image size")
return cx / w, cy / h, bw / w, bh / h
def load_json(path: Path) -> Optional[Dict[str, Any]]:
for enc in ("utf-8", "gbk", "gb2312", "latin-1"):
try:
with open(path, "r", encoding=enc) as f:
return json.load(f)
except (UnicodeDecodeError, json.JSONDecodeError):
continue
return None
def find_image_path(json_path: Path, data: Dict[str, Any]) -> Optional[Path]:
ip = (data.get("imagePath") or "").strip()
if ip:
cand = (json_path.parent / ip).resolve()
if cand.exists() and cand.suffix.lower() in IMG_EXTS:
return cand
name = Path(ip).name
for ext in IMG_EXTS:
c2 = json_path.parent / f"{Path(name).stem}{ext}"
if c2.exists():
return c2
stem = json_path.stem
for ext in IMG_EXTS:
c = json_path.parent / f"{stem}{ext}"
if c.exists():
return c
return None
def image_from_image_data(data: Dict[str, Any]) -> Optional[bytes]:
raw = data.get("imageData")
if not raw or not isinstance(raw, str):
return None
try:
return base64.b64decode(raw)
except Exception:
return None
def shapes_to_yolo_lines(
data: Dict[str, Any], img_w: int, img_h: int
) -> List[str]:
shapes = data.get("shapes") or data.get("annotations") or []
if not isinstance(shapes, list):
return []
lines: List[str] = []
for shp in shapes:
st = shp.get("shape_type")
if st not in (None, "rectangle", "bbox", "box"):
continue
label = str(shp.get("label", "")).strip()
if label not in REF_LABELS:
continue
pts = shp.get("points")
if not pts or len(pts) < 2:
continue
(x1, y1), (x2, y2) = pts[0], pts[1]
try:
xc, yc, bw, bh = norm_bbox_yolo(
float(x1), float(y1), float(x2), float(y2), img_w, img_h
)
except Exception:
continue
if bw <= 0 or bh <= 0:
continue
lines.append(f"0 {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}")
return lines
def ensure_image_size(
data: Dict[str, Any], img_path: Optional[Path]
) -> Tuple[int, int]:
w = int(data.get("imageWidth", 0) or 0)
h = int(data.get("imageHeight", 0) or 0)
if w > 0 and h > 0:
return w, h
if img_path and img_path.exists():
try:
from PIL import Image
with Image.open(img_path) as im:
return im.size
except Exception:
pass
blob = image_from_image_data(data)
if blob:
try:
from PIL import Image
with Image.open(io.BytesIO(blob)) as im:
return im.size
except Exception:
pass
return 0, 0
def unique_stem(json_path: Path, source_root: Path) -> str:
rel = json_path.parent.relative_to(source_root)
prefix = rel.as_posix().replace("/", "__")
return f"{prefix}__{json_path.stem}"
def write_dataset_yaml(out_root: Path) -> Path:
yaml_path = out_root / "refbox.yaml"
# Ultralyticspath 为数据集根train/val 为相对 path 的图像目录
text = (
f"path: {out_root.resolve()}\n"
"train: images/train\n"
"val: images/val\n"
"nc: 1\n"
"names:\n"
" 0: ref\n"
)
yaml_path.write_text(text, encoding="utf-8")
return yaml_path
def main() -> int:
args = parse_args()
source = args.source.expanduser().resolve()
if not source.is_dir():
print(f"[错误] 数据目录不存在: {source}", file=sys.stderr)
return 1
out_root = (
args.out.expanduser().resolve()
if args.out
else repo_root() / "detect_refbox" / "dataset"
)
img_train = out_root / "images" / "train"
img_val = out_root / "images" / "val"
lbl_train = out_root / "labels" / "train"
lbl_val = out_root / "labels" / "val"
for d in (img_train, img_val, lbl_train, lbl_val):
d.mkdir(parents=True, exist_ok=True)
records: List[Tuple[Path, Path, List[str], str]] = []
# (json_path, src_image_path, yolo_lines, unique_stem)
json_files = sorted(source.rglob("*.json"))
skipped = 0
for jp in json_files:
data = load_json(jp)
if not data:
skipped += 1
continue
img_path = find_image_path(jp, data)
if not img_path:
blob = image_from_image_data(data)
if not blob:
skipped += 1
continue
ext = Path((data.get("imagePath") or "img.png")).suffix.lower()
if ext not in IMG_EXTS:
ext = ".png"
stem = unique_stem(jp, source)
tmp_img = out_root / "_tmp_decode" / f"{stem}{ext}"
tmp_img.parent.mkdir(parents=True, exist_ok=True)
tmp_img.write_bytes(blob)
img_path = tmp_img
iw, ih = ensure_image_size(data, img_path)
if iw <= 0 or ih <= 0:
skipped += 1
continue
lines = shapes_to_yolo_lines(data, iw, ih)
if not lines:
skipped += 1
continue
stem = unique_stem(jp, source)
ext = img_path.suffix.lower()
if ext not in IMG_EXTS:
ext = ".png"
records.append((jp, img_path, lines, stem + ext))
if not records:
print("[错误] 没有可用样本(需 JSON + 图像 + ref 矩形)", file=sys.stderr)
return 1
rng = random.Random(args.seed)
rng.shuffle(records)
n_val = int(round(len(records) * args.val_ratio))
n_val = max(1, n_val) if len(records) >= 2 else 0
n_val = min(n_val, len(records) - 1) if len(records) >= 2 else 0
# 仅 1 张:训练与验证共用同一张(写入两个目录),避免 YOLO 无 val
single_dup = len(records) == 1
val_set = set(range(len(records) - n_val, len(records))) if not single_dup else set()
n_tr = 0
n_va = 0
def materialize(
src_img: Path, lines: List[str], fname: str, is_val: bool
) -> bool:
nonlocal n_tr, n_va, skipped
idir = img_val if is_val else img_train
ldir = lbl_val if is_val else lbl_train
dst_img = idir / fname
stem = Path(fname).stem
dst_lbl = ldir / f"{stem}.txt"
dst_img.parent.mkdir(parents=True, exist_ok=True)
try:
if args.copy_images:
shutil.copy2(src_img, dst_img)
else:
if dst_img.exists():
dst_img.unlink()
try:
dst_img.hardlink_to(src_img)
except OSError:
shutil.copy2(src_img, dst_img)
except Exception as e:
print(f"[跳过] 复制图像失败 {src_img}: {e}", file=sys.stderr)
skipped += 1
return False
dst_lbl.write_text("\n".join(lines) + "\n", encoding="utf-8")
if is_val:
n_va += 1
else:
n_tr += 1
return True
for i, (_, src_img, lines, fname) in enumerate(records):
if single_dup:
materialize(src_img, lines, fname, is_val=False)
materialize(src_img, lines, fname, is_val=True)
break
is_val = i in val_set
materialize(src_img, lines, fname, is_val=is_val)
yaml_path = write_dataset_yaml(out_root)
tmp = out_root / "_tmp_decode"
if tmp.exists():
shutil.rmtree(tmp, ignore_errors=True)
print(
f"完成: 训练 {n_tr} / 验证 {n_va}(跳过 {skipped} 个 JSON\n"
f"数据集: {out_root}\n"
f"YAML: {yaml_path}"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())