288 lines
8.9 KiB
Python
Executable File
288 lines
8.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Filter PLY point clouds based on ground-truth length from label.csv.
|
|
|
|
Scale = actual_length_mm / pc_length_mm (PCA length of point cloud).
|
|
|
|
Actions:
|
|
- scale < 1 or scale > 1.2: Remove (delete) PLY - bad quality
|
|
- 1.0 <= scale <= 1.2: Leave as-is (no rescale)
|
|
|
|
Filters out samples where actual_length > max_length_mm (default 380mm) as invalid.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
from pathlib import Path
|
|
from typing import Dict, Optional
|
|
|
|
import numpy as np
|
|
|
|
try:
|
|
import open3d as o3d
|
|
O3D_AVAILABLE = True
|
|
except ImportError:
|
|
O3D_AVAILABLE = False
|
|
|
|
|
|
DEFAULT_DATA_ROOT = "/home/ubuntu/data/fish/2025-11-19-output"
|
|
DEFAULT_LABEL_CSV = "/home/ubuntu/projects/FishMeasure/measure/data/label.csv"
|
|
DEFAULT_MAX_LENGTH_MM = 380.0
|
|
|
|
|
|
def estimate_pointcloud_length_pca(points: np.ndarray) -> float:
|
|
"""Estimate length as extent along 1st PCA axis (mm)."""
|
|
if points is None or points.ndim != 2 or points.shape[0] < 3 or points.shape[1] < 3:
|
|
return float("nan")
|
|
pts = points[:, :3].astype(np.float32, copy=False)
|
|
pts = pts - pts.mean(axis=0, keepdims=True)
|
|
try:
|
|
_u, _s, vt = np.linalg.svd(pts, full_matrices=False)
|
|
axis = vt[0]
|
|
proj = pts @ axis
|
|
return float(np.max(proj) - np.min(proj))
|
|
except Exception:
|
|
return float("nan")
|
|
|
|
|
|
def load_lengths_from_csv(
|
|
label_csv: Path,
|
|
max_length_mm: float = DEFAULT_MAX_LENGTH_MM,
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Load sample_id -> length_mm from label.csv.
|
|
Column B (index 1): sample_id
|
|
Column C (index 2): length in cm
|
|
|
|
Excludes rows where length_cm * 10 > max_length_mm.
|
|
For duplicate sample_ids, uses mean of valid lengths.
|
|
"""
|
|
raw: Dict[str, list] = {}
|
|
with label_csv.open("r", encoding="utf-8") as f:
|
|
reader = csv.reader(f)
|
|
for row in reader:
|
|
if len(row) < 3:
|
|
continue
|
|
sample_id = (row[1] or "").strip()
|
|
if not sample_id or sample_id.lower() == "xxx":
|
|
continue
|
|
try:
|
|
length_cm = float((row[2] or "").strip())
|
|
except (ValueError, TypeError):
|
|
continue
|
|
length_mm = length_cm * 10.0
|
|
if length_mm > max_length_mm:
|
|
continue # Skip invalid
|
|
raw.setdefault(sample_id, []).append(length_mm)
|
|
|
|
resolved: Dict[str, float] = {}
|
|
for sample_id, lengths in raw.items():
|
|
if lengths:
|
|
resolved[sample_id] = float(np.mean(lengths))
|
|
return resolved
|
|
|
|
|
|
def load_ply_points(ply_path: Path):
|
|
"""Load points from PLY (point cloud or mesh). Returns (points, colors or None, mesh or None)."""
|
|
# Try point cloud first
|
|
pcd = o3d.io.read_point_cloud(str(ply_path))
|
|
if pcd is not None and len(pcd.points) > 0:
|
|
pts = np.asarray(pcd.points, dtype=np.float64)
|
|
colors = np.asarray(pcd.colors, dtype=np.float64) if pcd.has_colors() else None
|
|
return pts, colors, None
|
|
|
|
# Fallback: mesh
|
|
mesh = o3d.io.read_triangle_mesh(str(ply_path))
|
|
if mesh is not None and len(mesh.vertices) > 0:
|
|
pts = np.asarray(mesh.vertices, dtype=np.float64)
|
|
return pts, None, mesh
|
|
|
|
return None, None, None
|
|
|
|
|
|
def save_ply(ply_path: Path, points: np.ndarray, colors: Optional[np.ndarray] = None, mesh: Optional[o3d.geometry.TriangleMesh] = None):
|
|
"""Save rescaled points to PLY."""
|
|
if mesh is not None:
|
|
mesh.vertices = o3d.utility.Vector3dVector(points)
|
|
o3d.io.write_triangle_mesh(str(ply_path), mesh)
|
|
else:
|
|
pcd = o3d.geometry.PointCloud()
|
|
pcd.points = o3d.utility.Vector3dVector(points)
|
|
if colors is not None:
|
|
pcd.colors = o3d.utility.Vector3dVector(colors)
|
|
o3d.io.write_point_cloud(str(ply_path), pcd)
|
|
|
|
|
|
def rescale_ply_folder(
|
|
cloud_dir: Path,
|
|
actual_length_mm: float,
|
|
dry_run: bool = False,
|
|
remove_bad: bool = True,
|
|
verbose: bool = True,
|
|
) -> tuple[int, int, int]:
|
|
"""
|
|
Process PLYs in cloud_dir based on scale = actual_length_mm / pc_length.
|
|
|
|
- scale < 1 or scale > 1.2: Remove (delete) PLY
|
|
- 1.0 <= scale <= 1.2: Leave as-is
|
|
|
|
Returns (kept_count, removed_count, skipped_count).
|
|
"""
|
|
ply_files = sorted(cloud_dir.glob("*.ply"))
|
|
kept = 0
|
|
removed = 0
|
|
skipped = 0
|
|
|
|
for ply_path in ply_files:
|
|
data = load_ply_points(ply_path)
|
|
if data[0] is None:
|
|
skipped += 1
|
|
if verbose:
|
|
print(f" Skip (empty): {ply_path.name}")
|
|
continue
|
|
|
|
points, colors, mesh = data
|
|
pc_length = estimate_pointcloud_length_pca(points)
|
|
if not np.isfinite(pc_length) or pc_length <= 0:
|
|
skipped += 1
|
|
if verbose:
|
|
print(f" Skip (invalid length): {ply_path.name}")
|
|
continue
|
|
|
|
scale = actual_length_mm / pc_length
|
|
|
|
# Remove bad PLYs: scale < 1 or scale > 1.2
|
|
if scale < 1.0 or scale > 1.2:
|
|
if not dry_run and remove_bad:
|
|
ply_path.unlink()
|
|
removed += 1
|
|
if verbose:
|
|
action = "Remove" if remove_bad else "Skip (would remove)"
|
|
print(f" {action} (bad scale={scale:.3f}): {ply_path.name}")
|
|
continue
|
|
|
|
# 1.0 <= scale <= 1.2: leave as-is
|
|
kept += 1
|
|
if verbose:
|
|
print(f" Keep (scale={scale:.3f}): {ply_path.name}")
|
|
|
|
return kept, removed, skipped
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Rescale PLY point clouds to match ground-truth length from label.csv"
|
|
)
|
|
parser.add_argument(
|
|
"--data-root",
|
|
type=str,
|
|
default=DEFAULT_DATA_ROOT,
|
|
help=f"Data root with subfolders (default: {DEFAULT_DATA_ROOT})",
|
|
)
|
|
parser.add_argument(
|
|
"--label-csv",
|
|
type=str,
|
|
default=DEFAULT_LABEL_CSV,
|
|
help=f"Path to label.csv (default: {DEFAULT_LABEL_CSV})",
|
|
)
|
|
parser.add_argument(
|
|
"--max-length-mm",
|
|
type=float,
|
|
default=DEFAULT_MAX_LENGTH_MM,
|
|
help=f"Exclude samples with length > this (mm). Default: {DEFAULT_MAX_LENGTH_MM}",
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Print what would be done without modifying files",
|
|
)
|
|
parser.add_argument(
|
|
"--quiet",
|
|
action="store_true",
|
|
help="Less verbose output",
|
|
)
|
|
parser.add_argument(
|
|
"--no-remove",
|
|
action="store_true",
|
|
help="Do not delete bad PLYs (scale<1 or >1.5); only rescale and skip",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not O3D_AVAILABLE:
|
|
print("ERROR: open3d is required. pip install open3d")
|
|
return
|
|
|
|
data_root = Path(args.data_root).expanduser().resolve()
|
|
label_csv = Path(args.label_csv).expanduser().resolve()
|
|
|
|
if not data_root.exists():
|
|
print(f"ERROR: Data root not found: {data_root}")
|
|
return
|
|
if not label_csv.exists():
|
|
print(f"ERROR: Label CSV not found: {label_csv}")
|
|
return
|
|
|
|
print(f"Loading lengths from: {label_csv}")
|
|
lengths = load_lengths_from_csv(label_csv, max_length_mm=args.max_length_mm)
|
|
print(f" Found {len(lengths)} samples with valid length (<= {args.max_length_mm}mm)")
|
|
|
|
# Count excluded
|
|
all_ids = set()
|
|
with label_csv.open("r", encoding="utf-8") as f:
|
|
reader = csv.reader(f)
|
|
for row in reader:
|
|
if len(row) >= 3 and row[1].strip() and row[1].strip().lower() != "xxx":
|
|
all_ids.add(row[1].strip())
|
|
excluded = all_ids - set(lengths.keys())
|
|
if excluded:
|
|
print(f" Excluded {len(excluded)} sample(s) with length > {args.max_length_mm}mm or missing length")
|
|
|
|
if args.dry_run:
|
|
print("\n[DRY RUN - no files will be modified]\n")
|
|
|
|
total_kept = 0
|
|
total_removed = 0
|
|
total_skipped = 0
|
|
folders_processed = 0
|
|
folders_skipped = 0
|
|
|
|
for sample_id in sorted(lengths.keys()):
|
|
cloud_dir = data_root / sample_id / "cloud"
|
|
if not cloud_dir.exists():
|
|
folders_skipped += 1
|
|
continue
|
|
|
|
actual_length_mm = lengths[sample_id]
|
|
if not args.quiet:
|
|
print(f"\n[{sample_id}] actual_length={actual_length_mm:.1f}mm")
|
|
|
|
kept, removed, skipped = rescale_ply_folder(
|
|
cloud_dir,
|
|
actual_length_mm,
|
|
dry_run=args.dry_run,
|
|
remove_bad=not args.no_remove,
|
|
verbose=not args.quiet,
|
|
)
|
|
total_kept += kept
|
|
total_removed += removed
|
|
total_skipped += skipped
|
|
folders_processed += 1
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Summary")
|
|
print("=" * 60)
|
|
print(f" Folders processed: {folders_processed}")
|
|
print(f" Folders skipped (no cloud): {folders_skipped}")
|
|
print(f" PLYs kept (1.0<=scale<=1.2): {total_kept}")
|
|
print(f" PLYs removed (scale<1 or >1.2): {total_removed}")
|
|
print(f" PLYs skipped (empty/invalid): {total_skipped}")
|
|
if args.dry_run:
|
|
print(" [DRY RUN - no files were modified]")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|