Files
FishServer/FishMeasure/fish_video_weight_evaluation__v1.py
2026-04-16 14:53:01 +08:00

2434 lines
122 KiB
Python

#!/usr/bin/env python3
"""
Standalone image preview with YOLO detection and SAM segmentation.
Pure OpenCV + YOLO + SAM for viewing images from a folder.
"""
import argparse
import cv2
import json
import numpy as np
import torch
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
def _open_video_writer(path: Path, fps: float, size: Tuple[int, int]) -> cv2.VideoWriter:
"""Open a VideoWriter, preferring GStreamer NVENC on Jetson for hardware H.264."""
w, h = size
try:
if hasattr(cv2, "CAP_GSTREAMER"):
loc = str(path).replace('"', '\\"')
gst_pipe = (
f'appsrc ! videoconvert ! video/x-raw,format=BGRx ! '
f'nvvidconv ! video/x-raw(memory:NVMM) ! '
f'nvv4l2h264enc bitrate=4000000 ! h264parse ! '
f'mp4mux ! filesink location="{loc}"'
)
writer = cv2.VideoWriter(gst_pipe, cv2.CAP_GSTREAMER, 0, fps, (w, h))
if writer.isOpened():
return writer
writer.release()
except Exception:
pass
return cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
from ultralytics import YOLO
from seg import init_models
from pointcloud_filter import filter_point_cloud
import os
import sys
import importlib
import tempfile
import subprocess
try:
import pyzed.sl as sl
ZED_AVAILABLE = True
except ImportError:
ZED_AVAILABLE = False
print("Warning: pyzed not available. SVO2 file support disabled.")
try:
import open3d as o3d
O3D_AVAILABLE = True
except ImportError:
O3D_AVAILABLE = False
print("Warning: open3d not available. Point cloud classification may fail.")
from dataset.zed_reader import ZEDReader
from utils.evaluate_flatness import evaluate_flatness_ransac
from utils.keep_largest_cluster import keep_largest_cluster_with_colors
from utils.correct_tail_rotation import correct_tail_rotation_array
def draw_detections(image, results, class_names, depth_stats_list=None):
"""Draw YOLO detections on image with tracking IDs.
Args:
image: Image to draw on
results: YOLO detection/tracking results
class_names: Dictionary of class names
depth_stats_list: Optional list of depth stats dicts, one per detection
"""
if results is None or results.boxes is None:
return image
boxes = results.boxes.xyxy.cpu().numpy()
confidences = results.boxes.conf.cpu().numpy()
class_ids = results.boxes.cls.cpu().numpy().astype(int) if results.boxes.cls is not None else np.zeros(len(boxes), dtype=int)
# Get track IDs if available (from YOLO tracking)
track_ids = None
if hasattr(results.boxes, 'id') and results.boxes.id is not None:
track_ids = results.boxes.id.cpu().numpy().astype(int)
for i, (box, conf, cls_id) in enumerate(zip(boxes, confidences, class_ids)):
x1, y1, x2, y2 = map(int, box)
# Draw green box
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Build label with track ID if available
if class_names and cls_id in class_names:
class_name = class_names[cls_id]
else:
class_name = "fish"
# Add track ID to label if available
if track_ids is not None and i < len(track_ids):
label = f"ID:{track_ids[i]} {class_name}: {conf:.2f}"
else:
label = f"{class_name}: {conf:.2f}"
(text_w, text_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
cv2.rectangle(image, (x1, y1 - text_h - baseline - 5), (x1 + text_w, y1), (0, 255, 0), -1)
cv2.putText(image, label, (x1, y1 - baseline - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
# Draw depth median if available
if depth_stats_list and i < len(depth_stats_list) and depth_stats_list[i] is not None:
median_depth = depth_stats_list[i].get('median_depth_mm', None)
if median_depth is not None:
depth_text = f"Median: {median_depth:.1f}mm"
(depth_w, depth_h), depth_baseline = cv2.getTextSize(depth_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
# Draw below the confidence label
text_y = y1 + text_h + baseline + 5 + depth_h
# Ensure text doesn't go outside image bounds
if text_y < image.shape[0]:
cv2.rectangle(image, (x1, y1 + text_h + baseline + 5), (x1 + depth_w, text_y), (255, 255, 0), -1)
cv2.putText(image, depth_text, (x1, text_y - depth_baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
return image
def segment_with_sam(sam_predictor, image_bgr, boxes_xyxy, device):
"""Segment fish using SAM with YOLO detection boxes.
Returns list of individual masks for each detection."""
if len(boxes_xyxy) == 0:
return []
# Convert BGR to RGB for SAM
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
# Set image in SAM predictor
sam_predictor.set_image(image_rgb)
# Convert boxes to tensor and transform
boxes_tensor = torch.tensor(boxes_xyxy, dtype=torch.float32, device=device)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_tensor, image_rgb.shape[:2])
# Predict masks
masks, scores, logits = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False
)
# Return individual masks
if masks is None or masks.shape[0] == 0:
return []
individual_masks = []
for i in range(masks.shape[0]):
individual_masks.append(masks[i][0].cpu().numpy().astype(bool))
return individual_masks
def create_segmentation_overlay(image, masks):
"""Create visualization with segmentation mask overlay."""
overlay = image.copy()
if masks and len(masks) > 0:
# Combine all masks
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=bool)
for mask in masks:
combined_mask = np.logical_or(combined_mask, mask)
# Create colored mask (green)
colored_mask = np.zeros_like(overlay)
colored_mask[combined_mask] = [0, 255, 0] # Green mask
# Blend with original image
alpha = 0.5
overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
# Draw mask outline
mask_uint8 = (combined_mask.astype(np.uint8) * 255)
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)
return overlay
def validate_depth_in_mask(depth_mm, mask, max_depth_mm=1200.0, min_depth_mm=100.0,
outlier_std_factor=3.0, min_valid_ratio=0.5):
"""Validate depth values within mask region.
Checks if depth values in the mask region are reasonable:
- Depth values should be within [min_depth_mm, max_depth_mm]
- Depth values should not have too many outliers (based on std)
- At least min_valid_ratio of mask pixels should have valid depth
Args:
depth_mm: Depth map in millimeters (H, W)
mask: Binary mask (H, W)
max_depth_mm: Maximum valid depth (default: 1200.0mm)
min_depth_mm: Minimum valid depth (default: 100.0mm)
outlier_std_factor: Factor for outlier detection (default: 3.0 std devs)
min_valid_ratio: Minimum ratio of valid depth pixels in mask (default: 0.5)
Returns:
tuple: (is_valid: bool, reason: str, stats: dict)
"""
if mask is None:
return False, "Mask is None", {}
# Ensure mask is boolean
if mask.dtype != bool:
mask = mask.astype(bool)
# Check dimensions match
if mask.shape != depth_mm.shape:
return False, f"Mask shape {mask.shape} != depth shape {depth_mm.shape}", {}
# Get depth values within mask
mask_depth = depth_mm[mask]
# Filter valid depth values
valid_mask_depth = mask_depth[(mask_depth > 0) &
np.isfinite(mask_depth) &
(mask_depth >= min_depth_mm) &
(mask_depth <= max_depth_mm)]
if len(valid_mask_depth) == 0:
return False, "No valid depth values in mask", {}
# Check if enough pixels have valid depth
total_mask_pixels = np.sum(mask)
valid_ratio = len(valid_mask_depth) / total_mask_pixels if total_mask_pixels > 0 else 0.0
if valid_ratio < min_valid_ratio:
return False, f"Too few valid depth pixels: {valid_ratio:.1%} < {min_valid_ratio:.1%}", {
"valid_ratio": valid_ratio,
"total_mask_pixels": total_mask_pixels,
"valid_depth_count": len(valid_mask_depth)
}
# Check for outliers using median and MAD (Median Absolute Deviation)
median_depth = np.median(valid_mask_depth)
mad = np.median(np.abs(valid_mask_depth - median_depth))
# Use MAD-based outlier detection (more robust than std)
if mad > 0:
# MAD to std approximation: std ≈ 1.4826 * MAD
std_approx = 1.4826 * mad
lower_bound = median_depth - outlier_std_factor * std_approx
upper_bound = median_depth + outlier_std_factor * std_approx
# Count outliers
outliers = np.sum((valid_mask_depth < lower_bound) | (valid_mask_depth > upper_bound))
outlier_ratio = outliers / len(valid_mask_depth) if len(valid_mask_depth) > 0 else 0.0
# If more than 30% are outliers, consider depth unreliable
if outlier_ratio > 0.3:
return False, f"Too many depth outliers: {outlier_ratio:.1%} > 30%", {
"median_depth": median_depth,
"mad": mad,
"outlier_ratio": outlier_ratio,
"lower_bound": lower_bound,
"upper_bound": upper_bound
}
# All checks passed
return True, "OK", {
"median_depth": median_depth,
"mean_depth": np.mean(valid_mask_depth),
"std_depth": np.std(valid_mask_depth),
"min_depth": np.min(valid_mask_depth),
"max_depth": np.max(valid_mask_depth),
"valid_ratio": valid_ratio,
"valid_count": len(valid_mask_depth)
}
def depth_mask_to_pointcloud(image_bgr, depth_mm, mask, intrinsics, max_depth_mm=1200.0):
"""Convert depth map and mask to 3D point cloud.
Args:
image_bgr: BGR image (H, W, 3)
depth_mm: Depth map in millimeters (H, W)
mask: Binary mask (H, W) - must be boolean or 0/1 array
intrinsics: Dict with fx, fy, cx, cy
max_depth_mm: Maximum depth to include
Returns:
points: (N, 3) array of 3D points in mm
colors: (N, 3) array of RGB colors (0-255)
"""
fx = intrinsics["fx"]
fy = intrinsics["fy"]
cx = intrinsics["cx"]
cy = intrinsics["cy"]
height, width = depth_mm.shape
# Validate and ensure mask is boolean and matches depth dimensions
if mask is None:
return None, None
# Ensure mask is boolean
if mask.dtype != bool:
mask = mask.astype(bool)
# Check dimensions match
if mask.shape != depth_mm.shape:
# Try to resize mask if dimensions don't match
if mask.shape[0] != height or mask.shape[1] != width:
import cv2
mask = cv2.resize(mask.astype(np.uint8), (width, height), interpolation=cv2.INTER_NEAREST).astype(bool)
print(f" Warning: Resized mask from {mask.shape} to ({height}, {width})")
# Create grid coordinates
y_idx, x_idx = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
# Filter by mask FIRST, then by valid depth
# This ensures we only consider points within the mask region
mask_valid = mask.astype(bool)
depth_valid = np.isfinite(depth_mm) & (depth_mm > 0) & (depth_mm <= max_depth_mm)
# Combine: must be in mask AND have valid depth
valid = mask_valid & depth_valid
if not np.any(valid):
return None, None
# Get valid pixel coordinates and depths
x_coords = x_idx[valid]
y_coords = y_idx[valid]
z_vals = depth_mm[valid]
# Convert to 3D points (in mm)
x_vals = (x_coords - cx) * z_vals / fx
y_vals = (y_coords - cy) * z_vals / fy
points = np.stack([x_vals, y_vals, z_vals], axis=1).astype(np.float32)
# Get colors (BGR to RGB)
colors = image_bgr[valid][:, [2, 1, 0]].astype(np.uint8)
return points, colors
def write_ply_file(path, points, colors):
"""Write point cloud to PLY file."""
with open(path, 'w', encoding='utf-8') as f:
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(points)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("property uchar red\n")
f.write("property uchar green\n")
f.write("property uchar blue\n")
f.write("end_header\n")
for pt, color in zip(points, colors):
f.write(f"{pt[0]:.3f} {pt[1]:.3f} {pt[2]:.3f} {int(color[0])} {int(color[1])} {int(color[2])}\n")
def load_pointcloud_classifier(checkpoint_path: str, num_classes: int = 2, use_normals: bool = False,
device: str = None) -> Optional[torch.nn.Module]:
"""Load point cloud quality classifier model.
Args:
checkpoint_path: Path to model checkpoint (e.g., log/classification/fish_pointnet2_finetune/checkpoints/best_model.pth)
num_classes: Number of classes (2 for good/bad)
use_normals: Whether to use normals
device: Device to use ('cuda' or 'cpu'), defaults to 'cuda' if available
Returns:
Loaded model or None if failed
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint_path = Path(checkpoint_path).expanduser().resolve()
if not checkpoint_path.exists():
print(f"ERROR: Point cloud classifier checkpoint not found: {checkpoint_path}")
return None
# Get model name from checkpoint directory structure
# Expected: log/classification/{model_name}/checkpoints/best_model.pth
# The model name is stored in logs/ directory as a .txt file (same as test_classification.py does)
experiment_dir = checkpoint_path.parent.parent
logs_dir = experiment_dir / 'logs'
model_name = None
# Method 1: Try to get model name from logs directory (same as test_classification.py)
if logs_dir.exists():
# Look for .txt files (as in test_classification.py: model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0])
txt_files = list(logs_dir.glob('*.txt'))
if txt_files:
model_name = txt_files[0].stem
print(f"Found model name from logs directory: {model_name}")
# Method 2: Try to find .py file in experiment directory
if model_name is None:
py_files = list(experiment_dir.glob('pointnet*.py'))
if py_files:
model_name = py_files[0].stem
print(f"Found model file in experiment directory: {model_name}.py")
# Method 3: Default to pointnet2_cls_ssg if nothing found
if model_name is None:
model_name = 'pointnet2_cls_ssg'
print(f"Using default model name: {model_name}")
# Add models directory to path (same as test_classification.py)
# checkpoint_path: .../Pointnet_Pointnet2_pytorch/log/classification/fish_pointnet2_finetune/checkpoints/best_model.pth
# We need to go up to Pointnet_Pointnet2_pytorch directory
# Structure: checkpoints -> fish_pointnet2_finetune -> classification -> log -> Pointnet_Pointnet2_pytorch
pointnet_dir = checkpoint_path.parent.parent.parent.parent.parent # Go up to Pointnet_Pointnet2_pytorch
models_dir = pointnet_dir / 'models'
if models_dir.exists():
if str(models_dir) not in sys.path:
sys.path.insert(0, str(models_dir))
print(f"Added to sys.path: {models_dir}")
else:
print(f"WARNING: Models directory not found: {models_dir}")
# Also add the base directory (same as test_classification.py: sys.path.append(os.path.join(ROOT_DIR, 'models')))
# test_classification.py does: sys.path.append(os.path.join(ROOT_DIR, 'models'))
# where ROOT_DIR = BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# So we need to add the Pointnet_Pointnet2_pytorch directory itself
if str(pointnet_dir) not in sys.path:
sys.path.insert(0, str(pointnet_dir))
print(f"Added to sys.path: {pointnet_dir}")
# Import model (same as test_classification.py)
print(f"Attempting to import model: {model_name}")
try:
model_module = importlib.import_module(model_name)
print(f"✓ Successfully imported {model_name}")
except ImportError as e1:
print(f"Failed to import {model_name}: {e1}")
# Try importing from models subdirectory
print(f"Attempting to import from models.{model_name}")
try:
model_module = importlib.import_module(f'models.{model_name}')
print(f"✓ Successfully imported models.{model_name}")
except ImportError as e2:
print(f"ERROR: Could not import model {model_name} or models.{model_name}")
print(f" First error: {e1}")
print(f" Second error: {e2}")
print(f" Current sys.path: {[p for p in sys.path if 'pointcloud' in p.lower() or 'pointnet' in p.lower()]}")
return None
# Create model
print(f"Creating model with num_classes={num_classes}, use_normals={use_normals}")
classifier = model_module.get_model(num_classes, normal_channel=use_normals)
classifier = classifier.to(device)
# Load checkpoint
print(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(str(checkpoint_path), map_location=device)
if 'model_state_dict' in checkpoint:
classifier.load_state_dict(checkpoint['model_state_dict'])
print("✓ Loaded model_state_dict from checkpoint")
else:
classifier.load_state_dict(checkpoint)
print("✓ Loaded state dict directly from checkpoint")
classifier.eval()
print(f"✓ Loaded point cloud classifier from {checkpoint_path}")
return classifier
def classify_pointcloud_array(classifier: torch.nn.Module, points: np.ndarray, colors: np.ndarray,
num_point: int = 1024, vote_num: int = 3,
use_cpu: bool = False, confidence_threshold: float = 0.5) -> Tuple[bool, float, Dict]:
"""Classify a point cloud array by using test_classification.py's classify_single_pointcloud function.
This function saves the point cloud to a temporary PLY file and then calls the exact same
classification function used in test_classification.py to ensure consistency.
Args:
classifier: Loaded classifier model
points: Point cloud array (N, 3)
colors: Color array (N, 3) - used for saving temporary PLY file
num_point: Number of points to sample (default: 1024)
vote_num: Number of votes for classification (default: 3)
use_cpu: Whether to use CPU (if False, uses model's device)
confidence_threshold: Minimum confidence threshold for "good" classification (default: 0.5)
Only point clouds with confidence >= threshold will be considered "good"
Returns:
tuple: (is_good: bool, confidence: float, result_dict: dict)
Returns (False, 0.0, {}) if classification fails
is_good is True only if predicted as "good" AND confidence >= confidence_threshold
"""
if classifier is None:
return False, 0.0, {}
if len(points) == 0:
return False, 0.0, {"error": "Empty point cloud"}
if not O3D_AVAILABLE:
return False, 0.0, {"error": "open3d not available"}
# Create temporary PLY file
temp_ply = None
try:
# Import the classification function from test_classification.py
# Try multiple possible paths
possible_paths = [
Path(__file__).parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "test_classification.py",
Path(__file__).parent.parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "test_classification.py",
Path("/home/ubuntu/projects/FishMeasure/pointcloud_classifier/Pointnet_Pointnet2_pytorch/test_classification.py"),
]
test_classification_path = None
for path in possible_paths:
if path.exists():
test_classification_path = path
break
if test_classification_path is None:
error_msg = f"test_classification.py not found. Tried paths: {[str(p) for p in possible_paths]}"
print(f" ERROR: {error_msg}")
return False, 0.0, {"error": error_msg}
# Add the directory to sys.path to import the module
test_classification_dir = test_classification_path.parent
if str(test_classification_dir) not in sys.path:
sys.path.insert(0, str(test_classification_dir))
# Import the classification function
try:
from test_classification import classify_single_pointcloud
except ImportError as e:
error_msg = f"Failed to import classify_single_pointcloud from test_classification: {e}"
print(f" ERROR: {error_msg}")
print(f" sys.path includes: {[p for p in sys.path if 'pointcloud' in p or 'Pointnet' in p]}")
return False, 0.0, {"error": error_msg}
# Create temporary PLY file
with tempfile.NamedTemporaryFile(mode='w', suffix='.ply', delete=False) as f:
temp_ply = f.name
# Save point cloud to temporary PLY file
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points.astype(np.float64))
if colors is not None and len(colors) == len(points):
pcd.colors = o3d.utility.Vector3dVector(colors.astype(np.float64) / 255.0)
o3d.io.write_point_cloud(temp_ply, pcd)
# Call the exact same classification function from test_classification.py
result, error = classify_single_pointcloud(
classifier, temp_ply, num_point=num_point, vote_num=vote_num,
use_cpu=use_cpu, use_normals=False
)
# Clean up temporary file
try:
os.unlink(temp_ply)
except:
pass
temp_ply = None
if result is None:
return False, 0.0, {"error": error if error else "Classification failed"}
# Extract results
class_id = result.get('class_id', 0)
confidence = result.get('confidence', 0.0)
prediction = result.get('prediction', 'unknown')
# Check if predicted as "good" (class_id == 1) AND confidence meets threshold
is_good = (class_id == 1) and (confidence >= confidence_threshold)
# Add prediction info to result for better logging
result['is_good'] = is_good
result['meets_threshold'] = is_good
return is_good, confidence, result
except Exception as e:
# Clean up temporary file if it exists
if temp_ply and os.path.exists(temp_ply):
try:
os.unlink(temp_ply)
except:
pass
print(f" Warning: Point cloud classification failed: {e}")
import traceback
traceback.print_exc()
return False, 0.0, {"error": str(e)}
def calculate_fish_depth_stats(depth_mm, mask, max_depth_mm=1200.0):
"""Calculate depth statistics for a fish mask.
Returns:
dict with mean, median, min, max, std depth in mm, or None if no valid points
"""
# Filter by mask and valid depth
valid = mask & np.isfinite(depth_mm) & (depth_mm > 0) & (depth_mm <= max_depth_mm)
if not np.any(valid):
return None
valid_depths = depth_mm[valid]
return {
"mean_depth_mm": float(np.mean(valid_depths)),
"median_depth_mm": float(np.median(valid_depths)),
"min_depth_mm": float(np.min(valid_depths)),
"max_depth_mm": float(np.max(valid_depths)),
"std_depth_mm": float(np.std(valid_depths)),
"num_points": int(np.sum(valid))
}
def calculate_iou(box1, box2):
"""Calculate IoU between two bounding boxes [x1, y1, x2, y2]."""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = (x2 - x1) * (y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
def calculate_bbox_center(box):
"""Calculate center point of bounding box [x1, y1, x2, y2]."""
return ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
def calculate_bbox_distance(box1, box2):
"""Calculate distance between centers of two bounding boxes."""
center1 = calculate_bbox_center(box1)
center2 = calculate_bbox_center(box2)
return np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
def is_bbox_stationary(current_box, previous_boxes, movement_threshold=5.0):
"""Check if bbox has moved significantly compared to previous positions.
Args:
current_box: Current bounding box [x1, y1, x2, y2]
previous_boxes: List of previous bounding boxes for this track
movement_threshold: Minimum pixel movement to consider as moved
Returns:
True if bbox is stationary (hasn't moved), False if it has moved
"""
if not previous_boxes or len(previous_boxes) == 0:
return False
# Check distance to most recent previous box
last_box = previous_boxes[-1]
distance = calculate_bbox_distance(current_box, last_box)
return distance < movement_threshold
def track_fish_across_frames(current_boxes, previous_boxes, iou_threshold=0.3):
"""Match current detections to previous frame detections using IoU.
Returns:
List of track IDs for current detections (None if new fish)
"""
if previous_boxes is None or len(previous_boxes) == 0:
return [None] * len(current_boxes)
track_ids = [None] * len(current_boxes)
used_prev = [False] * len(previous_boxes)
# Match by IoU
for i, curr_box in enumerate(current_boxes):
best_iou = 0.0
best_match = None
for j, prev_box in enumerate(previous_boxes):
if used_prev[j]:
continue
iou = calculate_iou(curr_box, prev_box)
if iou > best_iou and iou >= iou_threshold:
best_iou = iou
best_match = j
if best_match is not None:
track_ids[i] = best_match
used_prev[best_match] = True
return track_ids
def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_device,
conf=0.25, imgsz=640, max_frames=0, save_images=False, filter_pointcloud=False,
use_clustering_filter=False, use_density_filter=False,
pointcloud_classifier=None, use_pointcloud_classifier=False,
pointcloud_classifier_threshold=0.7,
flatness_threshold=0.0, use_flatness_filter=False,
run_template_matching=False, template_path=None, template_folder=None,
template_scale_factor=1.0, save_raw_pointclouds=False,
correct_tail_rotation=False, tail_rotation_distance_threshold=5.0,
tail_rotation_min_tail_ratio=0.7, tail_rotation_min_angle=5.0):
"""Process a single SVO2 file with pre-loaded YOLO and SAM models.
Args:
svo_path: Path to SVO2 file
output_base: Base output directory (will create subfolder with SVO name)
yolo_model: Pre-loaded YOLO model
sam_predictor: Pre-loaded SAM predictor
sam_device: SAM device (torch.device)
conf: YOLO confidence threshold
imgsz: YOLO image size
max_frames: Maximum frames to process (0 = all)
save_images: If True, save individual images instead of video
filter_pointcloud: If True, apply filtering to remove outliers from point clouds
use_clustering_filter: If True, use clustering to keep only largest cluster
use_density_filter: If True, use density filtering (requires at least 200 points within 100mm radius)
Returns:
True if successful, False otherwise
"""
if not ZED_AVAILABLE:
print("ERROR: pyzed not available. Cannot read SVO2 files.")
return False
svo_path = Path(svo_path).expanduser().resolve()
if not svo_path.exists():
print(f"ERROR: SVO2 file not found: {svo_path}")
return False
svo_name = svo_path.stem
class_names = yolo_model.names if hasattr(yolo_model, 'names') else {}
# Setup output folders
output_base = Path(output_base) / svo_name
output_images_folder = output_base / "images"
output_cloud_folder = output_base / "cloud"
output_raw_pc_folder = output_base / "raw_pc" if save_raw_pointclouds else None
output_images_folder.mkdir(parents=True, exist_ok=True)
output_cloud_folder.mkdir(parents=True, exist_ok=True)
if save_raw_pointclouds and output_raw_pc_folder:
output_raw_pc_folder.mkdir(parents=True, exist_ok=True)
print(f"Reading from SVO2 file: {svo_path.name}")
print(f"Output folder: {output_base.resolve()}")
# Check if output folder already exists and contains point clouds
# If so, skip data generation and directly run template matching
if output_base.exists() and output_cloud_folder.exists():
# Check if there are point cloud files
point_cloud_files = list(output_cloud_folder.glob("*.ply"))
if point_cloud_files and run_template_matching:
print(f"\n{'='*60}")
print(f"Output folder already exists with {len(point_cloud_files)} point cloud files")
print(f"Skipping data generation, directly running template matching...")
print(f"{'='*60}")
# Load kept point clouds from file if it exists, otherwise use all point clouds
kept_pointclouds = []
pointcloud_list_path = output_base / "pointclouds_kept_for_template_matching.txt"
if pointcloud_list_path.exists():
print(f" Loading point cloud list from: {pointcloud_list_path.name}")
with open(pointcloud_list_path, 'r', encoding='utf-8') as f:
kept_pointclouds = [line.strip() for line in f if line.strip()]
print(f" Found {len(kept_pointclouds)} point clouds in list")
else:
# If no list file, use all point clouds in the folder
kept_pointclouds = [str(f) for f in point_cloud_files]
print(f" No point cloud list file found, using all {len(kept_pointclouds)} point clouds")
if kept_pointclouds:
# Run template matching directly
print(f"\n{'='*60}")
print(f"Running template matching on {len(kept_pointclouds)} point clouds...")
print(f"{'='*60}")
# Validate template matching arguments
if not template_path and not template_folder:
print(f" ERROR: --run-template-matching requires either --template or --template-folder to be specified")
return False
# Create output directory for template matching
template_output_dir = output_base / "template_matching"
template_output_dir.mkdir(parents=True, exist_ok=True)
# Import fish_align_cli function
try:
# Add template_matching directory to path
template_matching_dir = Path(__file__).parent / "template_matching"
if str(template_matching_dir) not in sys.path:
sys.path.insert(0, str(template_matching_dir))
from fish_align_cli import process_folder_mode
# Process all point clouds using folder mode
process_folder_mode(
input_folder=str(output_cloud_folder),
template_path=template_path,
template_folder=template_folder,
output_dir=str(template_output_dir),
template_scale_factor=template_scale_factor,
debug=False,
debug_dir=None
)
print(f"✓ Template matching completed. Results saved to: {template_output_dir}")
return True
except Exception as e:
print(f" ERROR: Failed to run template matching: {e}")
import traceback
traceback.print_exc()
return False
else:
print(f" Warning: No point clouds found to process")
return False
elif point_cloud_files and not run_template_matching:
print(f"\n{'='*60}")
print(f"Output folder already exists with {len(point_cloud_files)} point cloud files")
print(f"Template matching not requested (--run-template-matching not set)")
print(f"Skipping processing...")
print(f"{'='*60}")
return True
# If folder exists but no point clouds, continue with normal processing
# Initialize ZED reader
zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False)
if not zed_reader.open():
print("ERROR: Failed to open SVO2 file")
return False
# Get camera intrinsics
calib_params = zed_reader.zed.get_camera_information().camera_configuration.calibration_parameters
camera_intrinsics = {
"fx": float(calib_params.left_cam.fx),
"fy": float(calib_params.left_cam.fy),
"cx": float(calib_params.left_cam.cx),
"cy": float(calib_params.left_cam.cy),
}
runtime_params = sl.RuntimeParameters()
left_image_mat = sl.Mat()
depth_mat = sl.Mat()
# Tracking variables
previous_boxes = None
fish_tracks: Dict[int, List[Dict[str, Any]]] = {}
track_bbox_history: Dict[int, List[List[float]]] = {}
track_stationary_count: Dict[int, int] = {}
next_track_id = 0
STATIONARY_THRESHOLD = 10
MOVEMENT_THRESHOLD = 5.0
video_frames = []
idx = 0
# List to track point clouds that passed PointNet++ classifier (if enabled)
kept_pointclouds = []
try:
while True:
if max_frames > 0 and idx >= max_frames:
break
err = zed_reader.zed.grab(runtime_params)
if err != sl.ERROR_CODE.SUCCESS:
break
# Retrieve images
zed_reader.zed.retrieve_image(left_image_mat, sl.VIEW.LEFT)
zed_reader.zed.retrieve_measure(depth_mat, sl.MEASURE.DEPTH)
# Convert to numpy
left_np = left_image_mat.get_data()
depth_np = depth_mat.get_data()
if left_np.shape[2] > 3:
img = left_np[:, :, :3].copy()
else:
img = left_np.copy()
if depth_np.dtype != np.float32:
depth_data = depth_np.astype(np.float32)
else:
depth_data = depth_np.copy()
frame_name = f"frame_{idx+1:06d}"
if idx % 30 == 0:
print(f"[{idx + 1}] {frame_name}")
# Run YOLO tracking
results = yolo_model.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0]
num_dets = len(results.boxes) if results.boxes is not None else 0
# Process detections (same logic as main function)
individual_masks = []
current_boxes = None
track_ids = []
depth_stats_list = [] # Store depth stats for each detection
try:
if num_dets > 0:
boxes_xyxy = results.boxes.xyxy.cpu().numpy()
current_boxes = boxes_xyxy.tolist()
# Get track IDs from YOLO tracking results
if hasattr(results.boxes, 'id') and results.boxes.id is not None:
track_ids = results.boxes.id.cpu().numpy().astype(int).tolist()
else:
# Fallback: assign sequential IDs if tracking not available
track_ids = list(range(next_track_id, next_track_id + len(current_boxes)))
next_track_id += len(current_boxes)
# Filter stationary fish
active_detections = []
active_masks = []
active_track_ids = []
active_boxes = []
for fish_idx, (box, track_id) in enumerate(zip(current_boxes, track_ids)):
if track_id not in track_bbox_history:
track_bbox_history[track_id] = []
track_stationary_count[track_id] = 0
is_stationary = False
if len(track_bbox_history[track_id]) > 0:
is_stationary = is_bbox_stationary(box, track_bbox_history[track_id], MOVEMENT_THRESHOLD)
if is_stationary:
track_stationary_count[track_id] += 1
if track_stationary_count[track_id] >= STATIONARY_THRESHOLD:
continue
else:
track_stationary_count[track_id] = 0
active_detections.append(fish_idx)
active_boxes.append(box)
active_track_ids.append(track_id)
track_bbox_history[track_id].append(box)
if len(track_bbox_history[track_id]) > 10:
track_bbox_history[track_id].pop(0)
if len(active_detections) == 0:
num_dets = 0
individual_masks = []
previous_boxes = None
depth_stats_list = []
else:
active_boxes_array = np.array(active_boxes)
all_masks = segment_with_sam(sam_predictor, img, active_boxes_array, sam_device)
individual_masks = all_masks if all_masks else []
# Calculate depth stats and map to original detection order
depth_stats_list = [None] * len(current_boxes) # Initialize with None for all detections
if individual_masks and len(individual_masks) > 0 and depth_data is not None:
for mask_idx, (mask, track_id, box) in enumerate(zip(individual_masks, active_track_ids, active_boxes)):
depth_stats = calculate_fish_depth_stats(depth_data, mask)
if depth_stats:
# Map to original detection index
original_idx = active_detections[mask_idx]
depth_stats_list[original_idx] = depth_stats
frame_data = {
"frame_index": idx + 1,
"frame_name": frame_name,
"fish_index": active_detections[mask_idx],
"track_id": int(track_id),
"bbox": box,
"depth_stats": depth_stats
}
if track_id not in fish_tracks:
fish_tracks[track_id] = []
fish_tracks[track_id].append(frame_data)
previous_boxes = active_boxes
current_boxes = active_boxes
num_dets = len(active_detections)
if individual_masks and len(individual_masks) > 0:
right_display = create_segmentation_overlay(img.copy(), individual_masks)
cv2.putText(right_display, "Segmentation", (10, right_display.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
else:
right_display = img.copy()
cv2.putText(right_display, "No detections" if num_dets == 0 else "Segmentation (failed)",
(10, right_display.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
(128, 128, 128) if num_dets == 0 else (0, 0, 255), 2, cv2.LINE_AA)
else:
right_display = img.copy()
cv2.putText(right_display, "No detections", (10, right_display.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (128, 128, 128), 2, cv2.LINE_AA)
previous_boxes = None
depth_stats_list = []
except Exception as e:
print(f" ERROR in processing: {e}")
right_display = img.copy()
previous_boxes = None
depth_stats_list = []
# Draw detections with depth info
left_display = draw_detections(img.copy(), results, class_names, depth_stats_list)
info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}"
cv2.putText(left_display, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
cv2.putText(left_display, "Detection", (10, left_display.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
# Combine and save
combined_display = np.hstack([left_display, right_display])
if num_dets > 0:
# Save image or collect for video
if save_images:
# Save individual image file
image_path = output_images_folder / f"{frame_name}_detection.png"
cv2.imwrite(str(image_path), combined_display)
if idx % 30 == 0:
print(f" Saved image: {image_path.name}")
else:
# Collect for video
video_frames.append(combined_display.copy())
# Save point clouds
if depth_data is not None and individual_masks and len(individual_masks) > 0:
for fish_idx, mask in enumerate(individual_masks):
# Verify mask is being used correctly
if mask is None:
print(f" Warning: Mask {fish_idx + 1} is None, skipping point cloud generation")
continue
# Check mask statistics
mask_pixels = np.sum(mask.astype(bool))
if mask_pixels == 0:
print(f" Warning: Mask {fish_idx + 1} is empty, skipping point cloud generation")
continue
# Validate depth values in mask before generating point cloud
depth_valid, reason, depth_stats = validate_depth_in_mask(
depth_data, mask, max_depth_mm=1200.0, min_depth_mm=600.0
)
if not depth_valid:
print(f" Skipped point cloud {fish_idx + 1}: depth validation failed - {reason}")
continue
points, colors = depth_mask_to_pointcloud(img, depth_data, mask, camera_intrinsics)
if points is not None and len(points) > 0:
# Apply filtering if enabled
original_count = len(points)
if filter_pointcloud:
points, colors = filter_point_cloud(
points, colors,
use_clustering_filter=use_clustering_filter,
use_density_filter=use_density_filter
)
filtered_count = len(points)
# Skip if less than 500 points
if filtered_count < 500:
print(f" Skipped point cloud {fish_idx + 1}: {original_count} -> {filtered_count} points (minimum: 500)")
continue
elif filtered_count == 0:
print(f" Warning: All points filtered out for fish {fish_idx + 1} (original: {original_count})")
continue
else:
# Check minimum point count
if len(points) < 500:
print(f" Skipped point cloud {fish_idx + 1}: only {len(points)} points (minimum: 500)")
continue
# Apply largest cluster filtering before classification (if classifier is enabled)
# This removes outliers and keeps only the main fish body cluster
if use_pointcloud_classifier:
cluster_before_count = len(points)
try:
points, colors, cluster_info = keep_largest_cluster_with_colors(
points, colors, eps=10.0, min_points=30
)
cluster_after_count = len(points)
if "error" in cluster_info:
print(f" Point cloud {fish_idx + 1}: Clustering failed - {cluster_info.get('error', 'Unknown error')}")
continue
if cluster_after_count < 500:
print(f" Point cloud {fish_idx + 1}: After clustering: {cluster_before_count} -> {cluster_after_count} points (minimum: 500) - SKIPPED")
continue
if cluster_after_count < cluster_before_count:
print(f" Point cloud {fish_idx + 1}: Clustering removed {cluster_before_count - cluster_after_count} points "
f"({cluster_info.get('num_clusters', 0)} clusters found, kept largest with {cluster_info.get('largest_cluster_size', 0)} points)")
except Exception as e:
print(f" Point cloud {fish_idx + 1}: WARNING - Largest cluster filtering failed: {e}")
# Continue with original points if clustering fails
# Save raw point cloud before classification (if flag is enabled)
if save_raw_pointclouds and output_raw_pc_folder:
postfix = f"_{fish_idx + 1}" if len(individual_masks) > 1 else ""
raw_ply_path = output_raw_pc_folder / f"raw_cloud_{idx+1:04d}_{frame_name}{postfix}.ply"
write_ply_file(raw_ply_path, points, colors)
print(f" Saved raw point cloud {fish_idx + 1} (before classifier): {raw_ply_path.name} ({len(points)} points)")
# Classify point cloud quality if classifier is available
if use_pointcloud_classifier:
if pointcloud_classifier is not None:
# Classifier is available - perform classification
is_good, confidence, class_result = classify_pointcloud_array(
pointcloud_classifier, points, colors, num_point=1024, vote_num=3,
confidence_threshold=pointcloud_classifier_threshold,
use_cpu=(sam_device.type == "cpu")
)
# Always print the prediction result
prediction = class_result.get("prediction", "unknown")
class_id = class_result.get("class_id", -1)
if not is_good:
# Point cloud is bad - don't save it
if class_id == 1:
# Predicted as good but confidence too low
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SKIPPED (confidence {confidence:.3f} < threshold {pointcloud_classifier_threshold:.3f})")
else:
# Predicted as bad
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SKIPPED (BAD quality)")
continue
else:
# Point cloud is good - proceed to save
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SAVING (GOOD quality, confidence >= {pointcloud_classifier_threshold:.3f})")
else:
# Classifier requested but not available
print(f" Point cloud {fish_idx + 1}: WARNING - Classifier requested but not loaded, saving without classification")
else:
# Classifier not enabled - this should not happen if use_pointcloud_classifier was True
# This means either the flag was not set, or classifier loading failed
if use_pointcloud_classifier:
print(f" Point cloud {fish_idx + 1}: ERROR - Classifier flag was set but classifier is None. Check startup logs for loading errors.")
else:
print(f" Point cloud {fish_idx + 1}: Classifier not enabled")
# Evaluate flatness if enabled
if use_flatness_filter:
try:
flatness_score, flatness_info = evaluate_flatness_ransac(
points,
distance_threshold=5.0,
ransac_n=3,
num_iterations=1000
)
if flatness_score < flatness_threshold:
print(f" Point cloud {fish_idx + 1}: Flatness score {flatness_score:.2f}% < threshold {flatness_threshold:.2f}% - SKIPPED (not flat enough)")
continue
else:
print(f" Point cloud {fish_idx + 1}: Flatness score {flatness_score:.2f}% >= threshold {flatness_threshold:.2f}% - PASSED")
except Exception as e:
print(f" Point cloud {fish_idx + 1}: WARNING - Flatness evaluation failed: {e}")
# If flatness check is required and fails, skip saving
if use_flatness_filter:
print(f" Point cloud {fish_idx + 1}: Skipping due to flatness evaluation error")
continue
# Correct tail rotation if enabled (before saving, after all checks)
if correct_tail_rotation:
try:
tail_correction_before_count = len(points)
points, colors, tail_correction_applied, tail_correction_info = correct_tail_rotation_array(
points, colors,
distance_threshold=tail_rotation_distance_threshold,
min_tail_ratio=tail_rotation_min_tail_ratio,
min_angle_threshold=tail_rotation_min_angle,
verbose=True
)
if tail_correction_applied:
print(f" Point cloud {fish_idx + 1}: Tail rotation corrected (angle: {tail_correction_info.get('rotation_angle_degrees', 0):.2f}°)")
except Exception as e:
print(f" Point cloud {fish_idx + 1}: WARNING - Tail rotation correction failed: {e}")
# Continue with original points if correction fails
# Save point cloud (passed all checks)
filtered_count = len(points)
postfix = f"_{fish_idx + 1}" if len(individual_masks) > 1 else ""
ply_path = output_cloud_folder / f"cloud_{idx+1:04d}_{frame_name}{postfix}.ply"
write_ply_file(ply_path, points, colors)
# Track point clouds that passed PointNet++ classifier (if enabled)
# If classifier is enabled, only track those that passed classifier
# If classifier is not enabled, track all saved point clouds
if use_pointcloud_classifier and pointcloud_classifier is not None:
# This point cloud passed all checks including classifier
kept_pointclouds.append(str(ply_path))
elif not use_pointcloud_classifier:
# Classifier not enabled, track all saved point clouds
kept_pointclouds.append(str(ply_path))
if filter_pointcloud:
print(f" Saved point cloud {fish_idx + 1}: {ply_path.name} ({original_count} -> {filtered_count} points)")
else:
print(f" Saved point cloud {fish_idx + 1}: {ply_path.name} ({filtered_count} points)")
idx += 1
# Create video (only if not saving individual images)
if not save_images and video_frames:
video_path = output_images_folder / f"{svo_name}_preview.mp4"
h, w = video_frames[0].shape[:2]
fps = 10.0
video_writer = _open_video_writer(video_path, fps, (w, h))
for frame in video_frames:
video_writer.write(frame)
video_writer.release()
print(f"✓ Saved video: {video_path.name} ({len(video_frames)} frames)")
elif save_images:
print(f"✓ Saved {idx} frames as individual images")
# Save tracking stats
if fish_tracks:
stats_path = output_base / "fish_depth_tracking.json"
tracking_data = {
"source": str(svo_path),
"total_frames_processed": idx,
"total_tracks": len(fish_tracks),
"tracks": {}
}
for track_id, frames in fish_tracks.items():
tracking_data["tracks"][str(track_id)] = {
"track_id": track_id,
"num_frames": len(frames),
"frames": frames,
"depth_summary": {
"mean_depth_overall_mm": float(np.mean([f["depth_stats"]["mean_depth_mm"] for f in frames])),
"std_depth_overall_mm": float(np.std([f["depth_stats"]["mean_depth_mm"] for f in frames])),
"min_depth_mm": float(np.min([f["depth_stats"]["min_depth_mm"] for f in frames])),
"max_depth_mm": float(np.max([f["depth_stats"]["max_depth_mm"] for f in frames])),
}
}
with open(stats_path, 'w', encoding='utf-8') as f:
json.dump(tracking_data, f, indent=2)
print(f"✓ Saved tracking stats: {stats_path.name}")
# Save list of point clouds kept for template matching
if kept_pointclouds:
pointcloud_list_path = output_base / "pointclouds_kept_for_template_matching.txt"
with open(pointcloud_list_path, 'w', encoding='utf-8') as f:
for ply_path in kept_pointclouds:
f.write(f"{ply_path}\n")
if use_pointcloud_classifier:
print(f"✓ Saved list of {len(kept_pointclouds)} point clouds that passed PointNet++ classifier: {pointcloud_list_path.name}")
else:
print(f"✓ Saved list of {len(kept_pointclouds)} point clouds for template matching: {pointcloud_list_path.name}")
else:
if use_pointcloud_classifier:
print(f" Note: PointNet++ classifier was enabled but no point clouds passed the filter")
else:
print(f" Note: No point clouds were saved")
# Run template matching if requested
if run_template_matching and kept_pointclouds:
print(f"\n{'='*60}")
print(f"Running template matching on {len(kept_pointclouds)} point clouds...")
print(f"{'='*60}")
# Create output directory for template matching
template_output_dir = output_base / "template_matching"
template_output_dir.mkdir(parents=True, exist_ok=True)
# Import fish_align_cli function
try:
# Add template_matching directory to path
template_matching_dir = Path(__file__).parent / "template_matching"
if str(template_matching_dir) not in sys.path:
sys.path.insert(0, str(template_matching_dir))
from fish_align_cli import process_folder_mode
# Process all kept point clouds using folder mode
# This will align all point clouds and use the one with maximum length for weight calculation
process_folder_mode(
input_folder=str(output_cloud_folder),
template_path=template_path,
template_folder=template_folder,
output_dir=str(template_output_dir),
template_scale_factor=template_scale_factor,
debug=False,
debug_dir=None
)
print(f"✓ Template matching completed. Results saved to: {template_output_dir}")
except Exception as e:
print(f" ERROR: Failed to run template matching: {e}")
import traceback
traceback.print_exc()
elif run_template_matching and not kept_pointclouds:
if use_pointcloud_classifier:
print(f" Warning: Template matching requested but no point clouds passed PointNet++ classifier")
else:
print(f" Warning: Template matching requested but no point clouds were saved")
print(f"✓ Processed {idx} frames from {svo_path.name}")
return True
except Exception as e:
print(f"ERROR processing {svo_path.name}: {e}")
import traceback
traceback.print_exc()
return False
finally:
if zed_reader:
zed_reader.close()
def process_batch_svo2_folder(svo_folder, output_base, yolo_model, sam_predictor, sam_device,
conf=0.25, imgsz=640, max_frames=0, save_images=False, filter_pointcloud=False,
use_clustering_filter=False, use_density_filter=False,
pointcloud_classifier=None, use_pointcloud_classifier=False,
pointcloud_classifier_threshold=0.7,
use_flatness_filter=False, flatness_threshold=90.0,
run_template_matching=False, template_path=None, template_folder=None,
template_scale_factor=1.0, save_raw_pointclouds=False,
correct_tail_rotation=False, tail_rotation_distance_threshold=5.0,
tail_rotation_min_tail_ratio=0.7, tail_rotation_min_angle=5.0):
"""Process all SVO2 files in a folder with pre-loaded YOLO and SAM models.
Args:
svo_folder: Folder containing SVO2 files
output_base: Base output directory (will create subfolders with SVO names)
yolo_model: Pre-loaded YOLO model
sam_predictor: Pre-loaded SAM predictor
sam_device: SAM device (torch.device)
conf: YOLO confidence threshold
imgsz: YOLO image size
max_frames: Maximum frames to process per SVO2 (0 = all)
save_images: If True, save individual images instead of video
filter_pointcloud: If True, apply filtering to remove outliers from point clouds
use_clustering_filter: If True, use clustering to keep only largest cluster
use_density_filter: If True, use density filtering
pointcloud_classifier: Loaded classifier model or None
use_pointcloud_classifier: If True, use classifier to filter point clouds
pointcloud_classifier_threshold: Confidence threshold for classifier
use_flatness_filter: If True, evaluate flatness before saving
flatness_threshold: Minimum flatness score required
Returns:
dict with processing statistics
"""
svo_folder = Path(svo_folder).expanduser().resolve()
if not svo_folder.exists() or not svo_folder.is_dir():
print(f"ERROR: SVO folder not found: {svo_folder}")
return {"success": False, "error": "Folder not found"}
# Find all SVO2 files
svo_files = sorted(svo_folder.glob("*.svo2"))
if not svo_files:
print(f"ERROR: No SVO2 files found in {svo_folder}")
return {"success": False, "error": "No SVO2 files found"}
print(f"Found {len(svo_files)} SVO2 file(s) to process")
print(f"Output base folder: {output_base}")
print("="*60)
output_base = Path(output_base).expanduser().resolve()
output_base.mkdir(parents=True, exist_ok=True)
success_count = 0
skipped_count = 0
failed_count = 0
for idx, svo_file in enumerate(svo_files):
svo_name = svo_file.stem
output_subfolder = output_base / svo_name
# Check if output folder already exists
if output_subfolder.exists() and output_subfolder.is_dir():
print(f"\n{'='*60}")
print(f"[{idx + 1}/{len(svo_files)}] Skipping: {svo_name}")
print(f" Output folder already exists: {output_subfolder}")
print(f"{'='*60}")
skipped_count += 1
continue
print(f"\n{'='*60}")
print(f"[{idx + 1}/{len(svo_files)}] Processing: {svo_name}")
print(f"{'='*60}")
try:
success = process_single_svo2(
svo_path=str(svo_file),
output_base=str(output_base),
yolo_model=yolo_model,
sam_predictor=sam_predictor,
sam_device=sam_device,
conf=conf,
imgsz=imgsz,
max_frames=max_frames,
save_images=save_images,
filter_pointcloud=filter_pointcloud,
use_clustering_filter=use_clustering_filter,
use_density_filter=use_density_filter,
pointcloud_classifier=pointcloud_classifier,
use_pointcloud_classifier=use_pointcloud_classifier,
pointcloud_classifier_threshold=pointcloud_classifier_threshold,
flatness_threshold=flatness_threshold,
use_flatness_filter=use_flatness_filter,
run_template_matching=run_template_matching,
template_path=template_path,
template_folder=template_folder,
template_scale_factor=template_scale_factor,
save_raw_pointclouds=save_raw_pointclouds,
correct_tail_rotation=correct_tail_rotation,
tail_rotation_distance_threshold=tail_rotation_distance_threshold,
tail_rotation_min_tail_ratio=tail_rotation_min_tail_ratio,
tail_rotation_min_angle=tail_rotation_min_angle
)
if success:
success_count += 1
print(f"✓ Successfully processed: {svo_name}")
else:
failed_count += 1
print(f"✗ Failed to process: {svo_name}")
except Exception as e:
failed_count += 1
print(f"✗ Error processing {svo_name}: {e}")
import traceback
traceback.print_exc()
print(f"\n{'='*60}")
print("Batch Processing Summary")
print(f"{'='*60}")
print(f"Total files: {len(svo_files)}")
print(f"Successfully processed: {success_count}")
print(f"Skipped (already exists): {skipped_count}")
print(f"Failed: {failed_count}")
print(f"Output directory: {output_base.resolve()}")
print(f"{'='*60}")
return {
"success": True,
"total": len(svo_files),
"success_count": success_count,
"skipped_count": skipped_count,
"failed_count": failed_count
}
def main():
parser = argparse.ArgumentParser(description="Preview images with YOLO detection and SAM segmentation")
input_group = parser.add_mutually_exclusive_group()
input_group.add_argument("--image-folder", type=str, default="", help="Folder containing images")
input_group.add_argument("--svo", type=str, default="", help="Path to SVO2 file (alternative to --image-folder)")
input_group.add_argument("--batch-svo-folder", type=str, default="",
help="Folder containing SVO2 files to process in batch mode. Each SVO2 will be processed and saved to output/<svo_name>/")
parser.add_argument("--yolo-model",
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt",
help="YOLO model path")
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
parser.add_argument("--scale", type=float, default=1.0, help="Display scale (0.5 = half size)")
parser.add_argument("--sam-device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
help="Device for SAM (cuda or cpu)")
parser.add_argument("--save-output", type=str, default="",
help="Save preview images to this folder instead of displaying. If empty, display window. For batch mode, this is the base output folder where subfolders will be created for each SVO2 file.")
parser.add_argument("--save-images", action="store_true",
help="Save individual image files instead of creating a video (only works with --save-output)")
parser.add_argument("--filter-pointcloud", action="store_true",
help="Apply filtering to remove outliers from point clouds (default: no filtering)")
parser.add_argument("--use-clustering-filter", action="store_true",
help="Use clustering filter to keep only the largest cluster (requires --filter-pointcloud)")
parser.add_argument("--use-density-filter", action="store_true",
help="Use density filter: keep only points with at least 200 neighbors within 100mm radius (requires --filter-pointcloud)")
parser.add_argument("--max-frames", type=int, default=0,
help="Maximum frames to process from SVO2 (0 = all frames)")
parser.add_argument("--pointcloud-classifier", type=str, default=None,
help="Path to point cloud quality classifier checkpoint (e.g., log/classification/fish_pointnet2_finetune/checkpoints/best_model.pth)")
parser.add_argument("--use-pointcloud-classifier", action="store_true",
help="Use point cloud classifier to filter out bad quality point clouds (requires --pointcloud-classifier)")
parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7,
help="Confidence threshold for point cloud classifier (default: 0.7). Only point clouds classified as 'good' with confidence >= threshold will be saved.")
parser.add_argument("--use-flatness-filter", action="store_true",
help="Evaluate point cloud flatness before saving. Skip point clouds that are not flat enough.")
parser.add_argument("--flatness-threshold", type=float, default=90.0,
help="Minimum flatness score (0-100%%) required to save point cloud (default: 50.0%%). Higher values mean stricter flatness requirement.")
parser.add_argument("--run-template-matching", action="store_true",
help="After processing, run template matching on point clouds that passed PointNet++ classifier")
parser.add_argument("--template", type=str, default=None,
help="Path to template mesh file for template matching (required if --run-template-matching is set)")
parser.add_argument("--template-folder", type=str, default=None,
help="Folder containing multiple template meshes. Will try all templates and select the one with best alignment (alternative to --template)")
parser.add_argument("--template-scale-factor", type=float, default=1.0,
help="Scale factor applied to template mesh before weight calculation (default: 1.0)")
parser.add_argument("--save-raw-pointclouds", action="store_true",
help="Save point clouds to raw_pc folder before passing to PointNet++ classifier. Useful for debugging why some point clouds fail classification.")
parser.add_argument("--correct-tail-rotation", action="store_true",
help="Correct fish tail rotation using dual RANSAC plane fitting. This should be applied before template matching as it affects height/length ratio.")
parser.add_argument("--tail-rotation-distance-threshold", type=float, default=5.0,
help="Maximum distance from point to plane for tail rotation correction (mm, default: 5.0)")
parser.add_argument("--tail-rotation-min-tail-ratio", type=float, default=0.7,
help="Minimum ratio of outliers on tail plane to detect rotation (default: 0.7)")
parser.add_argument("--tail-rotation-min-angle", type=float, default=5.0,
help="Minimum angle (degrees) between planes to apply tail rotation correction (default: 5.0)")
args = parser.parse_args()
# Validate input source
input_count = sum([bool(args.image_folder), bool(args.svo), bool(args.batch_svo_folder)])
if input_count == 0:
parser.error("One of --image-folder, --svo, or --batch-svo-folder must be provided")
if input_count > 1:
parser.error("Cannot specify multiple input sources. Use only one of --image-folder, --svo, or --batch-svo-folder")
# Determine input source
use_batch_mode = bool(args.batch_svo_folder)
use_svo = bool(args.svo) and not use_batch_mode
# Handle batch mode first
if use_batch_mode:
if not ZED_AVAILABLE:
print("ERROR: pyzed not available. Cannot read SVO2 files.")
return
svo_folder = Path(args.batch_svo_folder).expanduser().resolve()
if not svo_folder.exists() or not svo_folder.is_dir():
print(f"ERROR: SVO folder not found: {svo_folder}")
return
# Load YOLO
print(f"Loading YOLO: {args.yolo_model}")
yolo_model = YOLO(args.yolo_model)
print(f"✓ YOLO loaded")
# Load SAM
print(f"Loading SAM (device: {args.sam_device})...")
sam_predictor = init_models(device=args.sam_device, seg_model="sam")
sam_device_obj = torch.device(args.sam_device)
print(f"✓ SAM loaded")
# Load point cloud classifier if specified
pointcloud_classifier = None
if args.use_pointcloud_classifier:
classifier_path = args.pointcloud_classifier
if classifier_path is None:
default_path = Path(__file__).parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "log" / "classification" / "fish_pointnet2_finetune" / "checkpoints" / "best_model.pth"
if default_path.exists():
classifier_path = str(default_path)
print(f"Using default classifier path: {classifier_path}")
else:
print("Warning: --use-pointcloud-classifier specified but --pointcloud-classifier not provided and default not found")
args.use_pointcloud_classifier = False
if args.use_pointcloud_classifier and classifier_path:
print(f"Loading point cloud classifier from: {classifier_path}")
pointcloud_classifier = load_pointcloud_classifier(
classifier_path,
num_classes=2,
use_normals=False,
device=args.sam_device
)
if pointcloud_classifier is None:
print("Warning: Failed to load point cloud classifier. Continuing without quality filtering.")
args.use_pointcloud_classifier = False
else:
print(f"✓ Point cloud classifier loaded")
# Validate template matching arguments
if args.run_template_matching:
if not args.template and not args.template_folder:
parser.error("--run-template-matching requires either --template or --template-folder to be specified")
# Process batch
output_base = args.save_output if args.save_output else "output_preview"
process_batch_svo2_folder(
svo_folder=str(svo_folder),
output_base=output_base,
yolo_model=yolo_model,
sam_predictor=sam_predictor,
sam_device=sam_device_obj,
conf=args.conf,
imgsz=args.imgsz,
max_frames=args.max_frames,
save_images=args.save_images,
filter_pointcloud=args.filter_pointcloud,
use_clustering_filter=args.use_clustering_filter,
use_density_filter=args.use_density_filter,
pointcloud_classifier=pointcloud_classifier,
use_pointcloud_classifier=args.use_pointcloud_classifier,
pointcloud_classifier_threshold=args.pointcloud_classifier_threshold,
use_flatness_filter=args.use_flatness_filter,
flatness_threshold=args.flatness_threshold,
run_template_matching=args.run_template_matching,
template_path=args.template,
template_folder=args.template_folder,
template_scale_factor=args.template_scale_factor,
save_raw_pointclouds=args.save_raw_pointclouds,
correct_tail_rotation=args.correct_tail_rotation,
tail_rotation_distance_threshold=args.tail_rotation_distance_threshold,
tail_rotation_min_tail_ratio=args.tail_rotation_min_tail_ratio,
tail_rotation_min_angle=args.tail_rotation_min_angle
)
return
if use_svo:
if not ZED_AVAILABLE:
print("ERROR: pyzed not available. Cannot read SVO2 files.")
return
svo_path = Path(args.svo).expanduser().resolve()
if not svo_path.exists():
print(f"ERROR: SVO2 file not found: {svo_path}")
return
print(f"Reading from SVO2 file: {svo_path}")
else:
image_folder = Path(args.image_folder)
if not image_folder.is_dir():
print(f"ERROR: Folder not found: {image_folder}")
return
image_files = []
for ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
image_files.extend(sorted(image_folder.glob(f'*{ext}')))
image_files.extend(sorted(image_folder.glob(f'*{ext.upper()}')))
if not image_files:
print(f"ERROR: No images found in {image_folder}")
return
print(f"Found {len(image_files)} images")
# Load YOLO
print(f"Loading YOLO: {args.yolo_model}")
yolo_model = YOLO(args.yolo_model)
class_names = yolo_model.names if hasattr(yolo_model, 'names') else {}
print(f"✓ YOLO loaded. Classes: {class_names}")
# Load SAM
print(f"Loading SAM (device: {args.sam_device})...")
sam_predictor = init_models(device=args.sam_device, seg_model="sam")
sam_device = torch.device(args.sam_device)
print(f"✓ SAM loaded")
# Load point cloud classifier if specified
pointcloud_classifier = None
if args.use_pointcloud_classifier:
# Determine classifier path
classifier_path = args.pointcloud_classifier
if classifier_path is None:
# Try default path
default_path = Path(__file__).parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "log" / "classification" / "fish_pointnet2_finetune" / "checkpoints" / "best_model.pth"
if default_path.exists():
classifier_path = str(default_path)
print(f"Using default classifier path: {classifier_path}")
else:
print("Warning: --use-pointcloud-classifier specified but --pointcloud-classifier not provided and default not found")
print(f" Default path checked: {default_path}")
args.use_pointcloud_classifier = False
# Load classifier if path is available
if classifier_path is not None:
print(f"Loading point cloud classifier from: {classifier_path}")
pointcloud_classifier = load_pointcloud_classifier(
classifier_path,
num_classes=2,
use_normals=False,
device=args.sam_device
)
if pointcloud_classifier is None:
print("ERROR: Failed to load point cloud classifier from:", classifier_path)
print(" Continuing without quality filtering. Check the path and model file.")
args.use_pointcloud_classifier = False
else:
print(f"✓ Point cloud classifier loaded successfully from: {classifier_path}")
print(f" Confidence threshold: {args.pointcloud_classifier_threshold}")
print(f" Only point clouds classified as 'good' with confidence >= {args.pointcloud_classifier_threshold} will be saved")
# Setup output folders if saving
save_mode = bool(args.save_output)
output_images_folder = None
output_cloud_folder = None
video_frames = [] # Store frames for video creation
video_writer = None
if save_mode:
output_base = Path(args.save_output)
# Create subfolder based on input source name
if use_svo:
# Extract filename without extension from SVO2 path
svo_name = svo_path.stem # e.g., "HD1080_SN43186771_16-41-37" from "HD1080_SN43186771_16-41-37.svo2"
output_base = output_base / svo_name
else:
# For image folders, use folder name
folder_name = image_folder.name
output_base = output_base / folder_name
output_images_folder = output_base / "images"
output_cloud_folder = output_base / "cloud"
output_raw_pc_folder = output_base / "raw_pc" if args.save_raw_pointclouds else None
output_images_folder.mkdir(parents=True, exist_ok=True)
output_cloud_folder.mkdir(parents=True, exist_ok=True)
if args.save_raw_pointclouds and output_raw_pc_folder:
output_raw_pc_folder.mkdir(parents=True, exist_ok=True)
print(f"✓ Output base folder: {output_base.resolve()}")
print(f"✓ Output images folder (for video): {output_images_folder.resolve()}")
print(f"✓ Output cloud folder: {output_cloud_folder.resolve()}")
if args.save_raw_pointclouds:
print(f"✓ Output raw point cloud folder (before classifier): {output_raw_pc_folder.resolve()}")
# Check if output folder already exists and contains point clouds
# If so, skip data generation and directly run template matching
if output_base.exists() and output_cloud_folder.exists():
# Check if there are point cloud files
point_cloud_files = list(output_cloud_folder.glob("*.ply"))
if point_cloud_files and args.run_template_matching:
print(f"\n{'='*60}")
print(f"Output folder already exists with {len(point_cloud_files)} point cloud files")
print(f"Skipping data generation, directly running template matching...")
print(f"{'='*60}")
# Load kept point clouds from file if it exists, otherwise use all point clouds
kept_pointclouds = []
pointcloud_list_path = output_base / "pointclouds_kept_for_template_matching.txt"
if pointcloud_list_path.exists():
print(f" Loading point cloud list from: {pointcloud_list_path.name}")
with open(pointcloud_list_path, 'r', encoding='utf-8') as f:
kept_pointclouds = [line.strip() for line in f if line.strip()]
print(f" Found {len(kept_pointclouds)} point clouds in list")
else:
# If no list file, use all point clouds in the folder
kept_pointclouds = [str(f) for f in point_cloud_files]
print(f" No point cloud list file found, using all {len(kept_pointclouds)} point clouds")
if kept_pointclouds:
# Validate template matching arguments
if not args.template and not args.template_folder:
print(f" ERROR: --run-template-matching requires either --template or --template-folder to be specified")
return
# Run template matching directly
print(f"\n{'='*60}")
print(f"Running template matching on {len(kept_pointclouds)} point clouds...")
print(f"{'='*60}")
# Create output directory for template matching
template_output_dir = output_base / "template_matching"
template_output_dir.mkdir(parents=True, exist_ok=True)
# Import fish_align_cli function
try:
# Add template_matching directory to path
template_matching_dir = Path(__file__).parent / "template_matching"
if str(template_matching_dir) not in sys.path:
sys.path.insert(0, str(template_matching_dir))
from fish_align_cli import process_folder_mode
# Process all point clouds using folder mode
process_folder_mode(
input_folder=str(output_cloud_folder),
template_path=args.template,
template_folder=args.template_folder,
output_dir=str(template_output_dir),
template_scale_factor=args.template_scale_factor,
debug=False,
debug_dir=None
)
print(f"✓ Template matching completed. Results saved to: {template_output_dir}")
return
except Exception as e:
print(f" ERROR: Failed to run template matching: {e}")
import traceback
traceback.print_exc()
return
else:
print(f" Warning: No point clouds found to process")
return
elif point_cloud_files and not args.run_template_matching:
print(f"\n{'='*60}")
print(f"Output folder already exists with {len(point_cloud_files)} point cloud files")
print(f"Template matching not requested (--run-template-matching not set)")
print(f"Skipping processing...")
print(f"{'='*60}")
return
# If folder exists but no point clouds, continue with normal processing
else:
window_name = "Fish Detection & Segmentation Preview"
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
print("\n" + "="*60)
if save_mode:
print("Image Processing Started (Saving to disk)")
else:
print("Image Preview Started")
print("Controls:")
print(" Press any key - Next image")
print(" Press 'q' - Quit")
print("="*60 + "\n")
# Process frames
idx = 0
zed_reader = None
runtime_params = None
left_image_mat = None
depth_mat = None
camera_intrinsics = None
# Tracking for consecutive frames with fish
previous_boxes = None
fish_tracks: Dict[int, List[Dict[str, Any]]] = {} # track_id -> list of frame data
track_bbox_history: Dict[int, List[List[float]]] = {} # track_id -> list of bboxes
track_stationary_count: Dict[int, int] = {} # track_id -> consecutive stationary frames
next_track_id = 0
STATIONARY_THRESHOLD = 10 # Number of consecutive frames without movement
MOVEMENT_THRESHOLD = 5.0 # Minimum pixel movement to consider as moved
# List to track point clouds that passed PointNet++ classifier (if enabled)
kept_pointclouds = []
if use_svo:
# Initialize ZED reader for SVO2
zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False)
if not zed_reader.open():
print("ERROR: Failed to open SVO2 file")
return
# Get camera intrinsics from ZED
calib_params = zed_reader.zed.get_camera_information().camera_configuration.calibration_parameters
camera_intrinsics = {
"fx": float(calib_params.left_cam.fx),
"fy": float(calib_params.left_cam.fy),
"cx": float(calib_params.left_cam.cx),
"cy": float(calib_params.left_cam.cy),
}
print(f"✓ SVO2 file opened. Camera intrinsics: fx={camera_intrinsics['fx']:.1f}, fy={camera_intrinsics['fy']:.1f}")
runtime_params = sl.RuntimeParameters()
left_image_mat = sl.Mat()
depth_mat = sl.Mat()
else:
# For image folder, use default/estimated intrinsics
# These are typical values for HD1080 ZED camera
camera_intrinsics = {
"fx": 1066.8, # Approximate for HD1080
"fy": 1066.8,
"cx": 960.0,
"cy": 540.0,
}
print(f"Using default camera intrinsics: fx={camera_intrinsics['fx']:.1f}, fy={camera_intrinsics['fy']:.1f}")
while True:
if use_svo:
# Read frame from SVO2
if args.max_frames > 0 and idx >= args.max_frames:
break
err = zed_reader.zed.grab(runtime_params)
if err != sl.ERROR_CODE.SUCCESS:
print(f"\nEnd of SVO2 file or error: {err}")
break
# Retrieve images
zed_reader.zed.retrieve_image(left_image_mat, sl.VIEW.LEFT)
zed_reader.zed.retrieve_measure(depth_mat, sl.MEASURE.DEPTH)
# Convert to numpy
left_np = left_image_mat.get_data()
depth_np = depth_mat.get_data()
# Ensure BGR format (drop alpha if present)
if left_np.shape[2] > 3:
img = left_np[:, :, :3].copy()
else:
img = left_np.copy()
# Convert depth to float32 if needed
if depth_np.dtype != np.float32:
depth_data = depth_np.astype(np.float32)
else:
depth_data = depth_np.copy()
frame_name = f"frame_{idx+1:06d}"
print(f"[{idx + 1}] {frame_name}")
else:
# Read from image folder
if idx >= len(image_files):
break
img_path = image_files[idx]
print(f"[{idx + 1}/{len(image_files)}] {img_path.name}")
# Load image
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
if img is None:
print(f" ERROR: Failed to load")
idx += 1
continue
# Try to find corresponding depth file
depth_data = None
# Look for .npy depth files (common format from ZED)
depth_candidates = [
img_path.parent / f"{img_path.stem}_depth.npy",
img_path.parent / f"{img_path.stem.replace('_rgb', '_depth')}.npy",
img_path.parent / f"{img_path.stem.replace('rgb', 'depth')}.npy",
]
for depth_npy_path in depth_candidates:
if depth_npy_path.exists():
try:
depth_data = np.load(str(depth_npy_path))
print(f" Found depth file: {depth_npy_path.name}")
break
except Exception as e:
print(f" Warning: Could not load depth file {depth_npy_path.name}: {e}")
frame_name = img_path.stem
# Run YOLO tracking
results = yolo_model.track(img, conf=args.conf, imgsz=args.imgsz, verbose=False, persist=True)[0]
num_dets = len(results.boxes) if results.boxes is not None else 0
print(f" Detections: {num_dets}")
# Right side: Segmentation overlay and tracking
individual_masks = []
current_boxes = None
track_ids = []
depth_stats_list = [] # Store depth stats for each detection
try:
if num_dets > 0:
boxes_xyxy = results.boxes.xyxy.cpu().numpy()
current_boxes = boxes_xyxy.tolist()
# Get track IDs from YOLO tracking results
if hasattr(results.boxes, 'id') and results.boxes.id is not None:
track_ids = results.boxes.id.cpu().numpy().astype(int).tolist()
else:
# Fallback: assign sequential IDs if tracking not available
track_ids = list(range(next_track_id, next_track_id + len(current_boxes)))
next_track_id += len(current_boxes)
# Check for stationary fish and filter them out
active_detections = []
active_masks = []
active_track_ids = []
active_boxes = []
for fish_idx, (box, track_id) in enumerate(zip(current_boxes, track_ids)):
# Initialize tracking history if new track
if track_id not in track_bbox_history:
track_bbox_history[track_id] = []
track_stationary_count[track_id] = 0
# Check if bbox is stationary
is_stationary = False
if len(track_bbox_history[track_id]) > 0:
is_stationary = is_bbox_stationary(box, track_bbox_history[track_id], MOVEMENT_THRESHOLD)
if is_stationary:
track_stationary_count[track_id] += 1
if track_stationary_count[track_id] >= STATIONARY_THRESHOLD:
print(f" Fish {fish_idx+1} (Track {track_id}): STATIONARY for {track_stationary_count[track_id]} frames - FILTERED OUT")
continue # Skip this detection
else:
print(f" Fish {fish_idx+1} (Track {track_id}): stationary for {track_stationary_count[track_id]} frames (will filter at {STATIONARY_THRESHOLD})")
else:
# Fish moved, reset stationary count
track_stationary_count[track_id] = 0
# Add to active detections
active_detections.append(fish_idx)
active_boxes.append(box)
active_track_ids.append(track_id)
# Update bbox history (keep last 10 for checking)
track_bbox_history[track_id].append(box)
if len(track_bbox_history[track_id]) > 10:
track_bbox_history[track_id].pop(0)
# Only process active (moving) detections
# Initialize depth_stats_list for all detections (will be filled for active ones)
depth_stats_list = [None] * len(current_boxes) if current_boxes else []
if len(active_detections) == 0:
print(f" All {num_dets} detection(s) are stationary - skipping frame")
num_dets = 0 # Treat as no detections
individual_masks = []
previous_boxes = None # Reset tracking
else:
print(f" Processing {len(active_detections)} active detection(s) (filtered {num_dets - len(active_detections)} stationary)")
# Get masks only for active detections
active_boxes_array = np.array(active_boxes)
print(f" Running SAM segmentation...")
all_masks = segment_with_sam(sam_predictor, img, active_boxes_array, sam_device)
# Map masks to active detections
individual_masks = all_masks if all_masks else []
# Calculate depth statistics for each active fish and map to original detection order
if individual_masks and len(individual_masks) > 0 and depth_data is not None:
for mask_idx, (mask, track_id, box) in enumerate(zip(individual_masks, active_track_ids, active_boxes)):
depth_stats = calculate_fish_depth_stats(depth_data, mask)
if depth_stats:
# Map to original detection index
original_idx = active_detections[mask_idx]
depth_stats_list[original_idx] = depth_stats
frame_data = {
"frame_index": idx + 1,
"frame_name": frame_name,
"fish_index": active_detections[mask_idx],
"track_id": int(track_id),
"bbox": box,
"depth_stats": depth_stats
}
if track_id not in fish_tracks:
fish_tracks[track_id] = []
fish_tracks[track_id].append(frame_data)
print(f" Fish {active_detections[mask_idx]+1} (Track {track_id}): "
f"mean_depth={depth_stats['mean_depth_mm']:.1f}mm, "
f"std={depth_stats['std_depth_mm']:.1f}mm")
# Update previous boxes for next frame (only active ones)
previous_boxes = active_boxes
current_boxes = active_boxes # Update for saving
num_dets = len(active_detections) # Update count
if individual_masks and len(individual_masks) > 0:
right_display = create_segmentation_overlay(img.copy(), individual_masks)
cv2.putText(right_display, "Segmentation", (10, right_display.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
else:
right_display = img.copy()
cv2.putText(right_display, "Segmentation (failed)", (10, right_display.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2, cv2.LINE_AA)
else:
right_display = img.copy()
cv2.putText(right_display, "No detections", (10, right_display.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (128, 128, 128), 2, cv2.LINE_AA)
# Reset tracking when no detections
previous_boxes = None
depth_stats_list = []
except Exception as e:
print(f" ERROR in SAM segmentation: {e}")
import traceback
traceback.print_exc()
right_display = img.copy()
cv2.putText(right_display, f"Error: {str(e)[:30]}", (10, right_display.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2, cv2.LINE_AA)
previous_boxes = None
depth_stats_list = []
# Left side: Original image with detection boxes and depth info
left_display = draw_detections(img.copy(), results, class_names, depth_stats_list)
if use_svo:
info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}"
else:
info = f"[{idx + 1}/{len(image_files)}] {img_path.name} | Detections: {num_dets}"
cv2.putText(left_display, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
cv2.putText(left_display, "Detection", (10, left_display.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
# Combine left and right side by side
combined_display = np.hstack([left_display, right_display])
# Scale if needed
if args.scale != 1.0:
h, w = combined_display.shape[:2]
combined_display = cv2.resize(combined_display, (int(w * args.scale), int(h * args.scale)))
# Save or display
if save_mode:
# Save image or collect for video only when fish detected
if num_dets > 0:
if args.save_images:
# Save individual image file
if use_svo:
image_path = output_images_folder / f"{frame_name}_detection.png"
else:
image_path = output_images_folder / f"{img_path.stem}_detection.png"
cv2.imwrite(str(image_path), combined_display)
print(f" Saved image: {image_path.name}")
else:
# Collect for video
video_frames.append(combined_display.copy())
print(f" Added frame {len(video_frames)} to video")
# Generate and save point clouds for each detected fish
if num_dets > 0 and depth_data is not None and individual_masks and len(individual_masks) > 0:
print(f" Generating point clouds for {len(individual_masks)} fish...")
for fish_idx, mask in enumerate(individual_masks):
# Verify mask is being used correctly
if mask is None:
print(f" Warning: Mask {fish_idx + 1} is None, skipping point cloud generation")
continue
# Check mask statistics
mask_pixels = np.sum(mask.astype(bool))
if mask_pixels == 0:
print(f" Warning: Mask {fish_idx + 1} is empty, skipping point cloud generation")
continue
# Validate depth values in mask before generating point cloud
depth_valid, reason, depth_stats = validate_depth_in_mask(
depth_data, mask, max_depth_mm=1200.0, min_depth_mm=100.0
)
if not depth_valid:
print(f" Skipped point cloud {fish_idx + 1}: depth validation failed - {reason}")
continue
points, colors = depth_mask_to_pointcloud(img, depth_data, mask, camera_intrinsics)
if points is not None and len(points) > 0:
# Apply filtering if enabled
original_count = len(points)
if args.filter_pointcloud:
points, colors = filter_point_cloud(
points, colors,
use_clustering_filter=args.use_clustering_filter,
use_density_filter=args.use_density_filter
)
filtered_count = len(points)
# Skip if less than 500 points
if filtered_count < 500:
print(f" Skipped point cloud {fish_idx + 1}: {original_count} -> {filtered_count} points (minimum: 500)")
continue
elif filtered_count == 0:
print(f" Warning: All points filtered out for fish {fish_idx + 1} (original: {original_count} points)")
continue
else:
# Check minimum point count
if len(points) < 500:
print(f" Skipped point cloud {fish_idx + 1}: only {len(points)} points (minimum: 500)")
continue
# Apply largest cluster filtering before classification (if classifier is enabled)
# This removes outliers and keeps only the main fish body cluster
if args.use_pointcloud_classifier:
cluster_before_count = len(points)
try:
points, colors, cluster_info = keep_largest_cluster_with_colors(
points, colors, eps=10.0, min_points=30
)
cluster_after_count = len(points)
if "error" in cluster_info:
print(f" Point cloud {fish_idx + 1}: Clustering failed - {cluster_info.get('error', 'Unknown error')}")
continue
if cluster_after_count < 500:
print(f" Point cloud {fish_idx + 1}: After clustering: {cluster_before_count} -> {cluster_after_count} points (minimum: 500) - SKIPPED")
continue
if cluster_after_count < cluster_before_count:
print(f" Point cloud {fish_idx + 1}: Clustering removed {cluster_before_count - cluster_after_count} points "
f"({cluster_info.get('num_clusters', 0)} clusters found, kept largest with {cluster_info.get('largest_cluster_size', 0)} points)")
except Exception as e:
print(f" Point cloud {fish_idx + 1}: WARNING - Largest cluster filtering failed: {e}")
# Continue with original points if clustering fails
# Save raw point cloud before classification (if flag is enabled)
if args.save_raw_pointclouds and output_raw_pc_folder:
postfix = f"_{fish_idx + 1}" if len(individual_masks) > 1 else ""
if use_svo:
raw_ply_path = output_raw_pc_folder / f"raw_cloud_{idx+1:04d}_{frame_name}{postfix}.ply"
else:
raw_ply_path = output_raw_pc_folder / f"raw_cloud_{idx+1:04d}_{img_path.stem}{postfix}.ply"
write_ply_file(raw_ply_path, points, colors)
print(f" Saved raw point cloud {fish_idx + 1} (before classifier): {raw_ply_path.name} ({len(points)} points)")
# Classify point cloud quality if classifier is available
# mport pdb; pdb.set_trace()
if args.use_pointcloud_classifier:
if pointcloud_classifier is not None:
# Classifier is available - perform classification
is_good, confidence, class_result = classify_pointcloud_array(
pointcloud_classifier, points, colors, num_point=1024, vote_num=3,
confidence_threshold=args.pointcloud_classifier_threshold,
use_cpu=(args.sam_device == "cpu")
)
# Always print the prediction result
prediction = class_result.get("prediction", "unknown")
class_id = class_result.get("class_id", -1)
if not is_good:
# Point cloud is bad - don't save it
if class_id == 1:
# Predicted as good but confidence too low
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SKIPPED (confidence {confidence:.3f} < threshold {args.pointcloud_classifier_threshold:.3f})")
else:
# Predicted as bad
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SKIPPED (BAD quality)")
continue
else:
# Point cloud is good - proceed to save
print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SAVING (GOOD quality, confidence >= {args.pointcloud_classifier_threshold:.3f})")
else:
# Classifier requested but not available
print(f" Point cloud {fish_idx + 1}: WARNING - Classifier requested but not loaded, saving without classification")
else:
# Classifier not enabled - this should not happen if --use-pointcloud-classifier was used
# This means either the flag was not set, or classifier loading failed
if args.use_pointcloud_classifier:
print(f" Point cloud {fish_idx + 1}: ERROR - Classifier flag was set but classifier is None. Check startup logs for loading errors.")
else:
print(f" Point cloud {fish_idx + 1}: Classifier not enabled (use --use-pointcloud-classifier to enable)")
# Evaluate flatness if enabled
if args.use_flatness_filter:
try:
flatness_score, flatness_info = evaluate_flatness_ransac(
points,
distance_threshold=5.0,
ransac_n=3,
num_iterations=1000
)
if flatness_score < args.flatness_threshold:
print(f" Point cloud {fish_idx + 1}: Flatness score {flatness_score:.2f}% < threshold {args.flatness_threshold:.2f}% - SKIPPED (not flat enough)")
continue
else:
print(f" Point cloud {fish_idx + 1}: Flatness score {flatness_score:.2f}% >= threshold {args.flatness_threshold:.2f}% - PASSED")
except Exception as e:
print(f" Point cloud {fish_idx + 1}: WARNING - Flatness evaluation failed: {e}")
# If flatness check is required and fails, skip saving
if args.use_flatness_filter:
print(f" Point cloud {fish_idx + 1}: Skipping due to flatness evaluation error")
continue
# Correct tail rotation if enabled (before saving, after all checks)
if args.correct_tail_rotation:
try:
tail_correction_before_count = len(points)
points, colors, tail_correction_applied, tail_correction_info = correct_tail_rotation_array(
points, colors,
distance_threshold=args.tail_rotation_distance_threshold,
min_tail_ratio=args.tail_rotation_min_tail_ratio,
min_angle_threshold=args.tail_rotation_min_angle,
verbose=True
)
if tail_correction_applied:
print(f" Point cloud {fish_idx + 1}: Tail rotation corrected (angle: {tail_correction_info.get('rotation_angle_degrees', 0):.2f}°)")
except Exception as e:
print(f" Point cloud {fish_idx + 1}: WARNING - Tail rotation correction failed: {e}")
# Continue with original points if correction fails
# Save point cloud (passed all checks)
filtered_count = len(points)
# Create filename with postfix for multiple fish
if len(individual_masks) > 1:
postfix = f"_{fish_idx + 1}"
else:
postfix = ""
if use_svo:
ply_path = output_cloud_folder / f"cloud_{idx+1:04d}_{frame_name}{postfix}.ply"
else:
ply_path = output_cloud_folder / f"cloud_{idx+1:04d}_{img_path.stem}{postfix}.ply"
write_ply_file(ply_path, points, colors)
# Track point clouds that passed PointNet++ classifier (if enabled)
# If classifier is enabled, only track those that passed classifier
# If classifier is not enabled, track all saved point clouds
if args.use_pointcloud_classifier and pointcloud_classifier is not None:
# This point cloud passed all checks including classifier
kept_pointclouds.append(str(ply_path))
elif not args.use_pointcloud_classifier:
# Classifier not enabled, track all saved point clouds
kept_pointclouds.append(str(ply_path))
if args.filter_pointcloud:
print(f" Saved point cloud {fish_idx + 1}: {ply_path.name} ({original_count} -> {filtered_count} points)")
else:
print(f" Saved point cloud {fish_idx + 1}: {ply_path.name} ({filtered_count} points)")
else:
print(f" Warning: No valid points for fish {fish_idx + 1}")
idx += 1
else:
# Display window
print(f" Displaying window...")
try:
cv2.imshow(window_name, combined_display)
print(f" Window displayed. Waiting for keypress...")
except Exception as e:
print(f" ERROR displaying window: {e}")
import traceback
traceback.print_exc()
idx += 1
continue
# Wait for keypress
key = cv2.waitKey(0) & 0xFF
if key == ord('q'):
print("Quit")
break
# Any other key - next image
idx += 1
if not save_mode:
cv2.destroyAllWindows()
if use_svo and zed_reader:
zed_reader.close()
if save_mode:
total_frames = idx
output_base_path = output_images_folder.parent
# Create video from collected frames and save in images folder (only if not saving individual images)
if not args.save_images and video_frames:
if use_svo:
video_path = output_images_folder / f"{svo_name}_preview.mp4"
else:
video_path = output_images_folder / f"{folder_name}_preview.mp4"
print(f"\nCreating video from {len(video_frames)} frames...")
if len(video_frames) > 0:
h, w = video_frames[0].shape[:2]
fps = 10.0
video_writer = _open_video_writer(video_path, fps, (w, h))
for frame in video_frames:
video_writer.write(frame)
video_writer.release()
print(f"✓ Saved video: {video_path.name} ({len(video_frames)} frames, {fps} fps)")
elif args.save_images:
print(f"\n✓ Saved {total_frames} frames as individual images")
# Save fish tracking and depth statistics
if fish_tracks:
stats_path = output_base_path / "fish_depth_tracking.json"
tracking_data = {
"source": str(svo_path) if use_svo else str(image_folder),
"total_frames_processed": total_frames,
"total_tracks": len(fish_tracks),
"tracks": {}
}
for track_id, frames in fish_tracks.items():
tracking_data["tracks"][str(track_id)] = {
"track_id": track_id,
"num_frames": len(frames),
"frames": frames,
"depth_summary": {
"mean_depth_overall_mm": float(np.mean([f["depth_stats"]["mean_depth_mm"] for f in frames])),
"std_depth_overall_mm": float(np.std([f["depth_stats"]["mean_depth_mm"] for f in frames])),
"min_depth_mm": float(np.min([f["depth_stats"]["min_depth_mm"] for f in frames])),
"max_depth_mm": float(np.max([f["depth_stats"]["max_depth_mm"] for f in frames])),
}
}
with open(stats_path, 'w', encoding='utf-8') as f:
json.dump(tracking_data, f, indent=2)
print(f"\n✓ Saved fish depth tracking to: {stats_path}")
# Save list of point clouds kept for template matching
if kept_pointclouds:
pointcloud_list_path = output_base_path / "pointclouds_kept_for_template_matching.txt"
with open(pointcloud_list_path, 'w', encoding='utf-8') as f:
for ply_path in kept_pointclouds:
f.write(f"{ply_path}\n")
if args.use_pointcloud_classifier:
print(f"✓ Saved list of {len(kept_pointclouds)} point clouds that passed PointNet++ classifier: {pointcloud_list_path.name}")
else:
print(f"✓ Saved list of {len(kept_pointclouds)} point clouds for template matching: {pointcloud_list_path.name}")
else:
if args.use_pointcloud_classifier:
print(f" Note: PointNet++ classifier was enabled but no point clouds passed the filter")
else:
print(f" Note: No point clouds were saved")
# Run template matching if requested
if args.run_template_matching and kept_pointclouds:
print(f"\n{'='*60}")
print(f"Running template matching on {len(kept_pointclouds)} point clouds...")
print(f"{'='*60}")
# Validate template matching arguments
if not args.template and not args.template_folder:
print(f" ERROR: --run-template-matching requires either --template or --template-folder to be specified")
else:
# Create output directory for template matching
template_output_dir = output_base_path / "template_matching"
template_output_dir.mkdir(parents=True, exist_ok=True)
# Import fish_align_cli function
try:
# Add template_matching directory to path
template_matching_dir = Path(__file__).parent / "template_matching"
if str(template_matching_dir) not in sys.path:
sys.path.insert(0, str(template_matching_dir))
from fish_align_cli import process_folder_mode
# Process all kept point clouds using folder mode
# This will align all point clouds and use the one with maximum length for weight calculation
process_folder_mode(
input_folder=str(output_cloud_folder),
template_path=args.template,
template_folder=args.template_folder,
output_dir=str(template_output_dir),
template_scale_factor=args.template_scale_factor,
debug=False,
debug_dir=None
)
print(f"✓ Template matching completed. Results saved to: {template_output_dir}")
except Exception as e:
print(f" ERROR: Failed to run template matching: {e}")
import traceback
traceback.print_exc()
elif args.run_template_matching and not kept_pointclouds:
if args.use_pointcloud_classifier:
print(f" Warning: Template matching requested but no point clouds passed PointNet++ classifier")
else:
print(f" Warning: Template matching requested but no point clouds were saved")
if use_svo:
print(f"\nProcessing ended. Saved {total_frames} frames to {output_base_path.resolve()}")
else:
print(f"\nProcessing ended. Saved {total_frames}/{len(image_files)} images to {output_base_path.resolve()}")
else:
if use_svo:
print(f"\nPreview ended. Viewed {idx} frames.")
else:
print(f"\nPreview ended. Viewed {idx}/{len(image_files)} images.")
if __name__ == "__main__":
main()