2434 lines
122 KiB
Python
2434 lines
122 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Standalone image preview with YOLO detection and SAM segmentation.
|
|
Pure OpenCV + YOLO + SAM for viewing images from a folder.
|
|
"""
|
|
|
|
import argparse
|
|
import cv2
|
|
import json
|
|
import numpy as np
|
|
import torch
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
|
|
|
|
def _open_video_writer(path: Path, fps: float, size: Tuple[int, int]) -> cv2.VideoWriter:
|
|
"""Open a VideoWriter, preferring GStreamer NVENC on Jetson for hardware H.264."""
|
|
w, h = size
|
|
try:
|
|
if hasattr(cv2, "CAP_GSTREAMER"):
|
|
loc = str(path).replace('"', '\\"')
|
|
gst_pipe = (
|
|
f'appsrc ! videoconvert ! video/x-raw,format=BGRx ! '
|
|
f'nvvidconv ! video/x-raw(memory:NVMM) ! '
|
|
f'nvv4l2h264enc bitrate=4000000 ! h264parse ! '
|
|
f'mp4mux ! filesink location="{loc}"'
|
|
)
|
|
writer = cv2.VideoWriter(gst_pipe, cv2.CAP_GSTREAMER, 0, fps, (w, h))
|
|
if writer.isOpened():
|
|
return writer
|
|
writer.release()
|
|
except Exception:
|
|
pass
|
|
return cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
|
|
from ultralytics import YOLO
|
|
from seg import init_models
|
|
from pointcloud_filter import filter_point_cloud
|
|
import os
|
|
import sys
|
|
import importlib
|
|
import tempfile
|
|
import subprocess
|
|
|
|
try:
|
|
import pyzed.sl as sl
|
|
ZED_AVAILABLE = True
|
|
except ImportError:
|
|
ZED_AVAILABLE = False
|
|
print("Warning: pyzed not available. SVO2 file support disabled.")
|
|
|
|
try:
|
|
import open3d as o3d
|
|
O3D_AVAILABLE = True
|
|
except ImportError:
|
|
O3D_AVAILABLE = False
|
|
print("Warning: open3d not available. Point cloud classification may fail.")
|
|
|
|
from dataset.zed_reader import ZEDReader
|
|
from utils.evaluate_flatness import evaluate_flatness_ransac
|
|
from utils.keep_largest_cluster import keep_largest_cluster_with_colors
|
|
from utils.correct_tail_rotation import correct_tail_rotation_array
|
|
|
|
|
|
def draw_detections(image, results, class_names, depth_stats_list=None):
|
|
"""Draw YOLO detections on image with tracking IDs.
|
|
|
|
Args:
|
|
image: Image to draw on
|
|
results: YOLO detection/tracking results
|
|
class_names: Dictionary of class names
|
|
depth_stats_list: Optional list of depth stats dicts, one per detection
|
|
"""
|
|
if results is None or results.boxes is None:
|
|
return image
|
|
|
|
boxes = results.boxes.xyxy.cpu().numpy()
|
|
confidences = results.boxes.conf.cpu().numpy()
|
|
class_ids = results.boxes.cls.cpu().numpy().astype(int) if results.boxes.cls is not None else np.zeros(len(boxes), dtype=int)
|
|
|
|
# Get track IDs if available (from YOLO tracking)
|
|
track_ids = None
|
|
if hasattr(results.boxes, 'id') and results.boxes.id is not None:
|
|
track_ids = results.boxes.id.cpu().numpy().astype(int)
|
|
|
|
for i, (box, conf, cls_id) in enumerate(zip(boxes, confidences, class_ids)):
|
|
x1, y1, x2, y2 = map(int, box)
|
|
|
|
# Draw green box
|
|
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
|
# Build label with track ID if available
|
|
if class_names and cls_id in class_names:
|
|
class_name = class_names[cls_id]
|
|
else:
|
|
class_name = "fish"
|
|
|
|
# Add track ID to label if available
|
|
if track_ids is not None and i < len(track_ids):
|
|
label = f"ID:{track_ids[i]} {class_name}: {conf:.2f}"
|
|
else:
|
|
label = f"{class_name}: {conf:.2f}"
|
|
|
|
(text_w, text_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
|
|
cv2.rectangle(image, (x1, y1 - text_h - baseline - 5), (x1 + text_w, y1), (0, 255, 0), -1)
|
|
cv2.putText(image, label, (x1, y1 - baseline - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
|
|
|
|
# Draw depth median if available
|
|
if depth_stats_list and i < len(depth_stats_list) and depth_stats_list[i] is not None:
|
|
median_depth = depth_stats_list[i].get('median_depth_mm', None)
|
|
if median_depth is not None:
|
|
depth_text = f"Median: {median_depth:.1f}mm"
|
|
(depth_w, depth_h), depth_baseline = cv2.getTextSize(depth_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
|
# Draw below the confidence label
|
|
text_y = y1 + text_h + baseline + 5 + depth_h
|
|
# Ensure text doesn't go outside image bounds
|
|
if text_y < image.shape[0]:
|
|
cv2.rectangle(image, (x1, y1 + text_h + baseline + 5), (x1 + depth_w, text_y), (255, 255, 0), -1)
|
|
cv2.putText(image, depth_text, (x1, text_y - depth_baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
|
|
|
return image
|
|
|
|
|
|
def segment_with_sam(sam_predictor, image_bgr, boxes_xyxy, device):
|
|
"""Segment fish using SAM with YOLO detection boxes.
|
|
Returns list of individual masks for each detection."""
|
|
if len(boxes_xyxy) == 0:
|
|
return []
|
|
|
|
# Convert BGR to RGB for SAM
|
|
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
|
|
|
# Set image in SAM predictor
|
|
sam_predictor.set_image(image_rgb)
|
|
|
|
# Convert boxes to tensor and transform
|
|
boxes_tensor = torch.tensor(boxes_xyxy, dtype=torch.float32, device=device)
|
|
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_tensor, image_rgb.shape[:2])
|
|
|
|
# Predict masks
|
|
masks, scores, logits = sam_predictor.predict_torch(
|
|
point_coords=None,
|
|
point_labels=None,
|
|
boxes=transformed_boxes,
|
|
multimask_output=False
|
|
)
|
|
|
|
# Return individual masks
|
|
if masks is None or masks.shape[0] == 0:
|
|
return []
|
|
|
|
individual_masks = []
|
|
for i in range(masks.shape[0]):
|
|
individual_masks.append(masks[i][0].cpu().numpy().astype(bool))
|
|
|
|
return individual_masks
|
|
|
|
|
|
def create_segmentation_overlay(image, masks):
|
|
"""Create visualization with segmentation mask overlay."""
|
|
overlay = image.copy()
|
|
|
|
if masks and len(masks) > 0:
|
|
# Combine all masks
|
|
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=bool)
|
|
for mask in masks:
|
|
combined_mask = np.logical_or(combined_mask, mask)
|
|
|
|
# Create colored mask (green)
|
|
colored_mask = np.zeros_like(overlay)
|
|
colored_mask[combined_mask] = [0, 255, 0] # Green mask
|
|
|
|
# Blend with original image
|
|
alpha = 0.5
|
|
overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
|
|
|
|
# Draw mask outline
|
|
mask_uint8 = (combined_mask.astype(np.uint8) * 255)
|
|
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)
|
|
|
|
return overlay
|
|
|
|
|
|
def validate_depth_in_mask(depth_mm, mask, max_depth_mm=1200.0, min_depth_mm=100.0,
|
|
outlier_std_factor=3.0, min_valid_ratio=0.5):
|
|
"""Validate depth values within mask region.
|
|
|
|
Checks if depth values in the mask region are reasonable:
|
|
- Depth values should be within [min_depth_mm, max_depth_mm]
|
|
- Depth values should not have too many outliers (based on std)
|
|
- At least min_valid_ratio of mask pixels should have valid depth
|
|
|
|
Args:
|
|
depth_mm: Depth map in millimeters (H, W)
|
|
mask: Binary mask (H, W)
|
|
max_depth_mm: Maximum valid depth (default: 1200.0mm)
|
|
min_depth_mm: Minimum valid depth (default: 100.0mm)
|
|
outlier_std_factor: Factor for outlier detection (default: 3.0 std devs)
|
|
min_valid_ratio: Minimum ratio of valid depth pixels in mask (default: 0.5)
|
|
|
|
Returns:
|
|
tuple: (is_valid: bool, reason: str, stats: dict)
|
|
"""
|
|
if mask is None:
|
|
return False, "Mask is None", {}
|
|
|
|
# Ensure mask is boolean
|
|
if mask.dtype != bool:
|
|
mask = mask.astype(bool)
|
|
|
|
# Check dimensions match
|
|
if mask.shape != depth_mm.shape:
|
|
return False, f"Mask shape {mask.shape} != depth shape {depth_mm.shape}", {}
|
|
|
|
# Get depth values within mask
|
|
mask_depth = depth_mm[mask]
|
|
|
|
# Filter valid depth values
|
|
valid_mask_depth = mask_depth[(mask_depth > 0) &
|
|
np.isfinite(mask_depth) &
|
|
(mask_depth >= min_depth_mm) &
|
|
(mask_depth <= max_depth_mm)]
|
|
|
|
if len(valid_mask_depth) == 0:
|
|
return False, "No valid depth values in mask", {}
|
|
|
|
# Check if enough pixels have valid depth
|
|
total_mask_pixels = np.sum(mask)
|
|
valid_ratio = len(valid_mask_depth) / total_mask_pixels if total_mask_pixels > 0 else 0.0
|
|
|
|
if valid_ratio < min_valid_ratio:
|
|
return False, f"Too few valid depth pixels: {valid_ratio:.1%} < {min_valid_ratio:.1%}", {
|
|
"valid_ratio": valid_ratio,
|
|
"total_mask_pixels": total_mask_pixels,
|
|
"valid_depth_count": len(valid_mask_depth)
|
|
}
|
|
|
|
# Check for outliers using median and MAD (Median Absolute Deviation)
|
|
median_depth = np.median(valid_mask_depth)
|
|
mad = np.median(np.abs(valid_mask_depth - median_depth))
|
|
|
|
# Use MAD-based outlier detection (more robust than std)
|
|
if mad > 0:
|
|
# MAD to std approximation: std ≈ 1.4826 * MAD
|
|
std_approx = 1.4826 * mad
|
|
lower_bound = median_depth - outlier_std_factor * std_approx
|
|
upper_bound = median_depth + outlier_std_factor * std_approx
|
|
|
|
# Count outliers
|
|
outliers = np.sum((valid_mask_depth < lower_bound) | (valid_mask_depth > upper_bound))
|
|
outlier_ratio = outliers / len(valid_mask_depth) if len(valid_mask_depth) > 0 else 0.0
|
|
|
|
# If more than 30% are outliers, consider depth unreliable
|
|
if outlier_ratio > 0.3:
|
|
return False, f"Too many depth outliers: {outlier_ratio:.1%} > 30%", {
|
|
"median_depth": median_depth,
|
|
"mad": mad,
|
|
"outlier_ratio": outlier_ratio,
|
|
"lower_bound": lower_bound,
|
|
"upper_bound": upper_bound
|
|
}
|
|
|
|
# All checks passed
|
|
return True, "OK", {
|
|
"median_depth": median_depth,
|
|
"mean_depth": np.mean(valid_mask_depth),
|
|
"std_depth": np.std(valid_mask_depth),
|
|
"min_depth": np.min(valid_mask_depth),
|
|
"max_depth": np.max(valid_mask_depth),
|
|
"valid_ratio": valid_ratio,
|
|
"valid_count": len(valid_mask_depth)
|
|
}
|
|
|
|
|
|
def depth_mask_to_pointcloud(image_bgr, depth_mm, mask, intrinsics, max_depth_mm=1200.0):
|
|
"""Convert depth map and mask to 3D point cloud.
|
|
|
|
Args:
|
|
image_bgr: BGR image (H, W, 3)
|
|
depth_mm: Depth map in millimeters (H, W)
|
|
mask: Binary mask (H, W) - must be boolean or 0/1 array
|
|
intrinsics: Dict with fx, fy, cx, cy
|
|
max_depth_mm: Maximum depth to include
|
|
|
|
Returns:
|
|
points: (N, 3) array of 3D points in mm
|
|
colors: (N, 3) array of RGB colors (0-255)
|
|
"""
|
|
fx = intrinsics["fx"]
|
|
fy = intrinsics["fy"]
|
|
cx = intrinsics["cx"]
|
|
cy = intrinsics["cy"]
|
|
|
|
height, width = depth_mm.shape
|
|
|
|
# Validate and ensure mask is boolean and matches depth dimensions
|
|
if mask is None:
|
|
return None, None
|
|
|
|
# Ensure mask is boolean
|
|
if mask.dtype != bool:
|
|
mask = mask.astype(bool)
|
|
|
|
# Check dimensions match
|
|
if mask.shape != depth_mm.shape:
|
|
# Try to resize mask if dimensions don't match
|
|
if mask.shape[0] != height or mask.shape[1] != width:
|
|
import cv2
|
|
mask = cv2.resize(mask.astype(np.uint8), (width, height), interpolation=cv2.INTER_NEAREST).astype(bool)
|
|
print(f" Warning: Resized mask from {mask.shape} to ({height}, {width})")
|
|
|
|
# Create grid coordinates
|
|
y_idx, x_idx = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
|
|
|
|
# Filter by mask FIRST, then by valid depth
|
|
# This ensures we only consider points within the mask region
|
|
mask_valid = mask.astype(bool)
|
|
depth_valid = np.isfinite(depth_mm) & (depth_mm > 0) & (depth_mm <= max_depth_mm)
|
|
|
|
# Combine: must be in mask AND have valid depth
|
|
valid = mask_valid & depth_valid
|
|
|
|
if not np.any(valid):
|
|
return None, None
|
|
|
|
# Get valid pixel coordinates and depths
|
|
x_coords = x_idx[valid]
|
|
y_coords = y_idx[valid]
|
|
z_vals = depth_mm[valid]
|
|
|
|
# Convert to 3D points (in mm)
|
|
x_vals = (x_coords - cx) * z_vals / fx
|
|
y_vals = (y_coords - cy) * z_vals / fy
|
|
|
|
points = np.stack([x_vals, y_vals, z_vals], axis=1).astype(np.float32)
|
|
|
|
# Get colors (BGR to RGB)
|
|
colors = image_bgr[valid][:, [2, 1, 0]].astype(np.uint8)
|
|
|
|
return points, colors
|
|
|
|
|
|
|
|
|
|
def write_ply_file(path, points, colors):
|
|
"""Write point cloud to PLY file."""
|
|
with open(path, 'w', encoding='utf-8') as f:
|
|
f.write("ply\n")
|
|
f.write("format ascii 1.0\n")
|
|
f.write(f"element vertex {len(points)}\n")
|
|
f.write("property float x\n")
|
|
f.write("property float y\n")
|
|
f.write("property float z\n")
|
|
f.write("property uchar red\n")
|
|
f.write("property uchar green\n")
|
|
f.write("property uchar blue\n")
|
|
f.write("end_header\n")
|
|
for pt, color in zip(points, colors):
|
|
f.write(f"{pt[0]:.3f} {pt[1]:.3f} {pt[2]:.3f} {int(color[0])} {int(color[1])} {int(color[2])}\n")
|
|
|
|
|
|
def load_pointcloud_classifier(checkpoint_path: str, num_classes: int = 2, use_normals: bool = False,
|
|
device: str = None) -> Optional[torch.nn.Module]:
|
|
"""Load point cloud quality classifier model.
|
|
|
|
Args:
|
|
checkpoint_path: Path to model checkpoint (e.g., log/classification/fish_pointnet2_finetune/checkpoints/best_model.pth)
|
|
num_classes: Number of classes (2 for good/bad)
|
|
use_normals: Whether to use normals
|
|
device: Device to use ('cuda' or 'cpu'), defaults to 'cuda' if available
|
|
|
|
Returns:
|
|
Loaded model or None if failed
|
|
"""
|
|
if device is None:
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
checkpoint_path = Path(checkpoint_path).expanduser().resolve()
|
|
if not checkpoint_path.exists():
|
|
print(f"ERROR: Point cloud classifier checkpoint not found: {checkpoint_path}")
|
|
return None
|
|
|
|
# Get model name from checkpoint directory structure
|
|
# Expected: log/classification/{model_name}/checkpoints/best_model.pth
|
|
# The model name is stored in logs/ directory as a .txt file (same as test_classification.py does)
|
|
experiment_dir = checkpoint_path.parent.parent
|
|
logs_dir = experiment_dir / 'logs'
|
|
|
|
model_name = None
|
|
|
|
# Method 1: Try to get model name from logs directory (same as test_classification.py)
|
|
if logs_dir.exists():
|
|
# Look for .txt files (as in test_classification.py: model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0])
|
|
txt_files = list(logs_dir.glob('*.txt'))
|
|
if txt_files:
|
|
model_name = txt_files[0].stem
|
|
print(f"Found model name from logs directory: {model_name}")
|
|
|
|
# Method 2: Try to find .py file in experiment directory
|
|
if model_name is None:
|
|
py_files = list(experiment_dir.glob('pointnet*.py'))
|
|
if py_files:
|
|
model_name = py_files[0].stem
|
|
print(f"Found model file in experiment directory: {model_name}.py")
|
|
|
|
# Method 3: Default to pointnet2_cls_ssg if nothing found
|
|
if model_name is None:
|
|
model_name = 'pointnet2_cls_ssg'
|
|
print(f"Using default model name: {model_name}")
|
|
|
|
# Add models directory to path (same as test_classification.py)
|
|
# checkpoint_path: .../Pointnet_Pointnet2_pytorch/log/classification/fish_pointnet2_finetune/checkpoints/best_model.pth
|
|
# We need to go up to Pointnet_Pointnet2_pytorch directory
|
|
# Structure: checkpoints -> fish_pointnet2_finetune -> classification -> log -> Pointnet_Pointnet2_pytorch
|
|
pointnet_dir = checkpoint_path.parent.parent.parent.parent.parent # Go up to Pointnet_Pointnet2_pytorch
|
|
models_dir = pointnet_dir / 'models'
|
|
|
|
if models_dir.exists():
|
|
if str(models_dir) not in sys.path:
|
|
sys.path.insert(0, str(models_dir))
|
|
print(f"Added to sys.path: {models_dir}")
|
|
else:
|
|
print(f"WARNING: Models directory not found: {models_dir}")
|
|
|
|
# Also add the base directory (same as test_classification.py: sys.path.append(os.path.join(ROOT_DIR, 'models')))
|
|
# test_classification.py does: sys.path.append(os.path.join(ROOT_DIR, 'models'))
|
|
# where ROOT_DIR = BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
# So we need to add the Pointnet_Pointnet2_pytorch directory itself
|
|
if str(pointnet_dir) not in sys.path:
|
|
sys.path.insert(0, str(pointnet_dir))
|
|
print(f"Added to sys.path: {pointnet_dir}")
|
|
|
|
# Import model (same as test_classification.py)
|
|
print(f"Attempting to import model: {model_name}")
|
|
try:
|
|
model_module = importlib.import_module(model_name)
|
|
print(f"✓ Successfully imported {model_name}")
|
|
except ImportError as e1:
|
|
print(f"Failed to import {model_name}: {e1}")
|
|
# Try importing from models subdirectory
|
|
print(f"Attempting to import from models.{model_name}")
|
|
try:
|
|
model_module = importlib.import_module(f'models.{model_name}')
|
|
print(f"✓ Successfully imported models.{model_name}")
|
|
except ImportError as e2:
|
|
print(f"ERROR: Could not import model {model_name} or models.{model_name}")
|
|
print(f" First error: {e1}")
|
|
print(f" Second error: {e2}")
|
|
print(f" Current sys.path: {[p for p in sys.path if 'pointcloud' in p.lower() or 'pointnet' in p.lower()]}")
|
|
return None
|
|
|
|
# Create model
|
|
print(f"Creating model with num_classes={num_classes}, use_normals={use_normals}")
|
|
classifier = model_module.get_model(num_classes, normal_channel=use_normals)
|
|
classifier = classifier.to(device)
|
|
|
|
# Load checkpoint
|
|
print(f"Loading checkpoint from: {checkpoint_path}")
|
|
checkpoint = torch.load(str(checkpoint_path), map_location=device)
|
|
if 'model_state_dict' in checkpoint:
|
|
classifier.load_state_dict(checkpoint['model_state_dict'])
|
|
print("✓ Loaded model_state_dict from checkpoint")
|
|
else:
|
|
classifier.load_state_dict(checkpoint)
|
|
print("✓ Loaded state dict directly from checkpoint")
|
|
|
|
classifier.eval()
|
|
print(f"✓ Loaded point cloud classifier from {checkpoint_path}")
|
|
return classifier
|
|
|
|
|
|
def classify_pointcloud_array(classifier: torch.nn.Module, points: np.ndarray, colors: np.ndarray,
|
|
num_point: int = 1024, vote_num: int = 3,
|
|
use_cpu: bool = False, confidence_threshold: float = 0.5) -> Tuple[bool, float, Dict]:
|
|
"""Classify a point cloud array by using test_classification.py's classify_single_pointcloud function.
|
|
|
|
This function saves the point cloud to a temporary PLY file and then calls the exact same
|
|
classification function used in test_classification.py to ensure consistency.
|
|
|
|
Args:
|
|
classifier: Loaded classifier model
|
|
points: Point cloud array (N, 3)
|
|
colors: Color array (N, 3) - used for saving temporary PLY file
|
|
num_point: Number of points to sample (default: 1024)
|
|
vote_num: Number of votes for classification (default: 3)
|
|
use_cpu: Whether to use CPU (if False, uses model's device)
|
|
confidence_threshold: Minimum confidence threshold for "good" classification (default: 0.5)
|
|
Only point clouds with confidence >= threshold will be considered "good"
|
|
|
|
Returns:
|
|
tuple: (is_good: bool, confidence: float, result_dict: dict)
|
|
Returns (False, 0.0, {}) if classification fails
|
|
is_good is True only if predicted as "good" AND confidence >= confidence_threshold
|
|
"""
|
|
if classifier is None:
|
|
return False, 0.0, {}
|
|
|
|
if len(points) == 0:
|
|
return False, 0.0, {"error": "Empty point cloud"}
|
|
|
|
if not O3D_AVAILABLE:
|
|
return False, 0.0, {"error": "open3d not available"}
|
|
|
|
# Create temporary PLY file
|
|
temp_ply = None
|
|
try:
|
|
# Import the classification function from test_classification.py
|
|
# Try multiple possible paths
|
|
possible_paths = [
|
|
Path(__file__).parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "test_classification.py",
|
|
Path(__file__).parent.parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "test_classification.py",
|
|
Path("/home/ubuntu/projects/FishMeasure/pointcloud_classifier/Pointnet_Pointnet2_pytorch/test_classification.py"),
|
|
]
|
|
|
|
test_classification_path = None
|
|
for path in possible_paths:
|
|
if path.exists():
|
|
test_classification_path = path
|
|
break
|
|
|
|
if test_classification_path is None:
|
|
error_msg = f"test_classification.py not found. Tried paths: {[str(p) for p in possible_paths]}"
|
|
print(f" ERROR: {error_msg}")
|
|
return False, 0.0, {"error": error_msg}
|
|
|
|
# Add the directory to sys.path to import the module
|
|
test_classification_dir = test_classification_path.parent
|
|
if str(test_classification_dir) not in sys.path:
|
|
sys.path.insert(0, str(test_classification_dir))
|
|
|
|
# Import the classification function
|
|
try:
|
|
from test_classification import classify_single_pointcloud
|
|
except ImportError as e:
|
|
error_msg = f"Failed to import classify_single_pointcloud from test_classification: {e}"
|
|
print(f" ERROR: {error_msg}")
|
|
print(f" sys.path includes: {[p for p in sys.path if 'pointcloud' in p or 'Pointnet' in p]}")
|
|
return False, 0.0, {"error": error_msg}
|
|
|
|
# Create temporary PLY file
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.ply', delete=False) as f:
|
|
temp_ply = f.name
|
|
|
|
# Save point cloud to temporary PLY file
|
|
pcd = o3d.geometry.PointCloud()
|
|
pcd.points = o3d.utility.Vector3dVector(points.astype(np.float64))
|
|
if colors is not None and len(colors) == len(points):
|
|
pcd.colors = o3d.utility.Vector3dVector(colors.astype(np.float64) / 255.0)
|
|
o3d.io.write_point_cloud(temp_ply, pcd)
|
|
|
|
# Call the exact same classification function from test_classification.py
|
|
result, error = classify_single_pointcloud(
|
|
classifier, temp_ply, num_point=num_point, vote_num=vote_num,
|
|
use_cpu=use_cpu, use_normals=False
|
|
)
|
|
|
|
# Clean up temporary file
|
|
try:
|
|
os.unlink(temp_ply)
|
|
except:
|
|
pass
|
|
temp_ply = None
|
|
|
|
if result is None:
|
|
return False, 0.0, {"error": error if error else "Classification failed"}
|
|
|
|
# Extract results
|
|
class_id = result.get('class_id', 0)
|
|
confidence = result.get('confidence', 0.0)
|
|
prediction = result.get('prediction', 'unknown')
|
|
|
|
# Check if predicted as "good" (class_id == 1) AND confidence meets threshold
|
|
is_good = (class_id == 1) and (confidence >= confidence_threshold)
|
|
|
|
# Add prediction info to result for better logging
|
|
result['is_good'] = is_good
|
|
result['meets_threshold'] = is_good
|
|
|
|
return is_good, confidence, result
|
|
|
|
except Exception as e:
|
|
# Clean up temporary file if it exists
|
|
if temp_ply and os.path.exists(temp_ply):
|
|
try:
|
|
os.unlink(temp_ply)
|
|
except:
|
|
pass
|
|
|
|
print(f" Warning: Point cloud classification failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False, 0.0, {"error": str(e)}
|
|
|
|
|
|
def calculate_fish_depth_stats(depth_mm, mask, max_depth_mm=1200.0):
|
|
"""Calculate depth statistics for a fish mask.
|
|
|
|
Returns:
|
|
dict with mean, median, min, max, std depth in mm, or None if no valid points
|
|
"""
|
|
# Filter by mask and valid depth
|
|
valid = mask & np.isfinite(depth_mm) & (depth_mm > 0) & (depth_mm <= max_depth_mm)
|
|
|
|
if not np.any(valid):
|
|
return None
|
|
|
|
valid_depths = depth_mm[valid]
|
|
|
|
return {
|
|
"mean_depth_mm": float(np.mean(valid_depths)),
|
|
"median_depth_mm": float(np.median(valid_depths)),
|
|
"min_depth_mm": float(np.min(valid_depths)),
|
|
"max_depth_mm": float(np.max(valid_depths)),
|
|
"std_depth_mm": float(np.std(valid_depths)),
|
|
"num_points": int(np.sum(valid))
|
|
}
|
|
|
|
|
|
def calculate_iou(box1, box2):
|
|
"""Calculate IoU between two bounding boxes [x1, y1, x2, y2]."""
|
|
x1 = max(box1[0], box2[0])
|
|
y1 = max(box1[1], box2[1])
|
|
x2 = min(box1[2], box2[2])
|
|
y2 = min(box1[3], box2[3])
|
|
|
|
if x2 <= x1 or y2 <= y1:
|
|
return 0.0
|
|
|
|
intersection = (x2 - x1) * (y2 - y1)
|
|
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
|
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
|
union = area1 + area2 - intersection
|
|
|
|
return intersection / union if union > 0 else 0.0
|
|
|
|
|
|
def calculate_bbox_center(box):
|
|
"""Calculate center point of bounding box [x1, y1, x2, y2]."""
|
|
return ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
|
|
|
|
|
|
def calculate_bbox_distance(box1, box2):
|
|
"""Calculate distance between centers of two bounding boxes."""
|
|
center1 = calculate_bbox_center(box1)
|
|
center2 = calculate_bbox_center(box2)
|
|
return np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
|
|
|
|
|
|
def is_bbox_stationary(current_box, previous_boxes, movement_threshold=5.0):
|
|
"""Check if bbox has moved significantly compared to previous positions.
|
|
|
|
Args:
|
|
current_box: Current bounding box [x1, y1, x2, y2]
|
|
previous_boxes: List of previous bounding boxes for this track
|
|
movement_threshold: Minimum pixel movement to consider as moved
|
|
|
|
Returns:
|
|
True if bbox is stationary (hasn't moved), False if it has moved
|
|
"""
|
|
if not previous_boxes or len(previous_boxes) == 0:
|
|
return False
|
|
|
|
# Check distance to most recent previous box
|
|
last_box = previous_boxes[-1]
|
|
distance = calculate_bbox_distance(current_box, last_box)
|
|
|
|
return distance < movement_threshold
|
|
|
|
|
|
def track_fish_across_frames(current_boxes, previous_boxes, iou_threshold=0.3):
|
|
"""Match current detections to previous frame detections using IoU.
|
|
|
|
Returns:
|
|
List of track IDs for current detections (None if new fish)
|
|
"""
|
|
if previous_boxes is None or len(previous_boxes) == 0:
|
|
return [None] * len(current_boxes)
|
|
|
|
track_ids = [None] * len(current_boxes)
|
|
used_prev = [False] * len(previous_boxes)
|
|
|
|
# Match by IoU
|
|
for i, curr_box in enumerate(current_boxes):
|
|
best_iou = 0.0
|
|
best_match = None
|
|
for j, prev_box in enumerate(previous_boxes):
|
|
if used_prev[j]:
|
|
continue
|
|
iou = calculate_iou(curr_box, prev_box)
|
|
if iou > best_iou and iou >= iou_threshold:
|
|
best_iou = iou
|
|
best_match = j
|
|
if best_match is not None:
|
|
track_ids[i] = best_match
|
|
used_prev[best_match] = True
|
|
|
|
return track_ids
|
|
|
|
|
|
def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_device,
|
|
conf=0.25, imgsz=640, max_frames=0, save_images=False, filter_pointcloud=False,
|
|
use_clustering_filter=False, use_density_filter=False,
|
|
pointcloud_classifier=None, use_pointcloud_classifier=False,
|
|
pointcloud_classifier_threshold=0.7,
|
|
flatness_threshold=0.0, use_flatness_filter=False,
|
|
run_template_matching=False, template_path=None, template_folder=None,
|
|
template_scale_factor=1.0, save_raw_pointclouds=False,
|
|
correct_tail_rotation=False, tail_rotation_distance_threshold=5.0,
|
|
tail_rotation_min_tail_ratio=0.7, tail_rotation_min_angle=5.0):
|
|
"""Process a single SVO2 file with pre-loaded YOLO and SAM models.
|
|
|
|
Args:
|
|
svo_path: Path to SVO2 file
|
|
output_base: Base output directory (will create subfolder with SVO name)
|
|
yolo_model: Pre-loaded YOLO model
|
|
sam_predictor: Pre-loaded SAM predictor
|
|
sam_device: SAM device (torch.device)
|
|
conf: YOLO confidence threshold
|
|
imgsz: YOLO image size
|
|
max_frames: Maximum frames to process (0 = all)
|
|
save_images: If True, save individual images instead of video
|
|
filter_pointcloud: If True, apply filtering to remove outliers from point clouds
|
|
use_clustering_filter: If True, use clustering to keep only largest cluster
|
|
use_density_filter: If True, use density filtering (requires at least 200 points within 100mm radius)
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
if not ZED_AVAILABLE:
|
|
print("ERROR: pyzed not available. Cannot read SVO2 files.")
|
|
return False
|
|
|
|
svo_path = Path(svo_path).expanduser().resolve()
|
|
if not svo_path.exists():
|
|
print(f"ERROR: SVO2 file not found: {svo_path}")
|
|
return False
|
|
|
|
svo_name = svo_path.stem
|
|
class_names = yolo_model.names if hasattr(yolo_model, 'names') else {}
|
|
|
|
# Setup output folders
|
|
output_base = Path(output_base) / svo_name
|
|
output_images_folder = output_base / "images"
|
|
output_cloud_folder = output_base / "cloud"
|
|
output_raw_pc_folder = output_base / "raw_pc" if save_raw_pointclouds else None
|
|
output_images_folder.mkdir(parents=True, exist_ok=True)
|
|
output_cloud_folder.mkdir(parents=True, exist_ok=True)
|
|
if save_raw_pointclouds and output_raw_pc_folder:
|
|
output_raw_pc_folder.mkdir(parents=True, exist_ok=True)
|
|
|
|
print(f"Reading from SVO2 file: {svo_path.name}")
|
|
print(f"Output folder: {output_base.resolve()}")
|
|
|
|
# Check if output folder already exists and contains point clouds
|
|
# If so, skip data generation and directly run template matching
|
|
if output_base.exists() and output_cloud_folder.exists():
|
|
# Check if there are point cloud files
|
|
point_cloud_files = list(output_cloud_folder.glob("*.ply"))
|
|
if point_cloud_files and run_template_matching:
|
|
print(f"\n{'='*60}")
|
|
print(f"Output folder already exists with {len(point_cloud_files)} point cloud files")
|
|
print(f"Skipping data generation, directly running template matching...")
|
|
print(f"{'='*60}")
|
|
|
|
# Load kept point clouds from file if it exists, otherwise use all point clouds
|
|
kept_pointclouds = []
|
|
pointcloud_list_path = output_base / "pointclouds_kept_for_template_matching.txt"
|
|
if pointcloud_list_path.exists():
|
|
print(f" Loading point cloud list from: {pointcloud_list_path.name}")
|
|
with open(pointcloud_list_path, 'r', encoding='utf-8') as f:
|
|
kept_pointclouds = [line.strip() for line in f if line.strip()]
|
|
print(f" Found {len(kept_pointclouds)} point clouds in list")
|
|
else:
|
|
# If no list file, use all point clouds in the folder
|
|
kept_pointclouds = [str(f) for f in point_cloud_files]
|
|
print(f" No point cloud list file found, using all {len(kept_pointclouds)} point clouds")
|
|
|
|
if kept_pointclouds:
|
|
# Run template matching directly
|
|
print(f"\n{'='*60}")
|
|
print(f"Running template matching on {len(kept_pointclouds)} point clouds...")
|
|
print(f"{'='*60}")
|
|
|
|
# Validate template matching arguments
|
|
if not template_path and not template_folder:
|
|
print(f" ERROR: --run-template-matching requires either --template or --template-folder to be specified")
|
|
return False
|
|
|
|
# Create output directory for template matching
|
|
template_output_dir = output_base / "template_matching"
|
|
template_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Import fish_align_cli function
|
|
try:
|
|
# Add template_matching directory to path
|
|
template_matching_dir = Path(__file__).parent / "template_matching"
|
|
if str(template_matching_dir) not in sys.path:
|
|
sys.path.insert(0, str(template_matching_dir))
|
|
|
|
from fish_align_cli import process_folder_mode
|
|
|
|
# Process all point clouds using folder mode
|
|
process_folder_mode(
|
|
input_folder=str(output_cloud_folder),
|
|
template_path=template_path,
|
|
template_folder=template_folder,
|
|
output_dir=str(template_output_dir),
|
|
template_scale_factor=template_scale_factor,
|
|
debug=False,
|
|
debug_dir=None
|
|
)
|
|
print(f"✓ Template matching completed. Results saved to: {template_output_dir}")
|
|
return True
|
|
except Exception as e:
|
|
print(f" ERROR: Failed to run template matching: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
else:
|
|
print(f" Warning: No point clouds found to process")
|
|
return False
|
|
elif point_cloud_files and not run_template_matching:
|
|
print(f"\n{'='*60}")
|
|
print(f"Output folder already exists with {len(point_cloud_files)} point cloud files")
|
|
print(f"Template matching not requested (--run-template-matching not set)")
|
|
print(f"Skipping processing...")
|
|
print(f"{'='*60}")
|
|
return True
|
|
# If folder exists but no point clouds, continue with normal processing
|
|
|
|
# Initialize ZED reader
|
|
zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False)
|
|
if not zed_reader.open():
|
|
print("ERROR: Failed to open SVO2 file")
|
|
return False
|
|
|
|
# Get camera intrinsics
|
|
calib_params = zed_reader.zed.get_camera_information().camera_configuration.calibration_parameters
|
|
camera_intrinsics = {
|
|
"fx": float(calib_params.left_cam.fx),
|
|
"fy": float(calib_params.left_cam.fy),
|
|
"cx": float(calib_params.left_cam.cx),
|
|
"cy": float(calib_params.left_cam.cy),
|
|
}
|
|
|
|
runtime_params = sl.RuntimeParameters()
|
|
left_image_mat = sl.Mat()
|
|
depth_mat = sl.Mat()
|
|
|
|
# Tracking variables
|
|
previous_boxes = None
|
|
fish_tracks: Dict[int, List[Dict[str, Any]]] = {}
|
|
track_bbox_history: Dict[int, List[List[float]]] = {}
|
|
track_stationary_count: Dict[int, int] = {}
|
|
next_track_id = 0
|
|
STATIONARY_THRESHOLD = 10
|
|
MOVEMENT_THRESHOLD = 5.0
|
|
|
|
video_frames = []
|
|
idx = 0
|
|
|
|
# List to track point clouds that passed PointNet++ classifier (if enabled)
|
|
kept_pointclouds = []
|
|
|
|
try:
|
|
while True:
|
|
if max_frames > 0 and idx >= max_frames:
|
|
break
|
|
|
|
err = zed_reader.zed.grab(runtime_params)
|
|
if err != sl.ERROR_CODE.SUCCESS:
|
|
break
|
|
|
|
# Retrieve images
|
|
zed_reader.zed.retrieve_image(left_image_mat, sl.VIEW.LEFT)
|
|
zed_reader.zed.retrieve_measure(depth_mat, sl.MEASURE.DEPTH)
|
|
|
|
# Convert to numpy
|
|
left_np = left_image_mat.get_data()
|
|
depth_np = depth_mat.get_data()
|
|
|
|
if left_np.shape[2] > 3:
|
|
img = left_np[:, :, :3].copy()
|
|
else:
|
|
img = left_np.copy()
|
|
|
|
if depth_np.dtype != np.float32:
|
|
depth_data = depth_np.astype(np.float32)
|
|
else:
|
|
depth_data = depth_np.copy()
|
|
|
|
frame_name = f"frame_{idx+1:06d}"
|
|
if idx % 30 == 0:
|
|
print(f"[{idx + 1}] {frame_name}")
|
|
|
|
# Run YOLO tracking
|
|
results = yolo_model.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0]
|
|
num_dets = len(results.boxes) if results.boxes is not None else 0
|
|
|
|
# Process detections (same logic as main function)
|
|
individual_masks = []
|
|
current_boxes = None
|
|
track_ids = []
|
|
depth_stats_list = [] # Store depth stats for each detection
|
|
|
|
try:
|
|
if num_dets > 0:
|
|
boxes_xyxy = results.boxes.xyxy.cpu().numpy()
|
|
current_boxes = boxes_xyxy.tolist()
|
|
|
|
# Get track IDs from YOLO tracking results
|
|
if hasattr(results.boxes, 'id') and results.boxes.id is not None:
|
|
track_ids = results.boxes.id.cpu().numpy().astype(int).tolist()
|
|
else:
|
|
# Fallback: assign sequential IDs if tracking not available
|
|
track_ids = list(range(next_track_id, next_track_id + len(current_boxes)))
|
|
next_track_id += len(current_boxes)
|
|
|
|
# Filter stationary fish
|
|
active_detections = []
|
|
active_masks = []
|
|
active_track_ids = []
|
|
active_boxes = []
|
|
|
|
for fish_idx, (box, track_id) in enumerate(zip(current_boxes, track_ids)):
|
|
if track_id not in track_bbox_history:
|
|
track_bbox_history[track_id] = []
|
|
track_stationary_count[track_id] = 0
|
|
|
|
is_stationary = False
|
|
if len(track_bbox_history[track_id]) > 0:
|
|
is_stationary = is_bbox_stationary(box, track_bbox_history[track_id], MOVEMENT_THRESHOLD)
|
|
|
|
if is_stationary:
|
|
track_stationary_count[track_id] += 1
|
|
if track_stationary_count[track_id] >= STATIONARY_THRESHOLD:
|
|
continue
|
|
else:
|
|
track_stationary_count[track_id] = 0
|
|
|
|
active_detections.append(fish_idx)
|
|
active_boxes.append(box)
|
|
active_track_ids.append(track_id)
|
|
track_bbox_history[track_id].append(box)
|
|
if len(track_bbox_history[track_id]) > 10:
|
|
track_bbox_history[track_id].pop(0)
|
|
|
|
if len(active_detections) == 0:
|
|
num_dets = 0
|
|
individual_masks = []
|
|
previous_boxes = None
|
|
depth_stats_list = []
|
|
else:
|
|
active_boxes_array = np.array(active_boxes)
|
|
all_masks = segment_with_sam(sam_predictor, img, active_boxes_array, sam_device)
|
|
individual_masks = all_masks if all_masks else []
|
|
|
|
# Calculate depth stats and map to original detection order
|
|
depth_stats_list = [None] * len(current_boxes) # Initialize with None for all detections
|
|
if individual_masks and len(individual_masks) > 0 and depth_data is not None:
|
|
for mask_idx, (mask, track_id, box) in enumerate(zip(individual_masks, active_track_ids, active_boxes)):
|
|
depth_stats = calculate_fish_depth_stats(depth_data, mask)
|
|
if depth_stats:
|
|
# Map to original detection index
|
|
original_idx = active_detections[mask_idx]
|
|
depth_stats_list[original_idx] = depth_stats
|
|
|
|
frame_data = {
|
|
"frame_index": idx + 1,
|
|
"frame_name": frame_name,
|
|
"fish_index": active_detections[mask_idx],
|
|
"track_id": int(track_id),
|
|
"bbox": box,
|
|
"depth_stats": depth_stats
|
|
}
|
|
if track_id not in fish_tracks:
|
|
fish_tracks[track_id] = []
|
|
fish_tracks[track_id].append(frame_data)
|
|
|
|
previous_boxes = active_boxes
|
|
current_boxes = active_boxes
|
|
num_dets = len(active_detections)
|
|
|
|
if individual_masks and len(individual_masks) > 0:
|
|
right_display = create_segmentation_overlay(img.copy(), individual_masks)
|
|
cv2.putText(right_display, "Segmentation", (10, right_display.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
|
else:
|
|
right_display = img.copy()
|
|
cv2.putText(right_display, "No detections" if num_dets == 0 else "Segmentation (failed)",
|
|
(10, right_display.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
|
|
(128, 128, 128) if num_dets == 0 else (0, 0, 255), 2, cv2.LINE_AA)
|
|
else:
|
|
right_display = img.copy()
|
|
cv2.putText(right_display, "No detections", (10, right_display.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (128, 128, 128), 2, cv2.LINE_AA)
|
|
previous_boxes = None
|
|
depth_stats_list = []
|
|
except Exception as e:
|
|
print(f" ERROR in processing: {e}")
|
|
right_display = img.copy()
|
|
previous_boxes = None
|
|
depth_stats_list = []
|
|
|
|
# Draw detections with depth info
|
|
left_display = draw_detections(img.copy(), results, class_names, depth_stats_list)
|
|
info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}"
|
|
cv2.putText(left_display, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
|
|
cv2.putText(left_display, "Detection", (10, left_display.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
|
|
|
# Combine and save
|
|
combined_display = np.hstack([left_display, right_display])
|
|
|
|
if num_dets > 0:
|
|
# Save image or collect for video
|
|
if save_images:
|
|
# Save individual image file
|
|
image_path = output_images_folder / f"{frame_name}_detection.png"
|
|
cv2.imwrite(str(image_path), combined_display)
|
|
if idx % 30 == 0:
|
|
print(f" Saved image: {image_path.name}")
|
|
else:
|
|
# Collect for video
|
|
video_frames.append(combined_display.copy())
|
|
|
|
# Save point clouds
|
|
if depth_data is not None and individual_masks and len(individual_masks) > 0:
|
|
for fish_idx, mask in enumerate(individual_masks):
|
|
# Verify mask is being used correctly
|
|
if mask is None:
|
|
print(f" Warning: Mask {fish_idx + 1} is None, skipping point cloud generation")
|
|
continue
|
|
|
|
# Check mask statistics
|
|
mask_pixels = np.sum(mask.astype(bool))
|
|
if mask_pixels == 0:
|
|
print(f" Warning: Mask {fish_idx + 1} is empty, skipping point cloud generation")
|
|
continue
|
|
|
|
# Validate depth values in mask before generating point cloud
|
|
depth_valid, reason, depth_stats = validate_depth_in_mask(
|
|
depth_data, mask, max_depth_mm=1200.0, min_depth_mm=600.0
|
|
)
|
|
if not depth_valid:
|
|
print(f" Skipped point cloud {fish_idx + 1}: depth validation failed - {reason}")
|
|
continue
|
|
|
|
points, colors = depth_mask_to_pointcloud(img, depth_data, mask, camera_intrinsics)
|
|
if points is not None and len(points) > 0:
|
|
# Apply filtering if enabled
|
|
original_count = len(points)
|
|
if filter_pointcloud:
|
|
points, colors = filter_point_cloud(
|
|
points, colors,
|
|
use_clustering_filter=use_clustering_filter,
|
|
use_density_filter=use_density_filter
|
|
)
|
|
filtered_count = len(points)
|
|
# Skip if less than 500 points
|
|
if filtered_count < 500:
|
|
print(f" Skipped point cloud {fish_idx + 1}: {original_count} -> {filtered_count} points (minimum: 500)")
|
|
continue
|
|
elif filtered_count == 0:
|
|
print(f" Warning: All points filtered out for fish {fish_idx + 1} (original: {original_count})")
|
|
continue
|
|
else:
|
|
# Check minimum point count
|
|
if len(points) < 500:
|
|
print(f" Skipped point cloud {fish_idx + 1}: only {len(points)} points (minimum: 500)")
|
|
continue
|
|
|
|
# Apply largest cluster filtering before classification (if classifier is enabled)
|
|
# This removes outliers and keeps only the main fish body cluster
|
|
if use_pointcloud_classifier:
|
|
cluster_before_count = len(points)
|
|
try:
|
|
points, colors, cluster_info = keep_largest_cluster_with_colors(
|
|
points, colors, eps=10.0, min_points=30
|
|
)
|
|
cluster_after_count = len(points)
|
|
|
|
if "error" in cluster_info:
|
|
print(f" Point cloud {fish_idx + 1}: Clustering failed - {cluster_info.get('error', 'Unknown error')}")
|
|
continue
|
|
|
|
if cluster_after_count < 500:
|
|
print(f" Point cloud {fish_idx + 1}: After clustering: {cluster_before_count} -> {cluster_after_count} points (minimum: 500) - SKIPPED")
|
|
continue
|
|
|
|
if cluster_after_count < cluster_before_count:
|
|
print(f" Point cloud {fish_idx + 1}: Clustering removed {cluster_before_count - cluster_after_count} points "
|
|
f"({cluster_info.get('num_clusters', 0)} clusters found, kept largest with {cluster_info.get('largest_cluster_size', 0)} points)")
|
|
except Exception as e:
|
|
print(f" Point cloud {fish_idx + 1}: WARNING - Largest cluster filtering failed: {e}")
|
|
# Continue with original points if clustering fails
|
|
|
|
# Save raw point cloud before classification (if flag is enabled)
|
|
if save_raw_pointclouds and output_raw_pc_folder:
|
|
postfix = f"_{fish_idx + 1}" if len(individual_masks) > 1 else ""
|
|
raw_ply_path = output_raw_pc_folder / f"raw_cloud_{idx+1:04d}_{frame_name}{postfix}.ply"
|
|
write_ply_file(raw_ply_path, points, colors)
|
|
print(f" Saved raw point cloud {fish_idx + 1} (before classifier): {raw_ply_path.name} ({len(points)} points)")
|
|
|
|
# Classify point cloud quality if classifier is available
|
|
if use_pointcloud_classifier:
|
|
if pointcloud_classifier is not None:
|
|
# Classifier is available - perform classification
|
|
is_good, confidence, class_result = classify_pointcloud_array(
|
|
pointcloud_classifier, points, colors, num_point=1024, vote_num=3,
|
|
confidence_threshold=pointcloud_classifier_threshold,
|
|
use_cpu=(sam_device.type == "cpu")
|
|
)
|
|
|
|
# Always print the prediction result
|
|
prediction = class_result.get("prediction", "unknown")
|
|
class_id = class_result.get("class_id", -1)
|
|
|
|
if not is_good:
|
|
# Point cloud is bad - don't save it
|
|
if class_id == 1:
|
|
# Predicted as good but confidence too low
|
|
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SKIPPED (confidence {confidence:.3f} < threshold {pointcloud_classifier_threshold:.3f})")
|
|
else:
|
|
# Predicted as bad
|
|
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SKIPPED (BAD quality)")
|
|
continue
|
|
else:
|
|
# Point cloud is good - proceed to save
|
|
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SAVING (GOOD quality, confidence >= {pointcloud_classifier_threshold:.3f})")
|
|
else:
|
|
# Classifier requested but not available
|
|
print(f" Point cloud {fish_idx + 1}: WARNING - Classifier requested but not loaded, saving without classification")
|
|
else:
|
|
# Classifier not enabled - this should not happen if use_pointcloud_classifier was True
|
|
# This means either the flag was not set, or classifier loading failed
|
|
if use_pointcloud_classifier:
|
|
print(f" Point cloud {fish_idx + 1}: ERROR - Classifier flag was set but classifier is None. Check startup logs for loading errors.")
|
|
else:
|
|
print(f" Point cloud {fish_idx + 1}: Classifier not enabled")
|
|
|
|
# Evaluate flatness if enabled
|
|
if use_flatness_filter:
|
|
try:
|
|
flatness_score, flatness_info = evaluate_flatness_ransac(
|
|
points,
|
|
distance_threshold=5.0,
|
|
ransac_n=3,
|
|
num_iterations=1000
|
|
)
|
|
|
|
if flatness_score < flatness_threshold:
|
|
print(f" Point cloud {fish_idx + 1}: Flatness score {flatness_score:.2f}% < threshold {flatness_threshold:.2f}% - SKIPPED (not flat enough)")
|
|
continue
|
|
else:
|
|
print(f" Point cloud {fish_idx + 1}: Flatness score {flatness_score:.2f}% >= threshold {flatness_threshold:.2f}% - PASSED")
|
|
except Exception as e:
|
|
print(f" Point cloud {fish_idx + 1}: WARNING - Flatness evaluation failed: {e}")
|
|
# If flatness check is required and fails, skip saving
|
|
if use_flatness_filter:
|
|
print(f" Point cloud {fish_idx + 1}: Skipping due to flatness evaluation error")
|
|
continue
|
|
|
|
# Correct tail rotation if enabled (before saving, after all checks)
|
|
if correct_tail_rotation:
|
|
try:
|
|
tail_correction_before_count = len(points)
|
|
points, colors, tail_correction_applied, tail_correction_info = correct_tail_rotation_array(
|
|
points, colors,
|
|
distance_threshold=tail_rotation_distance_threshold,
|
|
min_tail_ratio=tail_rotation_min_tail_ratio,
|
|
min_angle_threshold=tail_rotation_min_angle,
|
|
verbose=True
|
|
)
|
|
if tail_correction_applied:
|
|
print(f" Point cloud {fish_idx + 1}: Tail rotation corrected (angle: {tail_correction_info.get('rotation_angle_degrees', 0):.2f}°)")
|
|
except Exception as e:
|
|
print(f" Point cloud {fish_idx + 1}: WARNING - Tail rotation correction failed: {e}")
|
|
# Continue with original points if correction fails
|
|
|
|
# Save point cloud (passed all checks)
|
|
filtered_count = len(points)
|
|
postfix = f"_{fish_idx + 1}" if len(individual_masks) > 1 else ""
|
|
ply_path = output_cloud_folder / f"cloud_{idx+1:04d}_{frame_name}{postfix}.ply"
|
|
write_ply_file(ply_path, points, colors)
|
|
|
|
# Track point clouds that passed PointNet++ classifier (if enabled)
|
|
# If classifier is enabled, only track those that passed classifier
|
|
# If classifier is not enabled, track all saved point clouds
|
|
if use_pointcloud_classifier and pointcloud_classifier is not None:
|
|
# This point cloud passed all checks including classifier
|
|
kept_pointclouds.append(str(ply_path))
|
|
elif not use_pointcloud_classifier:
|
|
# Classifier not enabled, track all saved point clouds
|
|
kept_pointclouds.append(str(ply_path))
|
|
|
|
if filter_pointcloud:
|
|
print(f" Saved point cloud {fish_idx + 1}: {ply_path.name} ({original_count} -> {filtered_count} points)")
|
|
else:
|
|
print(f" Saved point cloud {fish_idx + 1}: {ply_path.name} ({filtered_count} points)")
|
|
|
|
idx += 1
|
|
|
|
# Create video (only if not saving individual images)
|
|
if not save_images and video_frames:
|
|
video_path = output_images_folder / f"{svo_name}_preview.mp4"
|
|
h, w = video_frames[0].shape[:2]
|
|
fps = 10.0
|
|
video_writer = _open_video_writer(video_path, fps, (w, h))
|
|
for frame in video_frames:
|
|
video_writer.write(frame)
|
|
video_writer.release()
|
|
print(f"✓ Saved video: {video_path.name} ({len(video_frames)} frames)")
|
|
elif save_images:
|
|
print(f"✓ Saved {idx} frames as individual images")
|
|
|
|
# Save tracking stats
|
|
if fish_tracks:
|
|
stats_path = output_base / "fish_depth_tracking.json"
|
|
tracking_data = {
|
|
"source": str(svo_path),
|
|
"total_frames_processed": idx,
|
|
"total_tracks": len(fish_tracks),
|
|
"tracks": {}
|
|
}
|
|
for track_id, frames in fish_tracks.items():
|
|
tracking_data["tracks"][str(track_id)] = {
|
|
"track_id": track_id,
|
|
"num_frames": len(frames),
|
|
"frames": frames,
|
|
"depth_summary": {
|
|
"mean_depth_overall_mm": float(np.mean([f["depth_stats"]["mean_depth_mm"] for f in frames])),
|
|
"std_depth_overall_mm": float(np.std([f["depth_stats"]["mean_depth_mm"] for f in frames])),
|
|
"min_depth_mm": float(np.min([f["depth_stats"]["min_depth_mm"] for f in frames])),
|
|
"max_depth_mm": float(np.max([f["depth_stats"]["max_depth_mm"] for f in frames])),
|
|
}
|
|
}
|
|
with open(stats_path, 'w', encoding='utf-8') as f:
|
|
json.dump(tracking_data, f, indent=2)
|
|
print(f"✓ Saved tracking stats: {stats_path.name}")
|
|
|
|
# Save list of point clouds kept for template matching
|
|
if kept_pointclouds:
|
|
pointcloud_list_path = output_base / "pointclouds_kept_for_template_matching.txt"
|
|
with open(pointcloud_list_path, 'w', encoding='utf-8') as f:
|
|
for ply_path in kept_pointclouds:
|
|
f.write(f"{ply_path}\n")
|
|
if use_pointcloud_classifier:
|
|
print(f"✓ Saved list of {len(kept_pointclouds)} point clouds that passed PointNet++ classifier: {pointcloud_list_path.name}")
|
|
else:
|
|
print(f"✓ Saved list of {len(kept_pointclouds)} point clouds for template matching: {pointcloud_list_path.name}")
|
|
else:
|
|
if use_pointcloud_classifier:
|
|
print(f" Note: PointNet++ classifier was enabled but no point clouds passed the filter")
|
|
else:
|
|
print(f" Note: No point clouds were saved")
|
|
|
|
# Run template matching if requested
|
|
if run_template_matching and kept_pointclouds:
|
|
print(f"\n{'='*60}")
|
|
print(f"Running template matching on {len(kept_pointclouds)} point clouds...")
|
|
print(f"{'='*60}")
|
|
|
|
# Create output directory for template matching
|
|
template_output_dir = output_base / "template_matching"
|
|
template_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Import fish_align_cli function
|
|
try:
|
|
# Add template_matching directory to path
|
|
template_matching_dir = Path(__file__).parent / "template_matching"
|
|
if str(template_matching_dir) not in sys.path:
|
|
sys.path.insert(0, str(template_matching_dir))
|
|
|
|
from fish_align_cli import process_folder_mode
|
|
|
|
# Process all kept point clouds using folder mode
|
|
# This will align all point clouds and use the one with maximum length for weight calculation
|
|
process_folder_mode(
|
|
input_folder=str(output_cloud_folder),
|
|
template_path=template_path,
|
|
template_folder=template_folder,
|
|
output_dir=str(template_output_dir),
|
|
template_scale_factor=template_scale_factor,
|
|
debug=False,
|
|
debug_dir=None
|
|
)
|
|
print(f"✓ Template matching completed. Results saved to: {template_output_dir}")
|
|
except Exception as e:
|
|
print(f" ERROR: Failed to run template matching: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
elif run_template_matching and not kept_pointclouds:
|
|
if use_pointcloud_classifier:
|
|
print(f" Warning: Template matching requested but no point clouds passed PointNet++ classifier")
|
|
else:
|
|
print(f" Warning: Template matching requested but no point clouds were saved")
|
|
|
|
print(f"✓ Processed {idx} frames from {svo_path.name}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"ERROR processing {svo_path.name}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
finally:
|
|
if zed_reader:
|
|
zed_reader.close()
|
|
|
|
|
|
def process_batch_svo2_folder(svo_folder, output_base, yolo_model, sam_predictor, sam_device,
|
|
conf=0.25, imgsz=640, max_frames=0, save_images=False, filter_pointcloud=False,
|
|
use_clustering_filter=False, use_density_filter=False,
|
|
pointcloud_classifier=None, use_pointcloud_classifier=False,
|
|
pointcloud_classifier_threshold=0.7,
|
|
use_flatness_filter=False, flatness_threshold=90.0,
|
|
run_template_matching=False, template_path=None, template_folder=None,
|
|
template_scale_factor=1.0, save_raw_pointclouds=False,
|
|
correct_tail_rotation=False, tail_rotation_distance_threshold=5.0,
|
|
tail_rotation_min_tail_ratio=0.7, tail_rotation_min_angle=5.0):
|
|
"""Process all SVO2 files in a folder with pre-loaded YOLO and SAM models.
|
|
|
|
Args:
|
|
svo_folder: Folder containing SVO2 files
|
|
output_base: Base output directory (will create subfolders with SVO names)
|
|
yolo_model: Pre-loaded YOLO model
|
|
sam_predictor: Pre-loaded SAM predictor
|
|
sam_device: SAM device (torch.device)
|
|
conf: YOLO confidence threshold
|
|
imgsz: YOLO image size
|
|
max_frames: Maximum frames to process per SVO2 (0 = all)
|
|
save_images: If True, save individual images instead of video
|
|
filter_pointcloud: If True, apply filtering to remove outliers from point clouds
|
|
use_clustering_filter: If True, use clustering to keep only largest cluster
|
|
use_density_filter: If True, use density filtering
|
|
pointcloud_classifier: Loaded classifier model or None
|
|
use_pointcloud_classifier: If True, use classifier to filter point clouds
|
|
pointcloud_classifier_threshold: Confidence threshold for classifier
|
|
use_flatness_filter: If True, evaluate flatness before saving
|
|
flatness_threshold: Minimum flatness score required
|
|
|
|
Returns:
|
|
dict with processing statistics
|
|
"""
|
|
svo_folder = Path(svo_folder).expanduser().resolve()
|
|
if not svo_folder.exists() or not svo_folder.is_dir():
|
|
print(f"ERROR: SVO folder not found: {svo_folder}")
|
|
return {"success": False, "error": "Folder not found"}
|
|
|
|
# Find all SVO2 files
|
|
svo_files = sorted(svo_folder.glob("*.svo2"))
|
|
if not svo_files:
|
|
print(f"ERROR: No SVO2 files found in {svo_folder}")
|
|
return {"success": False, "error": "No SVO2 files found"}
|
|
|
|
print(f"Found {len(svo_files)} SVO2 file(s) to process")
|
|
print(f"Output base folder: {output_base}")
|
|
print("="*60)
|
|
|
|
output_base = Path(output_base).expanduser().resolve()
|
|
output_base.mkdir(parents=True, exist_ok=True)
|
|
|
|
success_count = 0
|
|
skipped_count = 0
|
|
failed_count = 0
|
|
|
|
for idx, svo_file in enumerate(svo_files):
|
|
svo_name = svo_file.stem
|
|
output_subfolder = output_base / svo_name
|
|
|
|
# Check if output folder already exists
|
|
if output_subfolder.exists() and output_subfolder.is_dir():
|
|
print(f"\n{'='*60}")
|
|
print(f"[{idx + 1}/{len(svo_files)}] Skipping: {svo_name}")
|
|
print(f" Output folder already exists: {output_subfolder}")
|
|
print(f"{'='*60}")
|
|
skipped_count += 1
|
|
continue
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"[{idx + 1}/{len(svo_files)}] Processing: {svo_name}")
|
|
print(f"{'='*60}")
|
|
|
|
try:
|
|
success = process_single_svo2(
|
|
svo_path=str(svo_file),
|
|
output_base=str(output_base),
|
|
yolo_model=yolo_model,
|
|
sam_predictor=sam_predictor,
|
|
sam_device=sam_device,
|
|
conf=conf,
|
|
imgsz=imgsz,
|
|
max_frames=max_frames,
|
|
save_images=save_images,
|
|
filter_pointcloud=filter_pointcloud,
|
|
use_clustering_filter=use_clustering_filter,
|
|
use_density_filter=use_density_filter,
|
|
pointcloud_classifier=pointcloud_classifier,
|
|
use_pointcloud_classifier=use_pointcloud_classifier,
|
|
pointcloud_classifier_threshold=pointcloud_classifier_threshold,
|
|
flatness_threshold=flatness_threshold,
|
|
use_flatness_filter=use_flatness_filter,
|
|
run_template_matching=run_template_matching,
|
|
template_path=template_path,
|
|
template_folder=template_folder,
|
|
template_scale_factor=template_scale_factor,
|
|
save_raw_pointclouds=save_raw_pointclouds,
|
|
correct_tail_rotation=correct_tail_rotation,
|
|
tail_rotation_distance_threshold=tail_rotation_distance_threshold,
|
|
tail_rotation_min_tail_ratio=tail_rotation_min_tail_ratio,
|
|
tail_rotation_min_angle=tail_rotation_min_angle
|
|
)
|
|
|
|
if success:
|
|
success_count += 1
|
|
print(f"✓ Successfully processed: {svo_name}")
|
|
else:
|
|
failed_count += 1
|
|
print(f"✗ Failed to process: {svo_name}")
|
|
except Exception as e:
|
|
failed_count += 1
|
|
print(f"✗ Error processing {svo_name}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
print(f"\n{'='*60}")
|
|
print("Batch Processing Summary")
|
|
print(f"{'='*60}")
|
|
print(f"Total files: {len(svo_files)}")
|
|
print(f"Successfully processed: {success_count}")
|
|
print(f"Skipped (already exists): {skipped_count}")
|
|
print(f"Failed: {failed_count}")
|
|
print(f"Output directory: {output_base.resolve()}")
|
|
print(f"{'='*60}")
|
|
|
|
return {
|
|
"success": True,
|
|
"total": len(svo_files),
|
|
"success_count": success_count,
|
|
"skipped_count": skipped_count,
|
|
"failed_count": failed_count
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Preview images with YOLO detection and SAM segmentation")
|
|
input_group = parser.add_mutually_exclusive_group()
|
|
input_group.add_argument("--image-folder", type=str, default="", help="Folder containing images")
|
|
input_group.add_argument("--svo", type=str, default="", help="Path to SVO2 file (alternative to --image-folder)")
|
|
input_group.add_argument("--batch-svo-folder", type=str, default="",
|
|
help="Folder containing SVO2 files to process in batch mode. Each SVO2 will be processed and saved to output/<svo_name>/")
|
|
parser.add_argument("--yolo-model",
|
|
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt",
|
|
help="YOLO model path")
|
|
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
|
|
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
|
|
parser.add_argument("--scale", type=float, default=1.0, help="Display scale (0.5 = half size)")
|
|
parser.add_argument("--sam-device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
|
|
help="Device for SAM (cuda or cpu)")
|
|
parser.add_argument("--save-output", type=str, default="",
|
|
help="Save preview images to this folder instead of displaying. If empty, display window. For batch mode, this is the base output folder where subfolders will be created for each SVO2 file.")
|
|
parser.add_argument("--save-images", action="store_true",
|
|
help="Save individual image files instead of creating a video (only works with --save-output)")
|
|
parser.add_argument("--filter-pointcloud", action="store_true",
|
|
help="Apply filtering to remove outliers from point clouds (default: no filtering)")
|
|
parser.add_argument("--use-clustering-filter", action="store_true",
|
|
help="Use clustering filter to keep only the largest cluster (requires --filter-pointcloud)")
|
|
parser.add_argument("--use-density-filter", action="store_true",
|
|
help="Use density filter: keep only points with at least 200 neighbors within 100mm radius (requires --filter-pointcloud)")
|
|
parser.add_argument("--max-frames", type=int, default=0,
|
|
help="Maximum frames to process from SVO2 (0 = all frames)")
|
|
parser.add_argument("--pointcloud-classifier", type=str, default=None,
|
|
help="Path to point cloud quality classifier checkpoint (e.g., log/classification/fish_pointnet2_finetune/checkpoints/best_model.pth)")
|
|
parser.add_argument("--use-pointcloud-classifier", action="store_true",
|
|
help="Use point cloud classifier to filter out bad quality point clouds (requires --pointcloud-classifier)")
|
|
parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7,
|
|
help="Confidence threshold for point cloud classifier (default: 0.7). Only point clouds classified as 'good' with confidence >= threshold will be saved.")
|
|
parser.add_argument("--use-flatness-filter", action="store_true",
|
|
help="Evaluate point cloud flatness before saving. Skip point clouds that are not flat enough.")
|
|
parser.add_argument("--flatness-threshold", type=float, default=90.0,
|
|
help="Minimum flatness score (0-100%%) required to save point cloud (default: 50.0%%). Higher values mean stricter flatness requirement.")
|
|
parser.add_argument("--run-template-matching", action="store_true",
|
|
help="After processing, run template matching on point clouds that passed PointNet++ classifier")
|
|
parser.add_argument("--template", type=str, default=None,
|
|
help="Path to template mesh file for template matching (required if --run-template-matching is set)")
|
|
parser.add_argument("--template-folder", type=str, default=None,
|
|
help="Folder containing multiple template meshes. Will try all templates and select the one with best alignment (alternative to --template)")
|
|
parser.add_argument("--template-scale-factor", type=float, default=1.0,
|
|
help="Scale factor applied to template mesh before weight calculation (default: 1.0)")
|
|
parser.add_argument("--save-raw-pointclouds", action="store_true",
|
|
help="Save point clouds to raw_pc folder before passing to PointNet++ classifier. Useful for debugging why some point clouds fail classification.")
|
|
parser.add_argument("--correct-tail-rotation", action="store_true",
|
|
help="Correct fish tail rotation using dual RANSAC plane fitting. This should be applied before template matching as it affects height/length ratio.")
|
|
parser.add_argument("--tail-rotation-distance-threshold", type=float, default=5.0,
|
|
help="Maximum distance from point to plane for tail rotation correction (mm, default: 5.0)")
|
|
parser.add_argument("--tail-rotation-min-tail-ratio", type=float, default=0.7,
|
|
help="Minimum ratio of outliers on tail plane to detect rotation (default: 0.7)")
|
|
parser.add_argument("--tail-rotation-min-angle", type=float, default=5.0,
|
|
help="Minimum angle (degrees) between planes to apply tail rotation correction (default: 5.0)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate input source
|
|
input_count = sum([bool(args.image_folder), bool(args.svo), bool(args.batch_svo_folder)])
|
|
if input_count == 0:
|
|
parser.error("One of --image-folder, --svo, or --batch-svo-folder must be provided")
|
|
if input_count > 1:
|
|
parser.error("Cannot specify multiple input sources. Use only one of --image-folder, --svo, or --batch-svo-folder")
|
|
|
|
# Determine input source
|
|
use_batch_mode = bool(args.batch_svo_folder)
|
|
use_svo = bool(args.svo) and not use_batch_mode
|
|
|
|
# Handle batch mode first
|
|
if use_batch_mode:
|
|
if not ZED_AVAILABLE:
|
|
print("ERROR: pyzed not available. Cannot read SVO2 files.")
|
|
return
|
|
|
|
svo_folder = Path(args.batch_svo_folder).expanduser().resolve()
|
|
if not svo_folder.exists() or not svo_folder.is_dir():
|
|
print(f"ERROR: SVO folder not found: {svo_folder}")
|
|
return
|
|
|
|
# Load YOLO
|
|
print(f"Loading YOLO: {args.yolo_model}")
|
|
yolo_model = YOLO(args.yolo_model)
|
|
print(f"✓ YOLO loaded")
|
|
|
|
# Load SAM
|
|
print(f"Loading SAM (device: {args.sam_device})...")
|
|
sam_predictor = init_models(device=args.sam_device, seg_model="sam")
|
|
sam_device_obj = torch.device(args.sam_device)
|
|
print(f"✓ SAM loaded")
|
|
|
|
# Load point cloud classifier if specified
|
|
pointcloud_classifier = None
|
|
if args.use_pointcloud_classifier:
|
|
classifier_path = args.pointcloud_classifier
|
|
if classifier_path is None:
|
|
default_path = Path(__file__).parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "log" / "classification" / "fish_pointnet2_finetune" / "checkpoints" / "best_model.pth"
|
|
if default_path.exists():
|
|
classifier_path = str(default_path)
|
|
print(f"Using default classifier path: {classifier_path}")
|
|
else:
|
|
print("Warning: --use-pointcloud-classifier specified but --pointcloud-classifier not provided and default not found")
|
|
args.use_pointcloud_classifier = False
|
|
|
|
if args.use_pointcloud_classifier and classifier_path:
|
|
print(f"Loading point cloud classifier from: {classifier_path}")
|
|
pointcloud_classifier = load_pointcloud_classifier(
|
|
classifier_path,
|
|
num_classes=2,
|
|
use_normals=False,
|
|
device=args.sam_device
|
|
)
|
|
if pointcloud_classifier is None:
|
|
print("Warning: Failed to load point cloud classifier. Continuing without quality filtering.")
|
|
args.use_pointcloud_classifier = False
|
|
else:
|
|
print(f"✓ Point cloud classifier loaded")
|
|
|
|
# Validate template matching arguments
|
|
if args.run_template_matching:
|
|
if not args.template and not args.template_folder:
|
|
parser.error("--run-template-matching requires either --template or --template-folder to be specified")
|
|
|
|
# Process batch
|
|
output_base = args.save_output if args.save_output else "output_preview"
|
|
process_batch_svo2_folder(
|
|
svo_folder=str(svo_folder),
|
|
output_base=output_base,
|
|
yolo_model=yolo_model,
|
|
sam_predictor=sam_predictor,
|
|
sam_device=sam_device_obj,
|
|
conf=args.conf,
|
|
imgsz=args.imgsz,
|
|
max_frames=args.max_frames,
|
|
save_images=args.save_images,
|
|
filter_pointcloud=args.filter_pointcloud,
|
|
use_clustering_filter=args.use_clustering_filter,
|
|
use_density_filter=args.use_density_filter,
|
|
pointcloud_classifier=pointcloud_classifier,
|
|
use_pointcloud_classifier=args.use_pointcloud_classifier,
|
|
pointcloud_classifier_threshold=args.pointcloud_classifier_threshold,
|
|
use_flatness_filter=args.use_flatness_filter,
|
|
flatness_threshold=args.flatness_threshold,
|
|
run_template_matching=args.run_template_matching,
|
|
template_path=args.template,
|
|
template_folder=args.template_folder,
|
|
template_scale_factor=args.template_scale_factor,
|
|
save_raw_pointclouds=args.save_raw_pointclouds,
|
|
correct_tail_rotation=args.correct_tail_rotation,
|
|
tail_rotation_distance_threshold=args.tail_rotation_distance_threshold,
|
|
tail_rotation_min_tail_ratio=args.tail_rotation_min_tail_ratio,
|
|
tail_rotation_min_angle=args.tail_rotation_min_angle
|
|
)
|
|
return
|
|
|
|
if use_svo:
|
|
if not ZED_AVAILABLE:
|
|
print("ERROR: pyzed not available. Cannot read SVO2 files.")
|
|
return
|
|
svo_path = Path(args.svo).expanduser().resolve()
|
|
if not svo_path.exists():
|
|
print(f"ERROR: SVO2 file not found: {svo_path}")
|
|
return
|
|
print(f"Reading from SVO2 file: {svo_path}")
|
|
else:
|
|
image_folder = Path(args.image_folder)
|
|
if not image_folder.is_dir():
|
|
print(f"ERROR: Folder not found: {image_folder}")
|
|
return
|
|
|
|
image_files = []
|
|
for ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
|
|
image_files.extend(sorted(image_folder.glob(f'*{ext}')))
|
|
image_files.extend(sorted(image_folder.glob(f'*{ext.upper()}')))
|
|
|
|
if not image_files:
|
|
print(f"ERROR: No images found in {image_folder}")
|
|
return
|
|
|
|
print(f"Found {len(image_files)} images")
|
|
|
|
# Load YOLO
|
|
print(f"Loading YOLO: {args.yolo_model}")
|
|
yolo_model = YOLO(args.yolo_model)
|
|
class_names = yolo_model.names if hasattr(yolo_model, 'names') else {}
|
|
print(f"✓ YOLO loaded. Classes: {class_names}")
|
|
|
|
# Load SAM
|
|
print(f"Loading SAM (device: {args.sam_device})...")
|
|
sam_predictor = init_models(device=args.sam_device, seg_model="sam")
|
|
sam_device = torch.device(args.sam_device)
|
|
print(f"✓ SAM loaded")
|
|
|
|
# Load point cloud classifier if specified
|
|
pointcloud_classifier = None
|
|
if args.use_pointcloud_classifier:
|
|
# Determine classifier path
|
|
classifier_path = args.pointcloud_classifier
|
|
if classifier_path is None:
|
|
# Try default path
|
|
default_path = Path(__file__).parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "log" / "classification" / "fish_pointnet2_finetune" / "checkpoints" / "best_model.pth"
|
|
if default_path.exists():
|
|
classifier_path = str(default_path)
|
|
print(f"Using default classifier path: {classifier_path}")
|
|
else:
|
|
print("Warning: --use-pointcloud-classifier specified but --pointcloud-classifier not provided and default not found")
|
|
print(f" Default path checked: {default_path}")
|
|
args.use_pointcloud_classifier = False
|
|
|
|
# Load classifier if path is available
|
|
if classifier_path is not None:
|
|
print(f"Loading point cloud classifier from: {classifier_path}")
|
|
pointcloud_classifier = load_pointcloud_classifier(
|
|
classifier_path,
|
|
num_classes=2,
|
|
use_normals=False,
|
|
device=args.sam_device
|
|
)
|
|
if pointcloud_classifier is None:
|
|
print("ERROR: Failed to load point cloud classifier from:", classifier_path)
|
|
print(" Continuing without quality filtering. Check the path and model file.")
|
|
args.use_pointcloud_classifier = False
|
|
else:
|
|
print(f"✓ Point cloud classifier loaded successfully from: {classifier_path}")
|
|
print(f" Confidence threshold: {args.pointcloud_classifier_threshold}")
|
|
print(f" Only point clouds classified as 'good' with confidence >= {args.pointcloud_classifier_threshold} will be saved")
|
|
|
|
# Setup output folders if saving
|
|
save_mode = bool(args.save_output)
|
|
output_images_folder = None
|
|
output_cloud_folder = None
|
|
video_frames = [] # Store frames for video creation
|
|
video_writer = None
|
|
if save_mode:
|
|
output_base = Path(args.save_output)
|
|
|
|
# Create subfolder based on input source name
|
|
if use_svo:
|
|
# Extract filename without extension from SVO2 path
|
|
svo_name = svo_path.stem # e.g., "HD1080_SN43186771_16-41-37" from "HD1080_SN43186771_16-41-37.svo2"
|
|
output_base = output_base / svo_name
|
|
else:
|
|
# For image folders, use folder name
|
|
folder_name = image_folder.name
|
|
output_base = output_base / folder_name
|
|
|
|
output_images_folder = output_base / "images"
|
|
output_cloud_folder = output_base / "cloud"
|
|
output_raw_pc_folder = output_base / "raw_pc" if args.save_raw_pointclouds else None
|
|
output_images_folder.mkdir(parents=True, exist_ok=True)
|
|
output_cloud_folder.mkdir(parents=True, exist_ok=True)
|
|
if args.save_raw_pointclouds and output_raw_pc_folder:
|
|
output_raw_pc_folder.mkdir(parents=True, exist_ok=True)
|
|
print(f"✓ Output base folder: {output_base.resolve()}")
|
|
print(f"✓ Output images folder (for video): {output_images_folder.resolve()}")
|
|
print(f"✓ Output cloud folder: {output_cloud_folder.resolve()}")
|
|
if args.save_raw_pointclouds:
|
|
print(f"✓ Output raw point cloud folder (before classifier): {output_raw_pc_folder.resolve()}")
|
|
|
|
# Check if output folder already exists and contains point clouds
|
|
# If so, skip data generation and directly run template matching
|
|
if output_base.exists() and output_cloud_folder.exists():
|
|
# Check if there are point cloud files
|
|
point_cloud_files = list(output_cloud_folder.glob("*.ply"))
|
|
if point_cloud_files and args.run_template_matching:
|
|
print(f"\n{'='*60}")
|
|
print(f"Output folder already exists with {len(point_cloud_files)} point cloud files")
|
|
print(f"Skipping data generation, directly running template matching...")
|
|
print(f"{'='*60}")
|
|
|
|
# Load kept point clouds from file if it exists, otherwise use all point clouds
|
|
kept_pointclouds = []
|
|
pointcloud_list_path = output_base / "pointclouds_kept_for_template_matching.txt"
|
|
if pointcloud_list_path.exists():
|
|
print(f" Loading point cloud list from: {pointcloud_list_path.name}")
|
|
with open(pointcloud_list_path, 'r', encoding='utf-8') as f:
|
|
kept_pointclouds = [line.strip() for line in f if line.strip()]
|
|
print(f" Found {len(kept_pointclouds)} point clouds in list")
|
|
else:
|
|
# If no list file, use all point clouds in the folder
|
|
kept_pointclouds = [str(f) for f in point_cloud_files]
|
|
print(f" No point cloud list file found, using all {len(kept_pointclouds)} point clouds")
|
|
|
|
if kept_pointclouds:
|
|
# Validate template matching arguments
|
|
if not args.template and not args.template_folder:
|
|
print(f" ERROR: --run-template-matching requires either --template or --template-folder to be specified")
|
|
return
|
|
|
|
# Run template matching directly
|
|
print(f"\n{'='*60}")
|
|
print(f"Running template matching on {len(kept_pointclouds)} point clouds...")
|
|
print(f"{'='*60}")
|
|
|
|
# Create output directory for template matching
|
|
template_output_dir = output_base / "template_matching"
|
|
template_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Import fish_align_cli function
|
|
try:
|
|
# Add template_matching directory to path
|
|
template_matching_dir = Path(__file__).parent / "template_matching"
|
|
if str(template_matching_dir) not in sys.path:
|
|
sys.path.insert(0, str(template_matching_dir))
|
|
|
|
from fish_align_cli import process_folder_mode
|
|
|
|
# Process all point clouds using folder mode
|
|
process_folder_mode(
|
|
input_folder=str(output_cloud_folder),
|
|
template_path=args.template,
|
|
template_folder=args.template_folder,
|
|
output_dir=str(template_output_dir),
|
|
template_scale_factor=args.template_scale_factor,
|
|
debug=False,
|
|
debug_dir=None
|
|
)
|
|
print(f"✓ Template matching completed. Results saved to: {template_output_dir}")
|
|
return
|
|
except Exception as e:
|
|
print(f" ERROR: Failed to run template matching: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return
|
|
else:
|
|
print(f" Warning: No point clouds found to process")
|
|
return
|
|
elif point_cloud_files and not args.run_template_matching:
|
|
print(f"\n{'='*60}")
|
|
print(f"Output folder already exists with {len(point_cloud_files)} point cloud files")
|
|
print(f"Template matching not requested (--run-template-matching not set)")
|
|
print(f"Skipping processing...")
|
|
print(f"{'='*60}")
|
|
return
|
|
# If folder exists but no point clouds, continue with normal processing
|
|
else:
|
|
window_name = "Fish Detection & Segmentation Preview"
|
|
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
|
|
|
print("\n" + "="*60)
|
|
if save_mode:
|
|
print("Image Processing Started (Saving to disk)")
|
|
else:
|
|
print("Image Preview Started")
|
|
print("Controls:")
|
|
print(" Press any key - Next image")
|
|
print(" Press 'q' - Quit")
|
|
print("="*60 + "\n")
|
|
|
|
# Process frames
|
|
idx = 0
|
|
zed_reader = None
|
|
runtime_params = None
|
|
left_image_mat = None
|
|
depth_mat = None
|
|
camera_intrinsics = None
|
|
|
|
# Tracking for consecutive frames with fish
|
|
previous_boxes = None
|
|
fish_tracks: Dict[int, List[Dict[str, Any]]] = {} # track_id -> list of frame data
|
|
track_bbox_history: Dict[int, List[List[float]]] = {} # track_id -> list of bboxes
|
|
track_stationary_count: Dict[int, int] = {} # track_id -> consecutive stationary frames
|
|
next_track_id = 0
|
|
STATIONARY_THRESHOLD = 10 # Number of consecutive frames without movement
|
|
MOVEMENT_THRESHOLD = 5.0 # Minimum pixel movement to consider as moved
|
|
|
|
# List to track point clouds that passed PointNet++ classifier (if enabled)
|
|
kept_pointclouds = []
|
|
|
|
if use_svo:
|
|
# Initialize ZED reader for SVO2
|
|
zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False)
|
|
if not zed_reader.open():
|
|
print("ERROR: Failed to open SVO2 file")
|
|
return
|
|
|
|
# Get camera intrinsics from ZED
|
|
calib_params = zed_reader.zed.get_camera_information().camera_configuration.calibration_parameters
|
|
camera_intrinsics = {
|
|
"fx": float(calib_params.left_cam.fx),
|
|
"fy": float(calib_params.left_cam.fy),
|
|
"cx": float(calib_params.left_cam.cx),
|
|
"cy": float(calib_params.left_cam.cy),
|
|
}
|
|
print(f"✓ SVO2 file opened. Camera intrinsics: fx={camera_intrinsics['fx']:.1f}, fy={camera_intrinsics['fy']:.1f}")
|
|
|
|
runtime_params = sl.RuntimeParameters()
|
|
left_image_mat = sl.Mat()
|
|
depth_mat = sl.Mat()
|
|
else:
|
|
# For image folder, use default/estimated intrinsics
|
|
# These are typical values for HD1080 ZED camera
|
|
camera_intrinsics = {
|
|
"fx": 1066.8, # Approximate for HD1080
|
|
"fy": 1066.8,
|
|
"cx": 960.0,
|
|
"cy": 540.0,
|
|
}
|
|
print(f"Using default camera intrinsics: fx={camera_intrinsics['fx']:.1f}, fy={camera_intrinsics['fy']:.1f}")
|
|
|
|
while True:
|
|
if use_svo:
|
|
# Read frame from SVO2
|
|
if args.max_frames > 0 and idx >= args.max_frames:
|
|
break
|
|
|
|
err = zed_reader.zed.grab(runtime_params)
|
|
if err != sl.ERROR_CODE.SUCCESS:
|
|
print(f"\nEnd of SVO2 file or error: {err}")
|
|
break
|
|
|
|
# Retrieve images
|
|
zed_reader.zed.retrieve_image(left_image_mat, sl.VIEW.LEFT)
|
|
zed_reader.zed.retrieve_measure(depth_mat, sl.MEASURE.DEPTH)
|
|
|
|
# Convert to numpy
|
|
left_np = left_image_mat.get_data()
|
|
depth_np = depth_mat.get_data()
|
|
|
|
# Ensure BGR format (drop alpha if present)
|
|
if left_np.shape[2] > 3:
|
|
img = left_np[:, :, :3].copy()
|
|
else:
|
|
img = left_np.copy()
|
|
|
|
# Convert depth to float32 if needed
|
|
if depth_np.dtype != np.float32:
|
|
depth_data = depth_np.astype(np.float32)
|
|
else:
|
|
depth_data = depth_np.copy()
|
|
|
|
frame_name = f"frame_{idx+1:06d}"
|
|
print(f"[{idx + 1}] {frame_name}")
|
|
else:
|
|
# Read from image folder
|
|
if idx >= len(image_files):
|
|
break
|
|
|
|
img_path = image_files[idx]
|
|
print(f"[{idx + 1}/{len(image_files)}] {img_path.name}")
|
|
|
|
# Load image
|
|
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
|
|
if img is None:
|
|
print(f" ERROR: Failed to load")
|
|
idx += 1
|
|
continue
|
|
|
|
# Try to find corresponding depth file
|
|
depth_data = None
|
|
# Look for .npy depth files (common format from ZED)
|
|
depth_candidates = [
|
|
img_path.parent / f"{img_path.stem}_depth.npy",
|
|
img_path.parent / f"{img_path.stem.replace('_rgb', '_depth')}.npy",
|
|
img_path.parent / f"{img_path.stem.replace('rgb', 'depth')}.npy",
|
|
]
|
|
for depth_npy_path in depth_candidates:
|
|
if depth_npy_path.exists():
|
|
try:
|
|
depth_data = np.load(str(depth_npy_path))
|
|
print(f" Found depth file: {depth_npy_path.name}")
|
|
break
|
|
except Exception as e:
|
|
print(f" Warning: Could not load depth file {depth_npy_path.name}: {e}")
|
|
|
|
frame_name = img_path.stem
|
|
|
|
# Run YOLO tracking
|
|
results = yolo_model.track(img, conf=args.conf, imgsz=args.imgsz, verbose=False, persist=True)[0]
|
|
num_dets = len(results.boxes) if results.boxes is not None else 0
|
|
print(f" Detections: {num_dets}")
|
|
|
|
# Right side: Segmentation overlay and tracking
|
|
individual_masks = []
|
|
current_boxes = None
|
|
track_ids = []
|
|
depth_stats_list = [] # Store depth stats for each detection
|
|
|
|
try:
|
|
if num_dets > 0:
|
|
boxes_xyxy = results.boxes.xyxy.cpu().numpy()
|
|
current_boxes = boxes_xyxy.tolist()
|
|
|
|
# Get track IDs from YOLO tracking results
|
|
if hasattr(results.boxes, 'id') and results.boxes.id is not None:
|
|
track_ids = results.boxes.id.cpu().numpy().astype(int).tolist()
|
|
else:
|
|
# Fallback: assign sequential IDs if tracking not available
|
|
track_ids = list(range(next_track_id, next_track_id + len(current_boxes)))
|
|
next_track_id += len(current_boxes)
|
|
|
|
# Check for stationary fish and filter them out
|
|
active_detections = []
|
|
active_masks = []
|
|
active_track_ids = []
|
|
active_boxes = []
|
|
|
|
for fish_idx, (box, track_id) in enumerate(zip(current_boxes, track_ids)):
|
|
# Initialize tracking history if new track
|
|
if track_id not in track_bbox_history:
|
|
track_bbox_history[track_id] = []
|
|
track_stationary_count[track_id] = 0
|
|
|
|
# Check if bbox is stationary
|
|
is_stationary = False
|
|
if len(track_bbox_history[track_id]) > 0:
|
|
is_stationary = is_bbox_stationary(box, track_bbox_history[track_id], MOVEMENT_THRESHOLD)
|
|
|
|
if is_stationary:
|
|
track_stationary_count[track_id] += 1
|
|
if track_stationary_count[track_id] >= STATIONARY_THRESHOLD:
|
|
print(f" Fish {fish_idx+1} (Track {track_id}): STATIONARY for {track_stationary_count[track_id]} frames - FILTERED OUT")
|
|
continue # Skip this detection
|
|
else:
|
|
print(f" Fish {fish_idx+1} (Track {track_id}): stationary for {track_stationary_count[track_id]} frames (will filter at {STATIONARY_THRESHOLD})")
|
|
else:
|
|
# Fish moved, reset stationary count
|
|
track_stationary_count[track_id] = 0
|
|
|
|
# Add to active detections
|
|
active_detections.append(fish_idx)
|
|
active_boxes.append(box)
|
|
active_track_ids.append(track_id)
|
|
|
|
# Update bbox history (keep last 10 for checking)
|
|
track_bbox_history[track_id].append(box)
|
|
if len(track_bbox_history[track_id]) > 10:
|
|
track_bbox_history[track_id].pop(0)
|
|
|
|
# Only process active (moving) detections
|
|
# Initialize depth_stats_list for all detections (will be filled for active ones)
|
|
depth_stats_list = [None] * len(current_boxes) if current_boxes else []
|
|
|
|
if len(active_detections) == 0:
|
|
print(f" All {num_dets} detection(s) are stationary - skipping frame")
|
|
num_dets = 0 # Treat as no detections
|
|
individual_masks = []
|
|
previous_boxes = None # Reset tracking
|
|
else:
|
|
print(f" Processing {len(active_detections)} active detection(s) (filtered {num_dets - len(active_detections)} stationary)")
|
|
# Get masks only for active detections
|
|
active_boxes_array = np.array(active_boxes)
|
|
print(f" Running SAM segmentation...")
|
|
all_masks = segment_with_sam(sam_predictor, img, active_boxes_array, sam_device)
|
|
|
|
# Map masks to active detections
|
|
individual_masks = all_masks if all_masks else []
|
|
|
|
# Calculate depth statistics for each active fish and map to original detection order
|
|
if individual_masks and len(individual_masks) > 0 and depth_data is not None:
|
|
for mask_idx, (mask, track_id, box) in enumerate(zip(individual_masks, active_track_ids, active_boxes)):
|
|
depth_stats = calculate_fish_depth_stats(depth_data, mask)
|
|
if depth_stats:
|
|
# Map to original detection index
|
|
original_idx = active_detections[mask_idx]
|
|
depth_stats_list[original_idx] = depth_stats
|
|
|
|
frame_data = {
|
|
"frame_index": idx + 1,
|
|
"frame_name": frame_name,
|
|
"fish_index": active_detections[mask_idx],
|
|
"track_id": int(track_id),
|
|
"bbox": box,
|
|
"depth_stats": depth_stats
|
|
}
|
|
|
|
if track_id not in fish_tracks:
|
|
fish_tracks[track_id] = []
|
|
fish_tracks[track_id].append(frame_data)
|
|
|
|
print(f" Fish {active_detections[mask_idx]+1} (Track {track_id}): "
|
|
f"mean_depth={depth_stats['mean_depth_mm']:.1f}mm, "
|
|
f"std={depth_stats['std_depth_mm']:.1f}mm")
|
|
|
|
# Update previous boxes for next frame (only active ones)
|
|
previous_boxes = active_boxes
|
|
current_boxes = active_boxes # Update for saving
|
|
num_dets = len(active_detections) # Update count
|
|
|
|
if individual_masks and len(individual_masks) > 0:
|
|
right_display = create_segmentation_overlay(img.copy(), individual_masks)
|
|
cv2.putText(right_display, "Segmentation", (10, right_display.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
|
else:
|
|
right_display = img.copy()
|
|
cv2.putText(right_display, "Segmentation (failed)", (10, right_display.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2, cv2.LINE_AA)
|
|
else:
|
|
right_display = img.copy()
|
|
cv2.putText(right_display, "No detections", (10, right_display.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (128, 128, 128), 2, cv2.LINE_AA)
|
|
# Reset tracking when no detections
|
|
previous_boxes = None
|
|
depth_stats_list = []
|
|
except Exception as e:
|
|
print(f" ERROR in SAM segmentation: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
right_display = img.copy()
|
|
cv2.putText(right_display, f"Error: {str(e)[:30]}", (10, right_display.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2, cv2.LINE_AA)
|
|
previous_boxes = None
|
|
depth_stats_list = []
|
|
|
|
# Left side: Original image with detection boxes and depth info
|
|
left_display = draw_detections(img.copy(), results, class_names, depth_stats_list)
|
|
if use_svo:
|
|
info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}"
|
|
else:
|
|
info = f"[{idx + 1}/{len(image_files)}] {img_path.name} | Detections: {num_dets}"
|
|
cv2.putText(left_display, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
|
|
cv2.putText(left_display, "Detection", (10, left_display.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
|
|
|
# Combine left and right side by side
|
|
combined_display = np.hstack([left_display, right_display])
|
|
|
|
# Scale if needed
|
|
if args.scale != 1.0:
|
|
h, w = combined_display.shape[:2]
|
|
combined_display = cv2.resize(combined_display, (int(w * args.scale), int(h * args.scale)))
|
|
|
|
# Save or display
|
|
if save_mode:
|
|
# Save image or collect for video only when fish detected
|
|
if num_dets > 0:
|
|
if args.save_images:
|
|
# Save individual image file
|
|
if use_svo:
|
|
image_path = output_images_folder / f"{frame_name}_detection.png"
|
|
else:
|
|
image_path = output_images_folder / f"{img_path.stem}_detection.png"
|
|
cv2.imwrite(str(image_path), combined_display)
|
|
print(f" Saved image: {image_path.name}")
|
|
else:
|
|
# Collect for video
|
|
video_frames.append(combined_display.copy())
|
|
print(f" Added frame {len(video_frames)} to video")
|
|
|
|
# Generate and save point clouds for each detected fish
|
|
if num_dets > 0 and depth_data is not None and individual_masks and len(individual_masks) > 0:
|
|
print(f" Generating point clouds for {len(individual_masks)} fish...")
|
|
for fish_idx, mask in enumerate(individual_masks):
|
|
# Verify mask is being used correctly
|
|
if mask is None:
|
|
print(f" Warning: Mask {fish_idx + 1} is None, skipping point cloud generation")
|
|
continue
|
|
|
|
# Check mask statistics
|
|
mask_pixels = np.sum(mask.astype(bool))
|
|
if mask_pixels == 0:
|
|
print(f" Warning: Mask {fish_idx + 1} is empty, skipping point cloud generation")
|
|
continue
|
|
|
|
# Validate depth values in mask before generating point cloud
|
|
depth_valid, reason, depth_stats = validate_depth_in_mask(
|
|
depth_data, mask, max_depth_mm=1200.0, min_depth_mm=100.0
|
|
)
|
|
if not depth_valid:
|
|
print(f" Skipped point cloud {fish_idx + 1}: depth validation failed - {reason}")
|
|
continue
|
|
|
|
points, colors = depth_mask_to_pointcloud(img, depth_data, mask, camera_intrinsics)
|
|
if points is not None and len(points) > 0:
|
|
# Apply filtering if enabled
|
|
original_count = len(points)
|
|
if args.filter_pointcloud:
|
|
points, colors = filter_point_cloud(
|
|
points, colors,
|
|
use_clustering_filter=args.use_clustering_filter,
|
|
use_density_filter=args.use_density_filter
|
|
)
|
|
filtered_count = len(points)
|
|
|
|
# Skip if less than 500 points
|
|
if filtered_count < 500:
|
|
print(f" Skipped point cloud {fish_idx + 1}: {original_count} -> {filtered_count} points (minimum: 500)")
|
|
continue
|
|
elif filtered_count == 0:
|
|
print(f" Warning: All points filtered out for fish {fish_idx + 1} (original: {original_count} points)")
|
|
continue
|
|
else:
|
|
# Check minimum point count
|
|
if len(points) < 500:
|
|
print(f" Skipped point cloud {fish_idx + 1}: only {len(points)} points (minimum: 500)")
|
|
continue
|
|
|
|
# Apply largest cluster filtering before classification (if classifier is enabled)
|
|
# This removes outliers and keeps only the main fish body cluster
|
|
if args.use_pointcloud_classifier:
|
|
cluster_before_count = len(points)
|
|
try:
|
|
points, colors, cluster_info = keep_largest_cluster_with_colors(
|
|
points, colors, eps=10.0, min_points=30
|
|
)
|
|
cluster_after_count = len(points)
|
|
|
|
if "error" in cluster_info:
|
|
print(f" Point cloud {fish_idx + 1}: Clustering failed - {cluster_info.get('error', 'Unknown error')}")
|
|
continue
|
|
|
|
if cluster_after_count < 500:
|
|
print(f" Point cloud {fish_idx + 1}: After clustering: {cluster_before_count} -> {cluster_after_count} points (minimum: 500) - SKIPPED")
|
|
continue
|
|
|
|
if cluster_after_count < cluster_before_count:
|
|
print(f" Point cloud {fish_idx + 1}: Clustering removed {cluster_before_count - cluster_after_count} points "
|
|
f"({cluster_info.get('num_clusters', 0)} clusters found, kept largest with {cluster_info.get('largest_cluster_size', 0)} points)")
|
|
except Exception as e:
|
|
print(f" Point cloud {fish_idx + 1}: WARNING - Largest cluster filtering failed: {e}")
|
|
# Continue with original points if clustering fails
|
|
|
|
# Save raw point cloud before classification (if flag is enabled)
|
|
if args.save_raw_pointclouds and output_raw_pc_folder:
|
|
postfix = f"_{fish_idx + 1}" if len(individual_masks) > 1 else ""
|
|
if use_svo:
|
|
raw_ply_path = output_raw_pc_folder / f"raw_cloud_{idx+1:04d}_{frame_name}{postfix}.ply"
|
|
else:
|
|
raw_ply_path = output_raw_pc_folder / f"raw_cloud_{idx+1:04d}_{img_path.stem}{postfix}.ply"
|
|
write_ply_file(raw_ply_path, points, colors)
|
|
print(f" Saved raw point cloud {fish_idx + 1} (before classifier): {raw_ply_path.name} ({len(points)} points)")
|
|
|
|
# Classify point cloud quality if classifier is available
|
|
# mport pdb; pdb.set_trace()
|
|
if args.use_pointcloud_classifier:
|
|
if pointcloud_classifier is not None:
|
|
# Classifier is available - perform classification
|
|
is_good, confidence, class_result = classify_pointcloud_array(
|
|
pointcloud_classifier, points, colors, num_point=1024, vote_num=3,
|
|
confidence_threshold=args.pointcloud_classifier_threshold,
|
|
use_cpu=(args.sam_device == "cpu")
|
|
)
|
|
|
|
# Always print the prediction result
|
|
prediction = class_result.get("prediction", "unknown")
|
|
class_id = class_result.get("class_id", -1)
|
|
|
|
if not is_good:
|
|
# Point cloud is bad - don't save it
|
|
if class_id == 1:
|
|
# Predicted as good but confidence too low
|
|
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SKIPPED (confidence {confidence:.3f} < threshold {args.pointcloud_classifier_threshold:.3f})")
|
|
else:
|
|
# Predicted as bad
|
|
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SKIPPED (BAD quality)")
|
|
continue
|
|
else:
|
|
# Point cloud is good - proceed to save
|
|
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SAVING (GOOD quality, confidence >= {args.pointcloud_classifier_threshold:.3f})")
|
|
else:
|
|
# Classifier requested but not available
|
|
print(f" Point cloud {fish_idx + 1}: WARNING - Classifier requested but not loaded, saving without classification")
|
|
else:
|
|
# Classifier not enabled - this should not happen if --use-pointcloud-classifier was used
|
|
# This means either the flag was not set, or classifier loading failed
|
|
if args.use_pointcloud_classifier:
|
|
print(f" Point cloud {fish_idx + 1}: ERROR - Classifier flag was set but classifier is None. Check startup logs for loading errors.")
|
|
else:
|
|
print(f" Point cloud {fish_idx + 1}: Classifier not enabled (use --use-pointcloud-classifier to enable)")
|
|
|
|
# Evaluate flatness if enabled
|
|
if args.use_flatness_filter:
|
|
try:
|
|
flatness_score, flatness_info = evaluate_flatness_ransac(
|
|
points,
|
|
distance_threshold=5.0,
|
|
ransac_n=3,
|
|
num_iterations=1000
|
|
)
|
|
|
|
if flatness_score < args.flatness_threshold:
|
|
print(f" Point cloud {fish_idx + 1}: Flatness score {flatness_score:.2f}% < threshold {args.flatness_threshold:.2f}% - SKIPPED (not flat enough)")
|
|
continue
|
|
else:
|
|
print(f" Point cloud {fish_idx + 1}: Flatness score {flatness_score:.2f}% >= threshold {args.flatness_threshold:.2f}% - PASSED")
|
|
except Exception as e:
|
|
print(f" Point cloud {fish_idx + 1}: WARNING - Flatness evaluation failed: {e}")
|
|
# If flatness check is required and fails, skip saving
|
|
if args.use_flatness_filter:
|
|
print(f" Point cloud {fish_idx + 1}: Skipping due to flatness evaluation error")
|
|
continue
|
|
|
|
# Correct tail rotation if enabled (before saving, after all checks)
|
|
if args.correct_tail_rotation:
|
|
try:
|
|
tail_correction_before_count = len(points)
|
|
points, colors, tail_correction_applied, tail_correction_info = correct_tail_rotation_array(
|
|
points, colors,
|
|
distance_threshold=args.tail_rotation_distance_threshold,
|
|
min_tail_ratio=args.tail_rotation_min_tail_ratio,
|
|
min_angle_threshold=args.tail_rotation_min_angle,
|
|
verbose=True
|
|
)
|
|
if tail_correction_applied:
|
|
print(f" Point cloud {fish_idx + 1}: Tail rotation corrected (angle: {tail_correction_info.get('rotation_angle_degrees', 0):.2f}°)")
|
|
except Exception as e:
|
|
print(f" Point cloud {fish_idx + 1}: WARNING - Tail rotation correction failed: {e}")
|
|
# Continue with original points if correction fails
|
|
|
|
# Save point cloud (passed all checks)
|
|
filtered_count = len(points)
|
|
# Create filename with postfix for multiple fish
|
|
if len(individual_masks) > 1:
|
|
postfix = f"_{fish_idx + 1}"
|
|
else:
|
|
postfix = ""
|
|
|
|
if use_svo:
|
|
ply_path = output_cloud_folder / f"cloud_{idx+1:04d}_{frame_name}{postfix}.ply"
|
|
else:
|
|
ply_path = output_cloud_folder / f"cloud_{idx+1:04d}_{img_path.stem}{postfix}.ply"
|
|
|
|
write_ply_file(ply_path, points, colors)
|
|
|
|
# Track point clouds that passed PointNet++ classifier (if enabled)
|
|
# If classifier is enabled, only track those that passed classifier
|
|
# If classifier is not enabled, track all saved point clouds
|
|
if args.use_pointcloud_classifier and pointcloud_classifier is not None:
|
|
# This point cloud passed all checks including classifier
|
|
kept_pointclouds.append(str(ply_path))
|
|
elif not args.use_pointcloud_classifier:
|
|
# Classifier not enabled, track all saved point clouds
|
|
kept_pointclouds.append(str(ply_path))
|
|
|
|
if args.filter_pointcloud:
|
|
print(f" Saved point cloud {fish_idx + 1}: {ply_path.name} ({original_count} -> {filtered_count} points)")
|
|
else:
|
|
print(f" Saved point cloud {fish_idx + 1}: {ply_path.name} ({filtered_count} points)")
|
|
else:
|
|
print(f" Warning: No valid points for fish {fish_idx + 1}")
|
|
|
|
idx += 1
|
|
else:
|
|
# Display window
|
|
print(f" Displaying window...")
|
|
try:
|
|
cv2.imshow(window_name, combined_display)
|
|
print(f" Window displayed. Waiting for keypress...")
|
|
except Exception as e:
|
|
print(f" ERROR displaying window: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
idx += 1
|
|
continue
|
|
|
|
# Wait for keypress
|
|
key = cv2.waitKey(0) & 0xFF
|
|
|
|
if key == ord('q'):
|
|
print("Quit")
|
|
break
|
|
|
|
# Any other key - next image
|
|
idx += 1
|
|
|
|
if not save_mode:
|
|
cv2.destroyAllWindows()
|
|
|
|
if use_svo and zed_reader:
|
|
zed_reader.close()
|
|
|
|
if save_mode:
|
|
total_frames = idx
|
|
output_base_path = output_images_folder.parent
|
|
|
|
# Create video from collected frames and save in images folder (only if not saving individual images)
|
|
if not args.save_images and video_frames:
|
|
if use_svo:
|
|
video_path = output_images_folder / f"{svo_name}_preview.mp4"
|
|
else:
|
|
video_path = output_images_folder / f"{folder_name}_preview.mp4"
|
|
|
|
print(f"\nCreating video from {len(video_frames)} frames...")
|
|
if len(video_frames) > 0:
|
|
h, w = video_frames[0].shape[:2]
|
|
fps = 10.0
|
|
video_writer = _open_video_writer(video_path, fps, (w, h))
|
|
|
|
for frame in video_frames:
|
|
video_writer.write(frame)
|
|
|
|
video_writer.release()
|
|
print(f"✓ Saved video: {video_path.name} ({len(video_frames)} frames, {fps} fps)")
|
|
elif args.save_images:
|
|
print(f"\n✓ Saved {total_frames} frames as individual images")
|
|
|
|
# Save fish tracking and depth statistics
|
|
if fish_tracks:
|
|
stats_path = output_base_path / "fish_depth_tracking.json"
|
|
tracking_data = {
|
|
"source": str(svo_path) if use_svo else str(image_folder),
|
|
"total_frames_processed": total_frames,
|
|
"total_tracks": len(fish_tracks),
|
|
"tracks": {}
|
|
}
|
|
|
|
for track_id, frames in fish_tracks.items():
|
|
tracking_data["tracks"][str(track_id)] = {
|
|
"track_id": track_id,
|
|
"num_frames": len(frames),
|
|
"frames": frames,
|
|
"depth_summary": {
|
|
"mean_depth_overall_mm": float(np.mean([f["depth_stats"]["mean_depth_mm"] for f in frames])),
|
|
"std_depth_overall_mm": float(np.std([f["depth_stats"]["mean_depth_mm"] for f in frames])),
|
|
"min_depth_mm": float(np.min([f["depth_stats"]["min_depth_mm"] for f in frames])),
|
|
"max_depth_mm": float(np.max([f["depth_stats"]["max_depth_mm"] for f in frames])),
|
|
}
|
|
}
|
|
|
|
with open(stats_path, 'w', encoding='utf-8') as f:
|
|
json.dump(tracking_data, f, indent=2)
|
|
print(f"\n✓ Saved fish depth tracking to: {stats_path}")
|
|
|
|
# Save list of point clouds kept for template matching
|
|
if kept_pointclouds:
|
|
pointcloud_list_path = output_base_path / "pointclouds_kept_for_template_matching.txt"
|
|
with open(pointcloud_list_path, 'w', encoding='utf-8') as f:
|
|
for ply_path in kept_pointclouds:
|
|
f.write(f"{ply_path}\n")
|
|
if args.use_pointcloud_classifier:
|
|
print(f"✓ Saved list of {len(kept_pointclouds)} point clouds that passed PointNet++ classifier: {pointcloud_list_path.name}")
|
|
else:
|
|
print(f"✓ Saved list of {len(kept_pointclouds)} point clouds for template matching: {pointcloud_list_path.name}")
|
|
else:
|
|
if args.use_pointcloud_classifier:
|
|
print(f" Note: PointNet++ classifier was enabled but no point clouds passed the filter")
|
|
else:
|
|
print(f" Note: No point clouds were saved")
|
|
|
|
# Run template matching if requested
|
|
if args.run_template_matching and kept_pointclouds:
|
|
print(f"\n{'='*60}")
|
|
print(f"Running template matching on {len(kept_pointclouds)} point clouds...")
|
|
print(f"{'='*60}")
|
|
|
|
# Validate template matching arguments
|
|
if not args.template and not args.template_folder:
|
|
print(f" ERROR: --run-template-matching requires either --template or --template-folder to be specified")
|
|
else:
|
|
# Create output directory for template matching
|
|
template_output_dir = output_base_path / "template_matching"
|
|
template_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Import fish_align_cli function
|
|
try:
|
|
# Add template_matching directory to path
|
|
template_matching_dir = Path(__file__).parent / "template_matching"
|
|
if str(template_matching_dir) not in sys.path:
|
|
sys.path.insert(0, str(template_matching_dir))
|
|
|
|
from fish_align_cli import process_folder_mode
|
|
|
|
# Process all kept point clouds using folder mode
|
|
# This will align all point clouds and use the one with maximum length for weight calculation
|
|
process_folder_mode(
|
|
input_folder=str(output_cloud_folder),
|
|
template_path=args.template,
|
|
template_folder=args.template_folder,
|
|
output_dir=str(template_output_dir),
|
|
template_scale_factor=args.template_scale_factor,
|
|
debug=False,
|
|
debug_dir=None
|
|
)
|
|
print(f"✓ Template matching completed. Results saved to: {template_output_dir}")
|
|
except Exception as e:
|
|
print(f" ERROR: Failed to run template matching: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
elif args.run_template_matching and not kept_pointclouds:
|
|
if args.use_pointcloud_classifier:
|
|
print(f" Warning: Template matching requested but no point clouds passed PointNet++ classifier")
|
|
else:
|
|
print(f" Warning: Template matching requested but no point clouds were saved")
|
|
|
|
if use_svo:
|
|
print(f"\nProcessing ended. Saved {total_frames} frames to {output_base_path.resolve()}")
|
|
else:
|
|
print(f"\nProcessing ended. Saved {total_frames}/{len(image_files)} images to {output_base_path.resolve()}")
|
|
else:
|
|
if use_svo:
|
|
print(f"\nPreview ended. Viewed {idx} frames.")
|
|
else:
|
|
print(f"\nPreview ended. Viewed {idx}/{len(image_files)} images.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|