Files

348 lines
10 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""
Build a dataset index mapping point clouds (.ply) to weight labels.
Data:
- Point clouds live under:
/home/ubuntu/data/fish/2025-11-19-output/{sample_id}/cloud/*.ply
- Labels live at:
/home/ubuntu/projects/FishMeasure/measure/data/label.csv
CSV format:
- Column B (index 1): sample_id (svo2 name / folder name)
- Column F (index 5): weight in grams (float)
Output JSON:
{
"meta": {...},
"items": [
{"ply": "<abs_or_rel_path>", "sample_id": "...", "weight_g": 123.45}
],
"mapping": {
"<abs_or_rel_path>": 123.45
}
}
"""
from __future__ import annotations
import argparse
import csv
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
DEFAULT_DATA_ROOT = "/home/ubuntu/data/fish/2025-11-19-output"
DEFAULT_LABEL_CSV = "/home/ubuntu/projects/FishMeasure/measure/data/label.csv"
@dataclass(frozen=True)
class LabelRow:
sample_id: str
weight_g: Optional[float]
def _parse_weight_g(cell: str) -> Optional[float]:
cell = (cell or "").strip()
if not cell:
return None
try:
return float(cell)
except Exception:
return None
def load_labels(label_csv: Path) -> List[LabelRow]:
"""
Load labels from label.csv.
Returns a list of LabelRow (keeps duplicates).
"""
rows: List[LabelRow] = []
with label_csv.open("r", encoding="utf-8") as f:
reader = csv.reader(f)
for r in reader:
if len(r) < 2:
continue
sample_id = (r[1] or "").strip()
if not sample_id or sample_id.lower() == "xxx":
continue
weight_g = _parse_weight_g(r[5] if len(r) > 5 else "")
rows.append(LabelRow(sample_id=sample_id, weight_g=weight_g))
return rows
def resolve_sample_weights(
label_rows: List[LabelRow],
duplicate_policy: str = "mean",
) -> Tuple[Dict[str, Optional[float]], Dict[str, List[Optional[float]]]]:
"""
Resolve sample_id -> single weight according to policy.
duplicate_policy:
- mean: average of non-null weights
- first: first non-null weight
- error: raise if a sample_id has multiple distinct non-null weights
Returns:
- resolved: {sample_id: weight_g or None}
- raw: {sample_id: [weight_g_or_None, ...]}
"""
raw: Dict[str, List[Optional[float]]] = {}
for row in label_rows:
raw.setdefault(row.sample_id, []).append(row.weight_g)
resolved: Dict[str, Optional[float]] = {}
for sample_id, weights in raw.items():
non_null = [w for w in weights if w is not None]
if not non_null:
resolved[sample_id] = None
continue
if duplicate_policy == "first":
resolved[sample_id] = non_null[0]
continue
if duplicate_policy == "mean":
resolved[sample_id] = sum(non_null) / len(non_null)
continue
if duplicate_policy == "error":
distinct = sorted(set(non_null))
if len(distinct) > 1:
raise ValueError(
f"Duplicate sample_id with multiple weights: {sample_id}: {distinct}"
)
resolved[sample_id] = distinct[0]
continue
raise ValueError(f"Unknown duplicate_policy: {duplicate_policy}")
return resolved, raw
def collect_ply_items(
data_root: Path,
sample_weights: Dict[str, Optional[float]],
relative_paths: bool,
) -> Tuple[List[dict], Dict[str, float], Dict[str, int]]:
"""
Build (items list, mapping dict, stats dict).
"""
items: List[dict] = []
mapping: Dict[str, float] = {}
missing_folders = 0
missing_weights = 0
total_plys = 0
for sample_id in sorted(sample_weights.keys()):
weight_g = sample_weights[sample_id]
if weight_g is None:
missing_weights += 1
continue
cloud_dir = data_root / sample_id / "cloud"
if not cloud_dir.exists():
missing_folders += 1
continue
ply_files = sorted(cloud_dir.glob("*.ply"))
if not ply_files:
# still count as missing folder-like data case
missing_folders += 1
continue
for ply in ply_files:
total_plys += 1
ply_path = ply
if relative_paths:
try:
ply_path = ply.relative_to(data_root)
except Exception:
ply_path = ply
ply_key = str(ply_path)
item = {
"ply": ply_key,
"sample_id": sample_id,
"weight_g": float(weight_g),
}
items.append(item)
mapping[ply_key] = float(weight_g)
stats = {
"missing_folders": missing_folders,
"missing_weights": missing_weights,
"total_plys": total_plys,
}
return items, mapping, stats
def build_index(
data_root: Path,
label_csv: Path,
output_json: Path,
duplicate_policy: str = "mean",
relative_paths: bool = False,
) -> dict:
label_rows = load_labels(label_csv)
sample_weights, raw_weights = resolve_sample_weights(
label_rows, duplicate_policy=duplicate_policy
)
items, mapping, stats = collect_ply_items(
data_root=data_root,
sample_weights=sample_weights,
relative_paths=relative_paths,
)
num_duplicates = sum(1 for _, ws in raw_weights.items() if len(ws) > 1)
out = {
"meta": {
"data_root": str(data_root),
"label_csv": str(label_csv),
"weight_column": "F",
"duplicate_policy": duplicate_policy,
"relative_paths": relative_paths,
"num_label_rows": len(label_rows),
"num_sample_ids": len(sample_weights),
"num_sample_ids_with_duplicate_rows": num_duplicates,
**stats,
},
"items": items,
"mapping": mapping,
}
output_json.parent.mkdir(parents=True, exist_ok=True)
output_json.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
return out
def prune_missing_files(index_json: Path, data_root: Optional[Path] = None) -> dict:
"""
Load index JSON, remove items whose PLY file does not exist, and return updated index.
"""
with index_json.open("r", encoding="utf-8") as f:
index = json.load(f)
items = index.get("items", [])
meta_root = index.get("meta", {}).get("data_root", None)
root = Path(data_root or meta_root or "/").expanduser().resolve()
kept: List[dict] = []
removed = 0
for item in items:
ply_str = item.get("ply", "")
ply_path = Path(ply_str)
if not ply_path.is_absolute():
ply_path = root / ply_path
ply_path = ply_path.expanduser().resolve()
if ply_path.exists():
kept.append(item)
else:
removed += 1
mapping = {it["ply"]: it["weight_g"] for it in kept}
meta = index.get("meta", {})
meta["total_plys"] = len(kept)
meta["pruned_missing"] = removed
out = {
"meta": meta,
"items": kept,
"mapping": mapping,
}
return out
def main() -> None:
parser = argparse.ArgumentParser(
description="Create a JSON mapping each .ply to weight (grams) using label.csv column F."
)
parser.add_argument(
"--data-root",
type=str,
default=DEFAULT_DATA_ROOT,
help=f"Dataset root containing { '{sample_id}/cloud/*.ply' } (default: {DEFAULT_DATA_ROOT})",
)
parser.add_argument(
"--label-csv",
type=str,
default=DEFAULT_LABEL_CSV,
help=f"Path to label.csv (default: {DEFAULT_LABEL_CSV})",
)
parser.add_argument(
"--output",
type=str,
default="weight_estimator/dataset_index.json",
help="Output JSON path (default: weight_estimator/dataset_index.json)",
)
parser.add_argument(
"--duplicate-policy",
type=str,
default="mean",
choices=["mean", "first", "error"],
help="How to resolve duplicate sample_id rows in label.csv (default: mean).",
)
parser.add_argument(
"--relative-paths",
action="store_true",
help="Store PLY paths relative to --data-root instead of absolute paths.",
)
parser.add_argument(
"--prune-missing",
action="store_true",
help="Load existing index, remove entries for non-existent PLY files, and overwrite.",
)
args = parser.parse_args()
output_json = Path(args.output).expanduser().resolve()
if args.prune_missing:
if not output_json.exists():
raise SystemExit(f"Index JSON not found (required for --prune-missing): {output_json}")
data_root = Path(args.data_root).expanduser().resolve() if args.data_root else None
out = prune_missing_files(output_json, data_root=data_root)
output_json.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
meta = out["meta"]
print("Pruned dataset index (removed non-existent PLY entries).")
print(f" output: {output_json}")
print(f" total_plys: {meta['total_plys']} (removed {meta.get('pruned_missing', 0)} missing)")
return
data_root = Path(args.data_root).expanduser().resolve()
label_csv = Path(args.label_csv).expanduser().resolve()
if not data_root.exists():
raise SystemExit(f"data root does not exist: {data_root}")
if not label_csv.exists():
raise SystemExit(f"label csv does not exist: {label_csv}")
out = build_index(
data_root=data_root,
label_csv=label_csv,
output_json=output_json,
duplicate_policy=args.duplicate_policy,
relative_paths=args.relative_paths,
)
meta = out["meta"]
print("Dataset index written.")
print(f" output: {output_json}")
print(f" total_plys: {meta['total_plys']}")
print(f" missing_weights: {meta['missing_weights']}")
print(f" missing_folders: {meta['missing_folders']}")
print(f" duplicate_policy: {meta['duplicate_policy']}")
if __name__ == "__main__":
main()