#!/usr/bin/env python3 """ Filter PLY point clouds based on ground-truth length from label.csv. Scale = actual_length_mm / pc_length_mm (PCA length of point cloud). Actions: - scale < 1 or scale > 1.2: Remove (delete) PLY - bad quality - 1.0 <= scale <= 1.2: Leave as-is (no rescale) Filters out samples where actual_length > max_length_mm (default 380mm) as invalid. """ from __future__ import annotations import argparse import csv from pathlib import Path from typing import Dict, Optional import numpy as np try: import open3d as o3d O3D_AVAILABLE = True except ImportError: O3D_AVAILABLE = False DEFAULT_DATA_ROOT = "/home/ubuntu/data/fish/2025-11-19-output" DEFAULT_LABEL_CSV = "/home/ubuntu/projects/FishMeasure/measure/data/label.csv" DEFAULT_MAX_LENGTH_MM = 380.0 def estimate_pointcloud_length_pca(points: np.ndarray) -> float: """Estimate length as extent along 1st PCA axis (mm).""" if points is None or points.ndim != 2 or points.shape[0] < 3 or points.shape[1] < 3: return float("nan") pts = points[:, :3].astype(np.float32, copy=False) pts = pts - pts.mean(axis=0, keepdims=True) try: _u, _s, vt = np.linalg.svd(pts, full_matrices=False) axis = vt[0] proj = pts @ axis return float(np.max(proj) - np.min(proj)) except Exception: return float("nan") def load_lengths_from_csv( label_csv: Path, max_length_mm: float = DEFAULT_MAX_LENGTH_MM, ) -> Dict[str, float]: """ Load sample_id -> length_mm from label.csv. Column B (index 1): sample_id Column C (index 2): length in cm Excludes rows where length_cm * 10 > max_length_mm. For duplicate sample_ids, uses mean of valid lengths. """ raw: Dict[str, list] = {} with label_csv.open("r", encoding="utf-8") as f: reader = csv.reader(f) for row in reader: if len(row) < 3: continue sample_id = (row[1] or "").strip() if not sample_id or sample_id.lower() == "xxx": continue try: length_cm = float((row[2] or "").strip()) except (ValueError, TypeError): continue length_mm = length_cm * 10.0 if length_mm > max_length_mm: continue # Skip invalid raw.setdefault(sample_id, []).append(length_mm) resolved: Dict[str, float] = {} for sample_id, lengths in raw.items(): if lengths: resolved[sample_id] = float(np.mean(lengths)) return resolved def load_ply_points(ply_path: Path): """Load points from PLY (point cloud or mesh). Returns (points, colors or None, mesh or None).""" # Try point cloud first pcd = o3d.io.read_point_cloud(str(ply_path)) if pcd is not None and len(pcd.points) > 0: pts = np.asarray(pcd.points, dtype=np.float64) colors = np.asarray(pcd.colors, dtype=np.float64) if pcd.has_colors() else None return pts, colors, None # Fallback: mesh mesh = o3d.io.read_triangle_mesh(str(ply_path)) if mesh is not None and len(mesh.vertices) > 0: pts = np.asarray(mesh.vertices, dtype=np.float64) return pts, None, mesh return None, None, None def save_ply(ply_path: Path, points: np.ndarray, colors: Optional[np.ndarray] = None, mesh: Optional[o3d.geometry.TriangleMesh] = None): """Save rescaled points to PLY.""" if mesh is not None: mesh.vertices = o3d.utility.Vector3dVector(points) o3d.io.write_triangle_mesh(str(ply_path), mesh) else: pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) if colors is not None: pcd.colors = o3d.utility.Vector3dVector(colors) o3d.io.write_point_cloud(str(ply_path), pcd) def rescale_ply_folder( cloud_dir: Path, actual_length_mm: float, dry_run: bool = False, remove_bad: bool = True, verbose: bool = True, ) -> tuple[int, int, int]: """ Process PLYs in cloud_dir based on scale = actual_length_mm / pc_length. - scale < 1 or scale > 1.2: Remove (delete) PLY - 1.0 <= scale <= 1.2: Leave as-is Returns (kept_count, removed_count, skipped_count). """ ply_files = sorted(cloud_dir.glob("*.ply")) kept = 0 removed = 0 skipped = 0 for ply_path in ply_files: data = load_ply_points(ply_path) if data[0] is None: skipped += 1 if verbose: print(f" Skip (empty): {ply_path.name}") continue points, colors, mesh = data pc_length = estimate_pointcloud_length_pca(points) if not np.isfinite(pc_length) or pc_length <= 0: skipped += 1 if verbose: print(f" Skip (invalid length): {ply_path.name}") continue scale = actual_length_mm / pc_length # Remove bad PLYs: scale < 1 or scale > 1.2 if scale < 1.0 or scale > 1.2: if not dry_run and remove_bad: ply_path.unlink() removed += 1 if verbose: action = "Remove" if remove_bad else "Skip (would remove)" print(f" {action} (bad scale={scale:.3f}): {ply_path.name}") continue # 1.0 <= scale <= 1.2: leave as-is kept += 1 if verbose: print(f" Keep (scale={scale:.3f}): {ply_path.name}") return kept, removed, skipped def main() -> None: parser = argparse.ArgumentParser( description="Rescale PLY point clouds to match ground-truth length from label.csv" ) parser.add_argument( "--data-root", type=str, default=DEFAULT_DATA_ROOT, help=f"Data root with subfolders (default: {DEFAULT_DATA_ROOT})", ) parser.add_argument( "--label-csv", type=str, default=DEFAULT_LABEL_CSV, help=f"Path to label.csv (default: {DEFAULT_LABEL_CSV})", ) parser.add_argument( "--max-length-mm", type=float, default=DEFAULT_MAX_LENGTH_MM, help=f"Exclude samples with length > this (mm). Default: {DEFAULT_MAX_LENGTH_MM}", ) parser.add_argument( "--dry-run", action="store_true", help="Print what would be done without modifying files", ) parser.add_argument( "--quiet", action="store_true", help="Less verbose output", ) parser.add_argument( "--no-remove", action="store_true", help="Do not delete bad PLYs (scale<1 or >1.5); only rescale and skip", ) args = parser.parse_args() if not O3D_AVAILABLE: print("ERROR: open3d is required. pip install open3d") return data_root = Path(args.data_root).expanduser().resolve() label_csv = Path(args.label_csv).expanduser().resolve() if not data_root.exists(): print(f"ERROR: Data root not found: {data_root}") return if not label_csv.exists(): print(f"ERROR: Label CSV not found: {label_csv}") return print(f"Loading lengths from: {label_csv}") lengths = load_lengths_from_csv(label_csv, max_length_mm=args.max_length_mm) print(f" Found {len(lengths)} samples with valid length (<= {args.max_length_mm}mm)") # Count excluded all_ids = set() with label_csv.open("r", encoding="utf-8") as f: reader = csv.reader(f) for row in reader: if len(row) >= 3 and row[1].strip() and row[1].strip().lower() != "xxx": all_ids.add(row[1].strip()) excluded = all_ids - set(lengths.keys()) if excluded: print(f" Excluded {len(excluded)} sample(s) with length > {args.max_length_mm}mm or missing length") if args.dry_run: print("\n[DRY RUN - no files will be modified]\n") total_kept = 0 total_removed = 0 total_skipped = 0 folders_processed = 0 folders_skipped = 0 for sample_id in sorted(lengths.keys()): cloud_dir = data_root / sample_id / "cloud" if not cloud_dir.exists(): folders_skipped += 1 continue actual_length_mm = lengths[sample_id] if not args.quiet: print(f"\n[{sample_id}] actual_length={actual_length_mm:.1f}mm") kept, removed, skipped = rescale_ply_folder( cloud_dir, actual_length_mm, dry_run=args.dry_run, remove_bad=not args.no_remove, verbose=not args.quiet, ) total_kept += kept total_removed += removed total_skipped += skipped folders_processed += 1 print("\n" + "=" * 60) print("Summary") print("=" * 60) print(f" Folders processed: {folders_processed}") print(f" Folders skipped (no cloud): {folders_skipped}") print(f" PLYs kept (1.0<=scale<=1.2): {total_kept}") print(f" PLYs removed (scale<1 or >1.2): {total_removed}") print(f" PLYs skipped (empty/invalid): {total_skipped}") if args.dry_run: print(" [DRY RUN - no files were modified]") if __name__ == "__main__": main()