#!/usr/bin/env python3 """ Build a dataset index mapping point clouds (.ply) to weight labels. Data: - Point clouds live under: /home/ubuntu/data/fish/2025-11-19-output/{sample_id}/cloud/*.ply - Labels live at: /home/ubuntu/projects/FishMeasure/measure/data/label.csv CSV format: - Column B (index 1): sample_id (svo2 name / folder name) - Column F (index 5): weight in grams (float) Output JSON: { "meta": {...}, "items": [ {"ply": "", "sample_id": "...", "weight_g": 123.45} ], "mapping": { "": 123.45 } } """ from __future__ import annotations import argparse import csv import json from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple DEFAULT_DATA_ROOT = "/home/ubuntu/data/fish/2025-11-19-output" DEFAULT_LABEL_CSV = "/home/ubuntu/projects/FishMeasure/measure/data/label.csv" @dataclass(frozen=True) class LabelRow: sample_id: str weight_g: Optional[float] def _parse_weight_g(cell: str) -> Optional[float]: cell = (cell or "").strip() if not cell: return None try: return float(cell) except Exception: return None def load_labels(label_csv: Path) -> List[LabelRow]: """ Load labels from label.csv. Returns a list of LabelRow (keeps duplicates). """ rows: List[LabelRow] = [] with label_csv.open("r", encoding="utf-8") as f: reader = csv.reader(f) for r in reader: if len(r) < 2: continue sample_id = (r[1] or "").strip() if not sample_id or sample_id.lower() == "xxx": continue weight_g = _parse_weight_g(r[5] if len(r) > 5 else "") rows.append(LabelRow(sample_id=sample_id, weight_g=weight_g)) return rows def resolve_sample_weights( label_rows: List[LabelRow], duplicate_policy: str = "mean", ) -> Tuple[Dict[str, Optional[float]], Dict[str, List[Optional[float]]]]: """ Resolve sample_id -> single weight according to policy. duplicate_policy: - mean: average of non-null weights - first: first non-null weight - error: raise if a sample_id has multiple distinct non-null weights Returns: - resolved: {sample_id: weight_g or None} - raw: {sample_id: [weight_g_or_None, ...]} """ raw: Dict[str, List[Optional[float]]] = {} for row in label_rows: raw.setdefault(row.sample_id, []).append(row.weight_g) resolved: Dict[str, Optional[float]] = {} for sample_id, weights in raw.items(): non_null = [w for w in weights if w is not None] if not non_null: resolved[sample_id] = None continue if duplicate_policy == "first": resolved[sample_id] = non_null[0] continue if duplicate_policy == "mean": resolved[sample_id] = sum(non_null) / len(non_null) continue if duplicate_policy == "error": distinct = sorted(set(non_null)) if len(distinct) > 1: raise ValueError( f"Duplicate sample_id with multiple weights: {sample_id}: {distinct}" ) resolved[sample_id] = distinct[0] continue raise ValueError(f"Unknown duplicate_policy: {duplicate_policy}") return resolved, raw def collect_ply_items( data_root: Path, sample_weights: Dict[str, Optional[float]], relative_paths: bool, ) -> Tuple[List[dict], Dict[str, float], Dict[str, int]]: """ Build (items list, mapping dict, stats dict). """ items: List[dict] = [] mapping: Dict[str, float] = {} missing_folders = 0 missing_weights = 0 total_plys = 0 for sample_id in sorted(sample_weights.keys()): weight_g = sample_weights[sample_id] if weight_g is None: missing_weights += 1 continue cloud_dir = data_root / sample_id / "cloud" if not cloud_dir.exists(): missing_folders += 1 continue ply_files = sorted(cloud_dir.glob("*.ply")) if not ply_files: # still count as missing folder-like data case missing_folders += 1 continue for ply in ply_files: total_plys += 1 ply_path = ply if relative_paths: try: ply_path = ply.relative_to(data_root) except Exception: ply_path = ply ply_key = str(ply_path) item = { "ply": ply_key, "sample_id": sample_id, "weight_g": float(weight_g), } items.append(item) mapping[ply_key] = float(weight_g) stats = { "missing_folders": missing_folders, "missing_weights": missing_weights, "total_plys": total_plys, } return items, mapping, stats def build_index( data_root: Path, label_csv: Path, output_json: Path, duplicate_policy: str = "mean", relative_paths: bool = False, ) -> dict: label_rows = load_labels(label_csv) sample_weights, raw_weights = resolve_sample_weights( label_rows, duplicate_policy=duplicate_policy ) items, mapping, stats = collect_ply_items( data_root=data_root, sample_weights=sample_weights, relative_paths=relative_paths, ) num_duplicates = sum(1 for _, ws in raw_weights.items() if len(ws) > 1) out = { "meta": { "data_root": str(data_root), "label_csv": str(label_csv), "weight_column": "F", "duplicate_policy": duplicate_policy, "relative_paths": relative_paths, "num_label_rows": len(label_rows), "num_sample_ids": len(sample_weights), "num_sample_ids_with_duplicate_rows": num_duplicates, **stats, }, "items": items, "mapping": mapping, } output_json.parent.mkdir(parents=True, exist_ok=True) output_json.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8") return out def prune_missing_files(index_json: Path, data_root: Optional[Path] = None) -> dict: """ Load index JSON, remove items whose PLY file does not exist, and return updated index. """ with index_json.open("r", encoding="utf-8") as f: index = json.load(f) items = index.get("items", []) meta_root = index.get("meta", {}).get("data_root", None) root = Path(data_root or meta_root or "/").expanduser().resolve() kept: List[dict] = [] removed = 0 for item in items: ply_str = item.get("ply", "") ply_path = Path(ply_str) if not ply_path.is_absolute(): ply_path = root / ply_path ply_path = ply_path.expanduser().resolve() if ply_path.exists(): kept.append(item) else: removed += 1 mapping = {it["ply"]: it["weight_g"] for it in kept} meta = index.get("meta", {}) meta["total_plys"] = len(kept) meta["pruned_missing"] = removed out = { "meta": meta, "items": kept, "mapping": mapping, } return out def main() -> None: parser = argparse.ArgumentParser( description="Create a JSON mapping each .ply to weight (grams) using label.csv column F." ) parser.add_argument( "--data-root", type=str, default=DEFAULT_DATA_ROOT, help=f"Dataset root containing { '{sample_id}/cloud/*.ply' } (default: {DEFAULT_DATA_ROOT})", ) parser.add_argument( "--label-csv", type=str, default=DEFAULT_LABEL_CSV, help=f"Path to label.csv (default: {DEFAULT_LABEL_CSV})", ) parser.add_argument( "--output", type=str, default="weight_estimator/dataset_index.json", help="Output JSON path (default: weight_estimator/dataset_index.json)", ) parser.add_argument( "--duplicate-policy", type=str, default="mean", choices=["mean", "first", "error"], help="How to resolve duplicate sample_id rows in label.csv (default: mean).", ) parser.add_argument( "--relative-paths", action="store_true", help="Store PLY paths relative to --data-root instead of absolute paths.", ) parser.add_argument( "--prune-missing", action="store_true", help="Load existing index, remove entries for non-existent PLY files, and overwrite.", ) args = parser.parse_args() output_json = Path(args.output).expanduser().resolve() if args.prune_missing: if not output_json.exists(): raise SystemExit(f"Index JSON not found (required for --prune-missing): {output_json}") data_root = Path(args.data_root).expanduser().resolve() if args.data_root else None out = prune_missing_files(output_json, data_root=data_root) output_json.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8") meta = out["meta"] print("Pruned dataset index (removed non-existent PLY entries).") print(f" output: {output_json}") print(f" total_plys: {meta['total_plys']} (removed {meta.get('pruned_missing', 0)} missing)") return data_root = Path(args.data_root).expanduser().resolve() label_csv = Path(args.label_csv).expanduser().resolve() if not data_root.exists(): raise SystemExit(f"data root does not exist: {data_root}") if not label_csv.exists(): raise SystemExit(f"label csv does not exist: {label_csv}") out = build_index( data_root=data_root, label_csv=label_csv, output_json=output_json, duplicate_policy=args.duplicate_policy, relative_paths=args.relative_paths, ) meta = out["meta"] print("Dataset index written.") print(f" output: {output_json}") print(f" total_plys: {meta['total_plys']}") print(f" missing_weights: {meta['missing_weights']}") print(f" missing_folders: {meta['missing_folders']}") print(f" duplicate_policy: {meta['duplicate_policy']}") if __name__ == "__main__": main()