Files
FishServer/FishMeasure/utils/rescale_ply_by_label_length.py
2026-04-08 19:32:23 +08:00

288 lines
8.9 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Filter PLY point clouds based on ground-truth length from label.csv.
Scale = actual_length_mm / pc_length_mm (PCA length of point cloud).
Actions:
- scale < 1 or scale > 1.2: Remove (delete) PLY - bad quality
- 1.0 <= scale <= 1.2: Leave as-is (no rescale)
Filters out samples where actual_length > max_length_mm (default 380mm) as invalid.
"""
from __future__ import annotations
import argparse
import csv
from pathlib import Path
from typing import Dict, Optional
import numpy as np
try:
import open3d as o3d
O3D_AVAILABLE = True
except ImportError:
O3D_AVAILABLE = False
DEFAULT_DATA_ROOT = "/home/ubuntu/data/fish/2025-11-19-output"
DEFAULT_LABEL_CSV = "/home/ubuntu/projects/FishMeasure/measure/data/label.csv"
DEFAULT_MAX_LENGTH_MM = 380.0
def estimate_pointcloud_length_pca(points: np.ndarray) -> float:
"""Estimate length as extent along 1st PCA axis (mm)."""
if points is None or points.ndim != 2 or points.shape[0] < 3 or points.shape[1] < 3:
return float("nan")
pts = points[:, :3].astype(np.float32, copy=False)
pts = pts - pts.mean(axis=0, keepdims=True)
try:
_u, _s, vt = np.linalg.svd(pts, full_matrices=False)
axis = vt[0]
proj = pts @ axis
return float(np.max(proj) - np.min(proj))
except Exception:
return float("nan")
def load_lengths_from_csv(
label_csv: Path,
max_length_mm: float = DEFAULT_MAX_LENGTH_MM,
) -> Dict[str, float]:
"""
Load sample_id -> length_mm from label.csv.
Column B (index 1): sample_id
Column C (index 2): length in cm
Excludes rows where length_cm * 10 > max_length_mm.
For duplicate sample_ids, uses mean of valid lengths.
"""
raw: Dict[str, list] = {}
with label_csv.open("r", encoding="utf-8") as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 3:
continue
sample_id = (row[1] or "").strip()
if not sample_id or sample_id.lower() == "xxx":
continue
try:
length_cm = float((row[2] or "").strip())
except (ValueError, TypeError):
continue
length_mm = length_cm * 10.0
if length_mm > max_length_mm:
continue # Skip invalid
raw.setdefault(sample_id, []).append(length_mm)
resolved: Dict[str, float] = {}
for sample_id, lengths in raw.items():
if lengths:
resolved[sample_id] = float(np.mean(lengths))
return resolved
def load_ply_points(ply_path: Path):
"""Load points from PLY (point cloud or mesh). Returns (points, colors or None, mesh or None)."""
# Try point cloud first
pcd = o3d.io.read_point_cloud(str(ply_path))
if pcd is not None and len(pcd.points) > 0:
pts = np.asarray(pcd.points, dtype=np.float64)
colors = np.asarray(pcd.colors, dtype=np.float64) if pcd.has_colors() else None
return pts, colors, None
# Fallback: mesh
mesh = o3d.io.read_triangle_mesh(str(ply_path))
if mesh is not None and len(mesh.vertices) > 0:
pts = np.asarray(mesh.vertices, dtype=np.float64)
return pts, None, mesh
return None, None, None
def save_ply(ply_path: Path, points: np.ndarray, colors: Optional[np.ndarray] = None, mesh: Optional[o3d.geometry.TriangleMesh] = None):
"""Save rescaled points to PLY."""
if mesh is not None:
mesh.vertices = o3d.utility.Vector3dVector(points)
o3d.io.write_triangle_mesh(str(ply_path), mesh)
else:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
if colors is not None:
pcd.colors = o3d.utility.Vector3dVector(colors)
o3d.io.write_point_cloud(str(ply_path), pcd)
def rescale_ply_folder(
cloud_dir: Path,
actual_length_mm: float,
dry_run: bool = False,
remove_bad: bool = True,
verbose: bool = True,
) -> tuple[int, int, int]:
"""
Process PLYs in cloud_dir based on scale = actual_length_mm / pc_length.
- scale < 1 or scale > 1.2: Remove (delete) PLY
- 1.0 <= scale <= 1.2: Leave as-is
Returns (kept_count, removed_count, skipped_count).
"""
ply_files = sorted(cloud_dir.glob("*.ply"))
kept = 0
removed = 0
skipped = 0
for ply_path in ply_files:
data = load_ply_points(ply_path)
if data[0] is None:
skipped += 1
if verbose:
print(f" Skip (empty): {ply_path.name}")
continue
points, colors, mesh = data
pc_length = estimate_pointcloud_length_pca(points)
if not np.isfinite(pc_length) or pc_length <= 0:
skipped += 1
if verbose:
print(f" Skip (invalid length): {ply_path.name}")
continue
scale = actual_length_mm / pc_length
# Remove bad PLYs: scale < 1 or scale > 1.2
if scale < 1.0 or scale > 1.2:
if not dry_run and remove_bad:
ply_path.unlink()
removed += 1
if verbose:
action = "Remove" if remove_bad else "Skip (would remove)"
print(f" {action} (bad scale={scale:.3f}): {ply_path.name}")
continue
# 1.0 <= scale <= 1.2: leave as-is
kept += 1
if verbose:
print(f" Keep (scale={scale:.3f}): {ply_path.name}")
return kept, removed, skipped
def main() -> None:
parser = argparse.ArgumentParser(
description="Rescale PLY point clouds to match ground-truth length from label.csv"
)
parser.add_argument(
"--data-root",
type=str,
default=DEFAULT_DATA_ROOT,
help=f"Data root with subfolders (default: {DEFAULT_DATA_ROOT})",
)
parser.add_argument(
"--label-csv",
type=str,
default=DEFAULT_LABEL_CSV,
help=f"Path to label.csv (default: {DEFAULT_LABEL_CSV})",
)
parser.add_argument(
"--max-length-mm",
type=float,
default=DEFAULT_MAX_LENGTH_MM,
help=f"Exclude samples with length > this (mm). Default: {DEFAULT_MAX_LENGTH_MM}",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print what would be done without modifying files",
)
parser.add_argument(
"--quiet",
action="store_true",
help="Less verbose output",
)
parser.add_argument(
"--no-remove",
action="store_true",
help="Do not delete bad PLYs (scale<1 or >1.5); only rescale and skip",
)
args = parser.parse_args()
if not O3D_AVAILABLE:
print("ERROR: open3d is required. pip install open3d")
return
data_root = Path(args.data_root).expanduser().resolve()
label_csv = Path(args.label_csv).expanduser().resolve()
if not data_root.exists():
print(f"ERROR: Data root not found: {data_root}")
return
if not label_csv.exists():
print(f"ERROR: Label CSV not found: {label_csv}")
return
print(f"Loading lengths from: {label_csv}")
lengths = load_lengths_from_csv(label_csv, max_length_mm=args.max_length_mm)
print(f" Found {len(lengths)} samples with valid length (<= {args.max_length_mm}mm)")
# Count excluded
all_ids = set()
with label_csv.open("r", encoding="utf-8") as f:
reader = csv.reader(f)
for row in reader:
if len(row) >= 3 and row[1].strip() and row[1].strip().lower() != "xxx":
all_ids.add(row[1].strip())
excluded = all_ids - set(lengths.keys())
if excluded:
print(f" Excluded {len(excluded)} sample(s) with length > {args.max_length_mm}mm or missing length")
if args.dry_run:
print("\n[DRY RUN - no files will be modified]\n")
total_kept = 0
total_removed = 0
total_skipped = 0
folders_processed = 0
folders_skipped = 0
for sample_id in sorted(lengths.keys()):
cloud_dir = data_root / sample_id / "cloud"
if not cloud_dir.exists():
folders_skipped += 1
continue
actual_length_mm = lengths[sample_id]
if not args.quiet:
print(f"\n[{sample_id}] actual_length={actual_length_mm:.1f}mm")
kept, removed, skipped = rescale_ply_folder(
cloud_dir,
actual_length_mm,
dry_run=args.dry_run,
remove_bad=not args.no_remove,
verbose=not args.quiet,
)
total_kept += kept
total_removed += removed
total_skipped += skipped
folders_processed += 1
print("\n" + "=" * 60)
print("Summary")
print("=" * 60)
print(f" Folders processed: {folders_processed}")
print(f" Folders skipped (no cloud): {folders_skipped}")
print(f" PLYs kept (1.0<=scale<=1.2): {total_kept}")
print(f" PLYs removed (scale<1 or >1.2): {total_removed}")
print(f" PLYs skipped (empty/invalid): {total_skipped}")
if args.dry_run:
print(" [DRY RUN - no files were modified]")
if __name__ == "__main__":
main()