Files
FishServer/FishMeasure/utils/keep_largest_cluster.py

379 lines
15 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Keep the largest cluster from a point cloud using DBSCAN clustering.
This script processes PLY files to identify clusters and keeps only the largest cluster,
removing smaller clusters and outliers.
"""
import argparse
import numpy as np
from pathlib import Path
import json
from typing import Tuple, Optional
try:
import open3d as o3d
O3D_AVAILABLE = True
except ImportError:
O3D_AVAILABLE = False
print("ERROR: open3d is required. Please install it: pip install open3d")
exit(1)
def keep_largest_cluster(points: np.ndarray, eps: float = 10.0, min_points: int = 10) -> Tuple[np.ndarray, dict]:
"""
Keep only the largest cluster from a point cloud using DBSCAN.
Args:
points: Point cloud array (N, 3) in mm
eps: DBSCAN eps parameter (maximum distance between points in the same cluster, in mm)
min_points: DBSCAN min_points parameter (minimum number of points to form a cluster)
Returns:
tuple: (filtered_points: np.ndarray, info: dict)
filtered_points: Points belonging to the largest cluster
info: Dictionary with clustering statistics
"""
if len(points) == 0:
return points, {"error": "Empty point cloud", "num_clusters": 0, "largest_cluster_size": 0}
if len(points) < min_points:
return points, {"warning": f"Not enough points for clustering (need at least {min_points})",
"num_clusters": 0, "largest_cluster_size": len(points)}
# Convert to Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points.astype(np.float64))
# Perform DBSCAN clustering
labels = np.array(pcd.cluster_dbscan(eps=eps, min_points=min_points, print_progress=False))
# Get unique cluster labels (excluding noise: -1)
unique_labels = np.unique(labels)
unique_labels = unique_labels[unique_labels >= 0] # Remove noise label (-1)
if len(unique_labels) == 0:
# No clusters found, all points are noise
return np.array([]), {"error": "No clusters found (all points are noise)",
"num_clusters": 0, "largest_cluster_size": 0, "noise_points": len(points)}
# Count points in each cluster
cluster_sizes = {}
for label in unique_labels:
cluster_sizes[label] = np.sum(labels == label)
# Find the largest cluster
largest_cluster_label = max(cluster_sizes, key=cluster_sizes.get)
largest_cluster_size = cluster_sizes[largest_cluster_label]
# Keep only points from the largest cluster
mask = labels == largest_cluster_label
filtered_points = points[mask]
# Count noise points
noise_count = np.sum(labels == -1)
info = {
"num_clusters": len(unique_labels),
"largest_cluster_label": int(largest_cluster_label),
"largest_cluster_size": int(largest_cluster_size),
"original_points": len(points),
"filtered_points": len(filtered_points),
"noise_points": int(noise_count),
"removed_points": len(points) - len(filtered_points),
"cluster_sizes": {int(k): int(v) for k, v in cluster_sizes.items()},
"mask": mask # Return mask for filtering colors
}
return filtered_points, info
def keep_largest_cluster_with_colors(points: np.ndarray, colors: np.ndarray,
eps: float = 10.0, min_points: int = 10) -> Tuple[np.ndarray, np.ndarray, dict]:
"""
Keep only the largest cluster from a point cloud using DBSCAN, preserving colors.
This is a convenience wrapper that filters both points and colors together.
Args:
points: Point cloud array (N, 3) in mm
colors: Color array (N, 3) in range [0, 255] or [0, 1]
eps: DBSCAN eps parameter (maximum distance between points in the same cluster, in mm)
min_points: DBSCAN min_points parameter (minimum number of points to form a cluster)
Returns:
tuple: (filtered_points: np.ndarray, filtered_colors: np.ndarray, info: dict)
filtered_points: Points belonging to the largest cluster
filtered_colors: Colors corresponding to filtered points
info: Dictionary with clustering statistics
"""
if colors is None or len(colors) == 0:
filtered_points, info = keep_largest_cluster(points, eps=eps, min_points=min_points)
return filtered_points, np.array([]), info
if len(points) != len(colors):
raise ValueError(f"Points and colors must have same length: {len(points)} != {len(colors)}")
if len(points) == 0:
return points, colors, {"error": "Empty point cloud", "num_clusters": 0, "largest_cluster_size": 0}
if len(points) < min_points:
return points, colors, {"warning": f"Not enough points for clustering (need at least {min_points})",
"num_clusters": 0, "largest_cluster_size": len(points)}
# Convert to Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points.astype(np.float64))
# Perform DBSCAN clustering
labels = np.array(pcd.cluster_dbscan(eps=eps, min_points=min_points, print_progress=False))
# Get unique cluster labels (excluding noise: -1)
unique_labels = np.unique(labels)
unique_labels = unique_labels[unique_labels >= 0] # Remove noise label (-1)
if len(unique_labels) == 0:
# No clusters found, all points are noise
return np.array([]), np.array([]).reshape(0, colors.shape[1]), {
"error": "No clusters found (all points are noise)",
"num_clusters": 0,
"largest_cluster_size": 0,
"noise_points": len(points)
}
# Count points in each cluster
cluster_sizes = {}
for label in unique_labels:
cluster_sizes[label] = np.sum(labels == label)
# Find the largest cluster
largest_cluster_label = max(cluster_sizes, key=cluster_sizes.get)
largest_cluster_size = cluster_sizes[largest_cluster_label]
# Keep only points from the largest cluster
mask = labels == largest_cluster_label
filtered_points = points[mask]
filtered_colors = colors[mask]
# Count noise points
noise_count = np.sum(labels == -1)
info = {
"num_clusters": len(unique_labels),
"largest_cluster_label": int(largest_cluster_label),
"largest_cluster_size": int(largest_cluster_size),
"original_points": len(points),
"filtered_points": len(filtered_points),
"noise_points": int(noise_count),
"removed_points": len(points) - len(filtered_points),
"cluster_sizes": {int(k): int(v) for k, v in cluster_sizes.items()}
}
return filtered_points, filtered_colors, info
def process_ply_file(input_ply: Path, output_ply: Optional[Path] = None,
eps: float = 10.0, min_points: int = 10,
preserve_colors: bool = True) -> dict:
"""
Process a single PLY file to keep only the largest cluster.
Args:
input_ply: Input PLY file path
output_ply: Output PLY file path (if None, saves to input_ply with _largest_cluster suffix)
eps: DBSCAN eps parameter (mm)
min_points: DBSCAN min_points parameter
preserve_colors: If True, preserve colors from input point cloud
Returns:
dict: Processing results and statistics
"""
if not input_ply.exists():
return {"error": f"Input file not found: {input_ply}"}
try:
# Load point cloud
pcd = o3d.io.read_point_cloud(str(input_ply))
if len(pcd.points) == 0:
return {"error": "Empty point cloud loaded"}
points = np.asarray(pcd.points)
colors = np.asarray(pcd.colors) if pcd.has_colors() else None
# Keep largest cluster
filtered_points, cluster_info = keep_largest_cluster(points, eps=eps, min_points=min_points)
if "error" in cluster_info:
return cluster_info
if len(filtered_points) == 0:
return {"error": "No points remaining after clustering", **cluster_info}
# Create filtered point cloud
filtered_pcd = o3d.geometry.PointCloud()
filtered_pcd.points = o3d.utility.Vector3dVector(filtered_points.astype(np.float64))
# Preserve colors if available
if preserve_colors and colors is not None:
# Get colors for filtered points
mask = np.isin(np.arange(len(points)), np.where(np.isin(points, filtered_points).all(axis=1))[0])
# Better approach: use the clustering labels to get colors
pcd_temp = o3d.geometry.PointCloud()
pcd_temp.points = o3d.utility.Vector3dVector(points.astype(np.float64))
labels = np.array(pcd_temp.cluster_dbscan(eps=eps, min_points=min_points, print_progress=False))
largest_cluster_label = cluster_info["largest_cluster_label"]
mask = labels == largest_cluster_label
if np.any(mask):
filtered_colors = colors[mask]
filtered_pcd.colors = o3d.utility.Vector3dVector(filtered_colors.astype(np.float64))
# Determine output path
if output_ply is None:
output_ply = input_ply.parent / f"{input_ply.stem}_largest_cluster.ply"
# Save filtered point cloud
success = o3d.io.write_point_cloud(str(output_ply), filtered_pcd)
if not success:
return {"error": f"Failed to save output file: {output_ply}"}
result = {
"success": True,
"input_file": str(input_ply),
"output_file": str(output_ply),
"original_points": len(points),
"filtered_points": len(filtered_points),
**cluster_info
}
return result
except Exception as e:
return {"error": f"Error processing {input_ply}: {str(e)}"}
def main():
parser = argparse.ArgumentParser(
description="Keep the largest cluster from point cloud PLY files using DBSCAN clustering"
)
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument("--ply", type=str, help="Path to a single PLY file")
input_group.add_argument("--folder", type=str, help="Path to a folder containing PLY files")
parser.add_argument("--output", type=str, default=None,
help="Output PLY file path (for single file) or output folder (for folder mode). "
"If not specified, saves to input location with _largest_cluster suffix")
parser.add_argument("--eps", type=float, default=10.0,
help="DBSCAN eps parameter: maximum distance between points in the same cluster (mm). Default: 10.0")
parser.add_argument("--min-points", type=int, default=300,
help="DBSCAN min_points parameter: minimum number of points to form a cluster. Default: 10")
parser.add_argument("--no-colors", action="store_true",
help="Do not preserve colors from input point cloud")
parser.add_argument("--results-json", type=str, default=None,
help="Path to save processing results as JSON (optional)")
args = parser.parse_args()
results = {}
if args.ply:
# Single file mode
input_ply = Path(args.ply).expanduser().resolve()
output_ply = Path(args.output).expanduser().resolve() if args.output else None
print(f"Processing: {input_ply.name}")
result = process_ply_file(
input_ply=input_ply,
output_ply=output_ply,
eps=args.eps,
min_points=args.min_points,
preserve_colors=not args.no_colors
)
results[str(input_ply)] = result
if "error" in result:
print(f" ERROR: {result['error']}")
elif "warning" in result:
print(f" WARNING: {result['warning']}")
else:
print(f" ✓ Original points: {result['original_points']}")
print(f" ✓ Filtered points: {result['filtered_points']}")
print(f" ✓ Clusters found: {result['num_clusters']}")
print(f" ✓ Largest cluster size: {result['largest_cluster_size']}")
print(f" ✓ Removed points: {result['removed_points']}")
print(f" ✓ Saved to: {result['output_file']}")
elif args.folder:
# Folder mode
input_folder = Path(args.folder).expanduser().resolve()
if not input_folder.is_dir():
print(f"ERROR: Folder not found: {input_folder}")
return
# Find all PLY files
ply_files = sorted(input_folder.glob("*.ply"))
ply_files.extend(sorted(input_folder.glob("*.PLY")))
if not ply_files:
print(f"ERROR: No PLY files found in {input_folder}")
return
print(f"Found {len(ply_files)} PLY file(s) in {input_folder}")
print("="*60)
# Determine output folder
if args.output:
output_folder = Path(args.output).expanduser().resolve()
output_folder.mkdir(parents=True, exist_ok=True)
else:
output_folder = input_folder
# Process each file
for idx, ply_file in enumerate(ply_files):
print(f"\n[{idx + 1}/{len(ply_files)}] Processing: {ply_file.name}")
if args.output:
output_ply = output_folder / f"{ply_file.stem}_largest_cluster.ply"
else:
output_ply = None # Will use default suffix
result = process_ply_file(
input_ply=ply_file,
output_ply=output_ply,
eps=args.eps,
min_points=args.min_points,
preserve_colors=not args.no_colors
)
results[str(ply_file)] = result
if "error" in result:
print(f" ERROR: {result['error']}")
elif "warning" in result:
print(f" WARNING: {result['warning']}")
else:
print(f"{result['original_points']} -> {result['filtered_points']} points "
f"({result['num_clusters']} clusters, kept largest)")
print(f"\n{'='*60}")
print(f"Processing complete!")
print(f"Processed: {len([r for r in results.values() if 'success' in r and r['success']])} files")
print(f"Errors: {len([r for r in results.values() if 'error' in r])} files")
if args.output:
print(f"Output folder: {output_folder}")
# Save results to JSON if requested
if args.results_json:
results_path = Path(args.results_json).expanduser().resolve()
with open(results_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to: {results_path}")
if __name__ == "__main__":
main()