Files
FishServer/FishMeasure/utils/correct_tail_rotation.py

442 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Correct fish tail rotation using dual RANSAC plane fitting.
This script detects if a fish tail is rotated relative to the fish body by:
1. Using first RANSAC to fit the main fish body plane
2. Using second RANSAC to fit the remaining points (potential rotated tail)
3. If second RANSAC covers most remaining points, the tail is rotated
4. Rotating the tail points back to align with the body plane
This correction should be applied before template matching as it affects height/length ratio.
"""
import numpy as np
from pathlib import Path
from typing import Tuple, Optional, Dict, Any
import open3d as o3d
def fit_plane_ransac(points: np.ndarray,
distance_threshold: float = 5.0,
ransac_n: int = 3,
num_iterations: int = 1000) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
"""
Fit a plane to points using RANSAC.
Args:
points: Point cloud array (N, 3)
distance_threshold: Maximum distance from point to plane to be considered inlier (mm)
ransac_n: Number of points to sample for plane fitting
num_iterations: Number of RANSAC iterations
Returns:
tuple: (plane_equation, inlier_indices, info_dict)
- plane_equation: [a, b, c, d] where ax + by + cz + d = 0
- inlier_indices: Array of indices of inlier points
- info_dict: Contains plane normal, inlier count, etc.
"""
if len(points) < 3:
return None, np.array([], dtype=int), {"error": "Not enough points (need at least 3)"}
if len(points) < ransac_n:
return None, np.array([], dtype=int), {"error": f"Not enough points for RANSAC (need at least {ransac_n})"}
# Convert to Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points.astype(np.float64))
# Fit plane using RANSAC
plane_model, inliers = pcd.segment_plane(
distance_threshold=distance_threshold,
ransac_n=ransac_n,
num_iterations=num_iterations
)
# Extract plane equation: ax + by + cz + d = 0
[a, b, c, d] = plane_model
# Calculate plane normal and distance
normal = np.array([a, b, c])
normal_norm = np.linalg.norm(normal)
if normal_norm > 0:
normal = normal / normal_norm
plane_distance = d / normal_norm
else:
normal = np.array([0, 0, 1])
plane_distance = 0.0
info = {
"plane_equation": plane_model,
"plane_normal": normal,
"plane_distance": float(plane_distance),
"num_inliers": len(inliers),
"num_total": len(points),
"num_outliers": len(points) - len(inliers),
"inlier_ratio": len(inliers) / len(points) if len(points) > 0 else 0.0,
"distance_threshold": distance_threshold
}
return plane_model, np.array(inliers), info
def calculate_rotation_between_planes(normal1: np.ndarray, normal2: np.ndarray) -> Tuple[np.ndarray, float]:
"""
Calculate rotation matrix and angle to rotate from plane1 to plane2.
Args:
normal1: Normal vector of first plane (unit vector)
normal2: Normal vector of second plane (unit vector)
Returns:
tuple: (rotation_matrix, angle_degrees)
- rotation_matrix: 3x3 rotation matrix
- angle_degrees: Rotation angle in degrees
"""
# Ensure unit vectors
normal1 = normal1 / np.linalg.norm(normal1)
normal2 = normal2 / np.linalg.norm(normal2)
# Calculate rotation axis (cross product)
axis = np.cross(normal1, normal2)
axis_norm = np.linalg.norm(axis)
# If normals are parallel or anti-parallel
if axis_norm < 1e-6:
# Planes are parallel, no rotation needed
return np.eye(3), 0.0
axis = axis / axis_norm
# Calculate rotation angle
cos_angle = np.clip(np.dot(normal1, normal2), -1.0, 1.0)
angle = np.arccos(cos_angle)
angle_degrees = np.degrees(angle)
# Build rotation matrix using Rodrigues' rotation formula
K = np.array([
[0, -axis[2], axis[1]],
[axis[2], 0, -axis[0]],
[-axis[1], axis[0], 0]
])
rotation_matrix = (np.eye(3) +
np.sin(angle) * K +
(1 - np.cos(angle)) * np.dot(K, K))
return rotation_matrix, angle_degrees
def correct_tail_rotation(points: np.ndarray,
colors: Optional[np.ndarray] = None,
distance_threshold: float = 5.0,
ransac_n: int = 3,
num_iterations: int = 1000,
min_tail_ratio: float = 0.7,
min_angle_threshold: float = 5.0) -> Tuple[np.ndarray, Optional[np.ndarray], Dict[str, Any]]:
"""
Detect and correct fish tail rotation using dual RANSAC.
Args:
points: Point cloud array (N, 3)
colors: Color array (N, 3), optional
distance_threshold: Maximum distance from point to plane to be considered inlier (mm)
ransac_n: Number of points to sample for plane fitting
num_iterations: Number of RANSAC iterations
min_tail_ratio: Minimum ratio of outliers that must be on second plane to consider tail rotated (default: 0.7)
min_angle_threshold: Minimum angle (degrees) between planes to apply correction (default: 5.0)
Returns:
tuple: (corrected_points, corrected_colors, info_dict)
- corrected_points: Point cloud with tail rotated back to body plane (if rotation detected)
- corrected_colors: Colors (same as input if provided)
- info_dict: Contains detection results, rotation info, etc.
"""
if len(points) < 6: # Need at least 6 points for two planes
return points, colors, {"error": "Not enough points (need at least 6)", "rotation_detected": False}
info = {
"rotation_detected": False,
"rotation_applied": False,
"body_plane": None,
"tail_plane": None,
"rotation_angle_degrees": 0.0,
"body_inlier_ratio": 0.0,
"tail_inlier_ratio": 0.0,
"total_points": len(points),
"body_points": 0,
"tail_points": 0
}
# Step 1: First RANSAC - fit main fish body plane
body_plane, body_inliers, body_info = fit_plane_ransac(
points, distance_threshold, ransac_n, num_iterations
)
if body_plane is None or "error" in body_info:
return points, colors, {**info, "error": body_info.get("error", "First RANSAC failed")}
body_inlier_ratio = body_info["inlier_ratio"]
body_outliers = np.setdiff1d(np.arange(len(points)), body_inliers)
info["body_plane"] = body_info
info["body_inlier_ratio"] = body_inlier_ratio
info["body_points"] = len(body_inliers)
# If most points are on the body plane, no rotation detected
if body_inlier_ratio >= 0.95:
info["rotation_detected"] = False
return points, colors, info
# Step 2: Second RANSAC - fit remaining points (potential rotated tail)
if len(body_outliers) < ransac_n:
info["rotation_detected"] = False
info["error"] = f"Not enough outlier points for second RANSAC ({len(body_outliers)} < {ransac_n})"
return points, colors, info
tail_points = points[body_outliers]
tail_plane, tail_inliers_local, tail_info = fit_plane_ransac(
tail_points, distance_threshold, ransac_n, num_iterations
)
if tail_plane is None or "error" in tail_info:
info["rotation_detected"] = False
info["error"] = tail_info.get("error", "Second RANSAC failed")
return points, colors, info
tail_inlier_ratio = tail_info["inlier_ratio"]
tail_inlier_count = len(tail_inliers_local)
info["tail_plane"] = tail_info
info["tail_inlier_ratio"] = tail_inlier_ratio
info["tail_points"] = tail_inlier_count
# Step 3: Check if tail is rotated
# Condition 1: Most outlier points should be on the second plane
outlier_ratio_on_tail_plane = tail_inlier_count / len(body_outliers) if len(body_outliers) > 0 else 0.0
# Condition 2: There should be a significant angle between the two planes
body_normal = body_info["plane_normal"]
tail_normal = tail_info["plane_normal"]
rotation_matrix, rotation_angle = calculate_rotation_between_planes(body_normal, tail_normal)
info["rotation_angle_degrees"] = rotation_angle
# Check conditions for rotation detection
rotation_detected = (
outlier_ratio_on_tail_plane >= min_tail_ratio and
rotation_angle >= min_angle_threshold
)
info["rotation_detected"] = rotation_detected
info["outlier_ratio_on_tail_plane"] = outlier_ratio_on_tail_plane
if not rotation_detected:
return points, colors, info
# Step 4: Rotate tail points back to body plane
# Get tail points (points that are on the tail plane)
tail_inlier_indices_global = body_outliers[tail_inliers_local]
tail_points_to_rotate = points[tail_inlier_indices_global]
# Calculate centroid of tail points for rotation
tail_centroid = np.mean(tail_points_to_rotate, axis=0)
# Rotate tail points around their centroid
tail_points_centered = tail_points_to_rotate - tail_centroid
tail_points_rotated = (rotation_matrix @ tail_points_centered.T).T + tail_centroid
# Combine corrected points
corrected_points = points.copy()
corrected_points[tail_inlier_indices_global] = tail_points_rotated
info["rotation_applied"] = True
info["num_tail_points_rotated"] = len(tail_inlier_indices_global)
return corrected_points, colors, info
def correct_tail_rotation_from_ply(ply_path: str,
output_path: Optional[str] = None,
distance_threshold: float = 5.0,
ransac_n: int = 3,
num_iterations: int = 1000,
min_tail_ratio: float = 0.7,
min_angle_threshold: float = 5.0) -> Tuple[bool, Dict[str, Any]]:
"""
Correct tail rotation for a point cloud from a PLY file.
Args:
ply_path: Path to input PLY file
output_path: Path to save corrected PLY file (optional)
distance_threshold: Maximum distance from point to plane to be considered inlier (mm)
ransac_n: Number of points to sample for plane fitting
num_iterations: Number of RANSAC iterations
min_tail_ratio: Minimum ratio of outliers that must be on second plane to consider tail rotated
min_angle_threshold: Minimum angle (degrees) between planes to apply correction
Returns:
tuple: (success: bool, info_dict: dict)
"""
ply_path = Path(ply_path).expanduser().resolve()
if not ply_path.exists():
return False, {"error": f"PLY file not found: {ply_path}"}
# Load point cloud
pcd = o3d.io.read_point_cloud(str(ply_path))
if len(pcd.points) == 0:
return False, {"error": "Empty point cloud"}
# Convert to numpy arrays
points = np.asarray(pcd.points)
colors = np.asarray(pcd.colors) * 255.0 if pcd.has_colors() else None
# Correct tail rotation
corrected_points, corrected_colors, info = correct_tail_rotation(
points, colors, distance_threshold, ransac_n, num_iterations,
min_tail_ratio, min_angle_threshold
)
if "error" in info:
return False, info
# Save corrected point cloud if output path is provided
if output_path:
output_path = Path(output_path).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
corrected_pcd = o3d.geometry.PointCloud()
corrected_pcd.points = o3d.utility.Vector3dVector(corrected_points.astype(np.float64))
if corrected_colors is not None:
corrected_pcd.colors = o3d.utility.Vector3dVector((corrected_colors / 255.0).astype(np.float64))
o3d.io.write_point_cloud(str(output_path), corrected_pcd)
info["output_file"] = str(output_path)
info["input_file"] = str(ply_path)
return True, info
def correct_tail_rotation_array(points: np.ndarray,
colors: Optional[np.ndarray] = None,
distance_threshold: float = 5.0,
ransac_n: int = 3,
num_iterations: int = 1000,
min_tail_ratio: float = 0.7,
min_angle_threshold: float = 5.0,
verbose: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray], bool, Dict[str, Any]]:
"""
Convenience function to correct tail rotation for point cloud arrays.
Returns the corrected points and colors, plus a flag indicating if correction was applied.
Args:
points: Point cloud array (N, 3)
colors: Color array (N, 3), optional
distance_threshold: Maximum distance from point to plane to be considered inlier (mm)
ransac_n: Number of points to sample for plane fitting
num_iterations: Number of RANSAC iterations
min_tail_ratio: Minimum ratio of outliers that must be on second plane to consider tail rotated
min_angle_threshold: Minimum angle (degrees) between planes to apply correction
verbose: If True, print correction info
Returns:
tuple: (corrected_points, corrected_colors, correction_applied, info_dict)
- corrected_points: Point cloud with tail rotated (if rotation detected) or original points
- corrected_colors: Colors (same as input)
- correction_applied: True if rotation was detected and corrected
- info_dict: Contains detection results
"""
corrected_points, corrected_colors, info = correct_tail_rotation(
points, colors, distance_threshold, ransac_n, num_iterations,
min_tail_ratio, min_angle_threshold
)
correction_applied = info.get("rotation_applied", False)
if verbose and info.get("rotation_detected", False):
print(f" Tail rotation detected: angle={info.get('rotation_angle_degrees', 0):.2f}°, "
f"body={info.get('body_inlier_ratio', 0)*100:.1f}%, "
f"tail={info.get('tail_inlier_ratio', 0)*100:.1f}%")
if correction_applied:
print(f" Applied rotation correction to {info.get('num_tail_points_rotated', 0)} tail points")
return corrected_points, corrected_colors, correction_applied, info
if __name__ == "__main__":
import argparse
import json
parser = argparse.ArgumentParser(
description="Detect and correct fish tail rotation using dual RANSAC plane fitting"
)
parser.add_argument("--ply", type=str, required=True, help="Path to input PLY file")
parser.add_argument("--output", type=str, help="Path to save corrected PLY file (optional)")
parser.add_argument("--distance-threshold", type=float, default=5.0,
help="Maximum distance from point to plane to be considered inlier (mm, default: 5.0)")
parser.add_argument("--ransac-n", type=int, default=3,
help="Number of points to sample for plane fitting (default: 3)")
parser.add_argument("--num-iterations", type=int, default=1000,
help="Number of RANSAC iterations (default: 1000)")
parser.add_argument("--min-tail-ratio", type=float, default=0.7,
help="Minimum ratio of outliers on tail plane to detect rotation (default: 0.7)")
parser.add_argument("--min-angle-threshold", type=float, default=5.0,
help="Minimum angle (degrees) between planes to apply correction (default: 5.0)")
parser.add_argument("--save-info", type=str, help="Path to save JSON file with correction info")
args = parser.parse_args()
# Correct tail rotation
success, info = correct_tail_rotation_from_ply(
args.ply,
output_path=args.output,
distance_threshold=args.distance_threshold,
ransac_n=args.ransac_n,
num_iterations=args.num_iterations,
min_tail_ratio=args.min_tail_ratio,
min_angle_threshold=args.min_angle_threshold
)
if not success:
print(f"ERROR: {info.get('error', 'Unknown error')}")
exit(1)
# Print results
print(f"Input file: {args.ply}")
print(f"Total points: {info['total_points']}")
print(f"Body plane inliers: {info['body_points']} ({info['body_inlier_ratio']*100:.1f}%)")
print(f"Tail plane inliers: {info['tail_points']} ({info['tail_inlier_ratio']*100:.1f}%)")
print(f"Rotation angle: {info['rotation_angle_degrees']:.2f} degrees")
print(f"Rotation detected: {info['rotation_detected']}")
if info['rotation_detected']:
print(f"Rotation applied: {info['rotation_applied']}")
print(f"Tail points rotated: {info.get('num_tail_points_rotated', 0)}")
if args.output:
print(f"Corrected point cloud saved to: {args.output}")
else:
print("No rotation detected - point cloud unchanged")
if args.output:
print(f"Original point cloud copied to: {args.output}")
# Save info to JSON if requested
if args.save_info:
# Convert numpy arrays to lists for JSON serialization
json_info = {}
for key, value in info.items():
if isinstance(value, np.ndarray):
json_info[key] = value.tolist()
elif isinstance(value, dict) and "plane_normal" in value:
# Handle nested dict with numpy arrays
json_info[key] = {
k: v.tolist() if isinstance(v, np.ndarray) else v
for k, v in value.items()
}
else:
json_info[key] = value
with open(args.save_info, 'w') as f:
json.dump(json_info, f, indent=2)
print(f"\nCorrection info saved to: {args.save_info}")