348 lines
10 KiB
Python
Executable File
348 lines
10 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Build a dataset index mapping point clouds (.ply) to weight labels.
|
|
|
|
Data:
|
|
- Point clouds live under:
|
|
/home/ubuntu/data/fish/2025-11-19-output/{sample_id}/cloud/*.ply
|
|
- Labels live at:
|
|
/home/ubuntu/projects/FishMeasure/measure/data/label.csv
|
|
|
|
CSV format:
|
|
- Column B (index 1): sample_id (svo2 name / folder name)
|
|
- Column F (index 5): weight in grams (float)
|
|
|
|
Output JSON:
|
|
{
|
|
"meta": {...},
|
|
"items": [
|
|
{"ply": "<abs_or_rel_path>", "sample_id": "...", "weight_g": 123.45}
|
|
],
|
|
"mapping": {
|
|
"<abs_or_rel_path>": 123.45
|
|
}
|
|
}
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
|
DEFAULT_DATA_ROOT = "/home/ubuntu/data/fish/2025-11-19-output"
|
|
DEFAULT_LABEL_CSV = "/home/ubuntu/projects/FishMeasure/measure/data/label.csv"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LabelRow:
|
|
sample_id: str
|
|
weight_g: Optional[float]
|
|
|
|
|
|
def _parse_weight_g(cell: str) -> Optional[float]:
|
|
cell = (cell or "").strip()
|
|
if not cell:
|
|
return None
|
|
try:
|
|
return float(cell)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def load_labels(label_csv: Path) -> List[LabelRow]:
|
|
"""
|
|
Load labels from label.csv.
|
|
|
|
Returns a list of LabelRow (keeps duplicates).
|
|
"""
|
|
rows: List[LabelRow] = []
|
|
with label_csv.open("r", encoding="utf-8") as f:
|
|
reader = csv.reader(f)
|
|
for r in reader:
|
|
if len(r) < 2:
|
|
continue
|
|
sample_id = (r[1] or "").strip()
|
|
if not sample_id or sample_id.lower() == "xxx":
|
|
continue
|
|
weight_g = _parse_weight_g(r[5] if len(r) > 5 else "")
|
|
rows.append(LabelRow(sample_id=sample_id, weight_g=weight_g))
|
|
return rows
|
|
|
|
|
|
def resolve_sample_weights(
|
|
label_rows: List[LabelRow],
|
|
duplicate_policy: str = "mean",
|
|
) -> Tuple[Dict[str, Optional[float]], Dict[str, List[Optional[float]]]]:
|
|
"""
|
|
Resolve sample_id -> single weight according to policy.
|
|
|
|
duplicate_policy:
|
|
- mean: average of non-null weights
|
|
- first: first non-null weight
|
|
- error: raise if a sample_id has multiple distinct non-null weights
|
|
|
|
Returns:
|
|
- resolved: {sample_id: weight_g or None}
|
|
- raw: {sample_id: [weight_g_or_None, ...]}
|
|
"""
|
|
raw: Dict[str, List[Optional[float]]] = {}
|
|
for row in label_rows:
|
|
raw.setdefault(row.sample_id, []).append(row.weight_g)
|
|
|
|
resolved: Dict[str, Optional[float]] = {}
|
|
for sample_id, weights in raw.items():
|
|
non_null = [w for w in weights if w is not None]
|
|
if not non_null:
|
|
resolved[sample_id] = None
|
|
continue
|
|
|
|
if duplicate_policy == "first":
|
|
resolved[sample_id] = non_null[0]
|
|
continue
|
|
|
|
if duplicate_policy == "mean":
|
|
resolved[sample_id] = sum(non_null) / len(non_null)
|
|
continue
|
|
|
|
if duplicate_policy == "error":
|
|
distinct = sorted(set(non_null))
|
|
if len(distinct) > 1:
|
|
raise ValueError(
|
|
f"Duplicate sample_id with multiple weights: {sample_id}: {distinct}"
|
|
)
|
|
resolved[sample_id] = distinct[0]
|
|
continue
|
|
|
|
raise ValueError(f"Unknown duplicate_policy: {duplicate_policy}")
|
|
|
|
return resolved, raw
|
|
|
|
|
|
def collect_ply_items(
|
|
data_root: Path,
|
|
sample_weights: Dict[str, Optional[float]],
|
|
relative_paths: bool,
|
|
) -> Tuple[List[dict], Dict[str, float], Dict[str, int]]:
|
|
"""
|
|
Build (items list, mapping dict, stats dict).
|
|
"""
|
|
items: List[dict] = []
|
|
mapping: Dict[str, float] = {}
|
|
|
|
missing_folders = 0
|
|
missing_weights = 0
|
|
total_plys = 0
|
|
|
|
for sample_id in sorted(sample_weights.keys()):
|
|
weight_g = sample_weights[sample_id]
|
|
if weight_g is None:
|
|
missing_weights += 1
|
|
continue
|
|
|
|
cloud_dir = data_root / sample_id / "cloud"
|
|
if not cloud_dir.exists():
|
|
missing_folders += 1
|
|
continue
|
|
|
|
ply_files = sorted(cloud_dir.glob("*.ply"))
|
|
if not ply_files:
|
|
# still count as missing folder-like data case
|
|
missing_folders += 1
|
|
continue
|
|
|
|
for ply in ply_files:
|
|
total_plys += 1
|
|
ply_path = ply
|
|
if relative_paths:
|
|
try:
|
|
ply_path = ply.relative_to(data_root)
|
|
except Exception:
|
|
ply_path = ply
|
|
|
|
ply_key = str(ply_path)
|
|
item = {
|
|
"ply": ply_key,
|
|
"sample_id": sample_id,
|
|
"weight_g": float(weight_g),
|
|
}
|
|
items.append(item)
|
|
mapping[ply_key] = float(weight_g)
|
|
|
|
stats = {
|
|
"missing_folders": missing_folders,
|
|
"missing_weights": missing_weights,
|
|
"total_plys": total_plys,
|
|
}
|
|
return items, mapping, stats
|
|
|
|
|
|
def build_index(
|
|
data_root: Path,
|
|
label_csv: Path,
|
|
output_json: Path,
|
|
duplicate_policy: str = "mean",
|
|
relative_paths: bool = False,
|
|
) -> dict:
|
|
label_rows = load_labels(label_csv)
|
|
sample_weights, raw_weights = resolve_sample_weights(
|
|
label_rows, duplicate_policy=duplicate_policy
|
|
)
|
|
|
|
items, mapping, stats = collect_ply_items(
|
|
data_root=data_root,
|
|
sample_weights=sample_weights,
|
|
relative_paths=relative_paths,
|
|
)
|
|
|
|
num_duplicates = sum(1 for _, ws in raw_weights.items() if len(ws) > 1)
|
|
|
|
out = {
|
|
"meta": {
|
|
"data_root": str(data_root),
|
|
"label_csv": str(label_csv),
|
|
"weight_column": "F",
|
|
"duplicate_policy": duplicate_policy,
|
|
"relative_paths": relative_paths,
|
|
"num_label_rows": len(label_rows),
|
|
"num_sample_ids": len(sample_weights),
|
|
"num_sample_ids_with_duplicate_rows": num_duplicates,
|
|
**stats,
|
|
},
|
|
"items": items,
|
|
"mapping": mapping,
|
|
}
|
|
|
|
output_json.parent.mkdir(parents=True, exist_ok=True)
|
|
output_json.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
return out
|
|
|
|
|
|
def prune_missing_files(index_json: Path, data_root: Optional[Path] = None) -> dict:
|
|
"""
|
|
Load index JSON, remove items whose PLY file does not exist, and return updated index.
|
|
"""
|
|
with index_json.open("r", encoding="utf-8") as f:
|
|
index = json.load(f)
|
|
|
|
items = index.get("items", [])
|
|
meta_root = index.get("meta", {}).get("data_root", None)
|
|
root = Path(data_root or meta_root or "/").expanduser().resolve()
|
|
|
|
kept: List[dict] = []
|
|
removed = 0
|
|
|
|
for item in items:
|
|
ply_str = item.get("ply", "")
|
|
ply_path = Path(ply_str)
|
|
if not ply_path.is_absolute():
|
|
ply_path = root / ply_path
|
|
ply_path = ply_path.expanduser().resolve()
|
|
|
|
if ply_path.exists():
|
|
kept.append(item)
|
|
else:
|
|
removed += 1
|
|
|
|
mapping = {it["ply"]: it["weight_g"] for it in kept}
|
|
meta = index.get("meta", {})
|
|
meta["total_plys"] = len(kept)
|
|
meta["pruned_missing"] = removed
|
|
|
|
out = {
|
|
"meta": meta,
|
|
"items": kept,
|
|
"mapping": mapping,
|
|
}
|
|
return out
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Create a JSON mapping each .ply to weight (grams) using label.csv column F."
|
|
)
|
|
parser.add_argument(
|
|
"--data-root",
|
|
type=str,
|
|
default=DEFAULT_DATA_ROOT,
|
|
help=f"Dataset root containing { '{sample_id}/cloud/*.ply' } (default: {DEFAULT_DATA_ROOT})",
|
|
)
|
|
parser.add_argument(
|
|
"--label-csv",
|
|
type=str,
|
|
default=DEFAULT_LABEL_CSV,
|
|
help=f"Path to label.csv (default: {DEFAULT_LABEL_CSV})",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default="weight_estimator/dataset_index.json",
|
|
help="Output JSON path (default: weight_estimator/dataset_index.json)",
|
|
)
|
|
parser.add_argument(
|
|
"--duplicate-policy",
|
|
type=str,
|
|
default="mean",
|
|
choices=["mean", "first", "error"],
|
|
help="How to resolve duplicate sample_id rows in label.csv (default: mean).",
|
|
)
|
|
parser.add_argument(
|
|
"--relative-paths",
|
|
action="store_true",
|
|
help="Store PLY paths relative to --data-root instead of absolute paths.",
|
|
)
|
|
parser.add_argument(
|
|
"--prune-missing",
|
|
action="store_true",
|
|
help="Load existing index, remove entries for non-existent PLY files, and overwrite.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
output_json = Path(args.output).expanduser().resolve()
|
|
|
|
if args.prune_missing:
|
|
if not output_json.exists():
|
|
raise SystemExit(f"Index JSON not found (required for --prune-missing): {output_json}")
|
|
data_root = Path(args.data_root).expanduser().resolve() if args.data_root else None
|
|
out = prune_missing_files(output_json, data_root=data_root)
|
|
output_json.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
meta = out["meta"]
|
|
print("Pruned dataset index (removed non-existent PLY entries).")
|
|
print(f" output: {output_json}")
|
|
print(f" total_plys: {meta['total_plys']} (removed {meta.get('pruned_missing', 0)} missing)")
|
|
return
|
|
|
|
data_root = Path(args.data_root).expanduser().resolve()
|
|
label_csv = Path(args.label_csv).expanduser().resolve()
|
|
|
|
if not data_root.exists():
|
|
raise SystemExit(f"data root does not exist: {data_root}")
|
|
if not label_csv.exists():
|
|
raise SystemExit(f"label csv does not exist: {label_csv}")
|
|
|
|
out = build_index(
|
|
data_root=data_root,
|
|
label_csv=label_csv,
|
|
output_json=output_json,
|
|
duplicate_policy=args.duplicate_policy,
|
|
relative_paths=args.relative_paths,
|
|
)
|
|
|
|
meta = out["meta"]
|
|
print("Dataset index written.")
|
|
print(f" output: {output_json}")
|
|
print(f" total_plys: {meta['total_plys']}")
|
|
print(f" missing_weights: {meta['missing_weights']}")
|
|
print(f" missing_folders: {meta['missing_folders']}")
|
|
print(f" duplicate_policy: {meta['duplicate_policy']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|