Files
FishServer/FishMeasure/dataset/zed_reader.py
2026-04-16 14:53:01 +08:00

802 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pyzed.sl as sl
import cv2
import numpy as np
import os
import time
from scipy.spatial import cKDTree
# Try to import fish detector (optional)
try:
from detector import FishDetector
DETECTOR_AVAILABLE = True
except ImportError:
DETECTOR_AVAILABLE = False
FishDetector = None
class ZEDReader:
def __init__(self, svo_path=None, camera_mode=False, use_yolo_detector=False, yolo_model_path=None):
# 创建ZED相机对象
self.zed = sl.Camera()
self.camera_mode = camera_mode
self.use_yolo_detector = use_yolo_detector
self.detector = None
# Initialize YOLO detector if requested
if use_yolo_detector:
if not DETECTOR_AVAILABLE:
print("警告: YOLO检测器不可用。请安装detector.py和相关依赖")
self.use_yolo_detector = False
else:
try:
self.detector = FishDetector(model_path=yolo_model_path)
print("✓ YOLO鱼类检测器已初始化")
except Exception as e:
print(f"警告: 初始化YOLO检测器失败: {e}")
self.use_yolo_detector = False
# 设置初始化参数
self.init_params = sl.InitParameters()
if svo_path and not camera_mode:
# 如果提供SVO文件路径则从文件读取
print(f"从SVO文件读取: {svo_path}")
self.init_params.set_from_svo_file(svo_path)
self.init_params.svo_real_time_mode = False
else:
# 实时相机模式设置
print("使用实时相机模式")
self.init_params.camera_resolution = sl.RESOLUTION.HD1200
self.init_params.depth_mode = sl.DEPTH_MODE.NEURAL
self.init_params.coordinate_units = sl.UNIT.MILLIMETER
self.init_params.camera_fps = 30 # 设置帧率
def open(self):
# 打开相机或SVO文件
err = self.zed.open(self.init_params)
if err != sl.ERROR_CODE.SUCCESS:
print(f"错误: {err}")
return False
# 获取并打印标定参数
self.print_calibration_params()
return True
def print_calibration_params(self):
"""打印相机标定参数"""
calibration_params = self.zed.get_camera_information().camera_configuration.calibration_parameters
print("\n=== 相机标定参数 ===")
# 左眼相机参数
print("\n左眼相机参数:")
print(f"焦距: fx={calibration_params.left_cam.fx}, fy={calibration_params.left_cam.fy}")
print(f"主点: cx={calibration_params.left_cam.cx}, cy={calibration_params.left_cam.cy}")
print("畸变参数:")
print(f"k1={calibration_params.left_cam.disto[0]}")
print(f"k2={calibration_params.left_cam.disto[1]}")
print(f"k3={calibration_params.left_cam.disto[2]}")
print(f"k4={calibration_params.left_cam.disto[3]}")
print(f"k5={calibration_params.left_cam.disto[4]}")
# 右眼相机参数
print("\n右眼相机参数:")
print(f"焦距: fx={calibration_params.right_cam.fx}, fy={calibration_params.right_cam.fy}")
print(f"主点: cx={calibration_params.right_cam.cx}, cy={calibration_params.right_cam.cy}")
print("畸变参数:")
print(f"k1={calibration_params.right_cam.disto[0]}")
print(f"k2={calibration_params.right_cam.disto[1]}")
print(f"k3={calibration_params.right_cam.disto[2]}")
print(f"k4={calibration_params.right_cam.disto[3]}")
print(f"k5={calibration_params.right_cam.disto[4]}")
#import pdb; pdb.set_trace()
print("\n立体参数:")
# 获取旋转矩阵
rotation = calibration_params.stereo_transform.m
print("旋转矩阵 R:")
print(f"[{rotation[0][0]}, {rotation[0][1]}, {rotation[0][2]}]")
print(f"[{rotation[1][0]}, {rotation[1][1]}, {rotation[1][2]}]")
print(f"[{rotation[2][0]}, {rotation[2][1]}, {rotation[2][2]}]")
print("\n平移向量 T:")
print(f"[{rotation[0][3]}, {rotation[1][3]}, {rotation[2][3]}]")
# 保存标定参数到文件
if hasattr(self, 'save_path') and self.save_path:
self.save_calibration_params(calibration_params)
def save_calibration_params(self, calibration_params):
"""将标定参数保存到文件"""
if not hasattr(self, 'save_path') or not self.save_path:
return
calib_file = os.path.join(self.save_path, 'calibration_params.txt')
with open(calib_file, 'w') as f:
f.write("=== 相机标定参数 ===\n")
# 左眼相机
f.write("\n左眼相机参数:\n")
f.write(f"焦距: fx={calibration_params.left_cam.fx}, fy={calibration_params.left_cam.fy}\n")
f.write(f"主点: cx={calibration_params.left_cam.cx}, cy={calibration_params.left_cam.cy}\n")
f.write("畸变参数:\n")
for i, d in enumerate(calibration_params.left_cam.disto):
f.write(f"k{i+1}={d}\n")
# 右眼相机
f.write("\n右眼相机参数:\n")
f.write(f"焦距: fx={calibration_params.right_cam.fx}, fy={calibration_params.right_cam.fy}\n")
f.write(f"主点: cx={calibration_params.right_cam.cx}, cy={calibration_params.right_cam.cy}\n")
f.write("畸变参数:\n")
for i, d in enumerate(calibration_params.right_cam.disto):
f.write(f"k{i+1}={d}\n")
# 立体参数
f.write("\n立体参数:\n")
rotation = calibration_params.stereo_transform.m
f.write("旋转矩阵 R:\n")
f.write(f"[{rotation[0][0]}, {rotation[0][1]}, {rotation[0][2]}]\n")
f.write(f"[{rotation[1][0]}, {rotation[1][1]}, {rotation[1][2]}]\n")
f.write(f"[{rotation[2][0]}, {rotation[2][1]}, {rotation[2][2]}]\n")
f.write("\n平移向量 T:\n")
f.write(f"[{rotation[0][3]}, {rotation[1][3]}, {rotation[2][3]}]\n")
def start_recording(self, filename=None):
"""开始录制"""
if filename is None:
# 使用时间戳作为文件名
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"recording_{timestamp}.svo"
recording_param = sl.RecordingParameters(filename, sl.SVO_COMPRESSION_MODE.H264)
err = self.zed.enable_recording(recording_param)
if err != sl.ERROR_CODE.SUCCESS:
print(f"无法开始录制: {err}")
return False
print(f"开始录制到文件: {filename}")
return True
def ensure_save_path(self, save_path):
"""确保保存路径存在,如果不存在则创建"""
if save_path:
if not os.path.exists(save_path):
os.makedirs(save_path)
print(f"创建保存目录: {save_path}")
return True
return False
def detect_fish(self, depth_np, calib_params, min_points=100, max_depth=1200, max_y=150):
"""
检测深度图中是否有鱼类
检测条件:
- z+ 方向深度0-max_depth mm
- y- 方向相机坐标系y轴负方向y < max_y
- 满足条件的点数 >= min_points
Args:
depth_np: 深度图的numpy数组 (H, W) 单位为mm
calib_params: 相机标定参数
min_points: 检测到鱼所需的最少点数默认100
max_depth: 最大深度值mm默认1200
max_y: 最大Y坐标值mm默认150
Returns:
tuple: (是否检测到鱼(bool), 满足条件的点数(int))
"""
height, width = depth_np.shape
fx = calib_params.left_cam.fx
fy = calib_params.left_cam.fy
cx = calib_params.left_cam.cx
cy = calib_params.left_cam.cy
# 创建网格坐标
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
# 计算3D点
Z = depth_np # 深度值mm
X = (x_grid - cx) * Z / fx
Y = (y_grid - cy) * Z / fy
# 检测条件:
# 1. z+ 方向:深度值在 0-max_depth mm 范围内
# 2. y- 方向Y坐标小于max_y相机坐标系中y轴向上为正
# 3. 深度值有效(> 0
valid_depth = (Z > 0) & (Z <= max_depth)
valid_y = Y < max_y
# 同时满足所有条件的点
fish_mask = valid_depth & valid_y
# 统计满足条件的点数
fish_point_count = np.sum(fish_mask)
# 如果点数 >= min_points认为检测到鱼
fish_detected = fish_point_count >= min_points
return fish_detected, fish_point_count
def downsample_point_cloud(self, points, colors, voxel_size=5.0):
"""
使用体素网格对点云进行均匀下采样,去除稀疏点
Args:
points: 点云坐标数组 (N, 3)
colors: 颜色数组 (N, 3)
voxel_size: 体素大小mm默认5.0mm
Returns:
tuple: (下采样后的点云, 下采样后的颜色)
"""
if len(points) == 0:
return points, colors
# 计算每个点所属的体素索引
voxel_indices = np.floor(points / voxel_size).astype(np.int32)
# 使用字典来存储每个体素中的点(保留第一个点)
voxel_dict = {}
for i in range(len(points)):
voxel_key = tuple(voxel_indices[i])
if voxel_key not in voxel_dict:
voxel_dict[voxel_key] = i
# 获取保留的点的索引
kept_indices = list(voxel_dict.values())
return points[kept_indices], colors[kept_indices]
def filter_point_cloud_kdtree(self, points, colors, radius=5.0, min_neighbors=50):
"""
使用KDTree进行半径搜索过滤掉稀疏点
如果点在指定半径内的邻居数量少于min_neighbors则移除该点
Args:
points: 点云坐标数组 (N, 3)
colors: 颜色数组 (N, 3)
radius: 搜索半径mm默认5.0mm
min_neighbors: 最小邻居数量默认50
Returns:
tuple: (过滤后的点云, 过滤后的颜色)
"""
if len(points) == 0:
return points, colors
# 构建KDTree
tree = cKDTree(points)
# 对每个点进行半径搜索(包括点本身)
# 使用query_ball_point找到半径内的所有点
neighbor_counts = np.zeros(len(points), dtype=np.int32)
# 批量查询以提高效率
for i in range(len(points)):
neighbors = tree.query_ball_point(points[i], radius)
neighbor_counts[i] = len(neighbors)
# 保留邻居数量 >= min_neighbors 的点
valid_mask = neighbor_counts >= min_neighbors
filtered_points = points[valid_mask]
filtered_colors = colors[valid_mask]
return filtered_points, filtered_colors
def save_point_cloud(self, left_np, depth_np, calib_params, save_path, base_name):
"""
从左图像和深度图生成点云
left_np: 左图像的numpy数组 (H, W, 3) BGR格式
depth_np: 深度图的numpy数组 (H, W) 单位为mm
calib_params: 相机标定参数
"""
height, width = depth_np.shape
fx = calib_params.left_cam.fx
fy = calib_params.left_cam.fy
cx = calib_params.left_cam.cx
cy = calib_params.left_cam.cy
# 创建网格坐标
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
# 计算3D点
Z = depth_np # 深度值mm
X = (x_grid - cx) * Z / fx
Y = (y_grid - cy) * Z / fy
# 创建点云
valid_points = Z > 0 # 只使用有效的深度值
points = np.stack([X[valid_points], Y[valid_points], Z[valid_points]], axis=1)
# 获取对应的颜色转换为RGB
colors = left_np[valid_points]
colors = colors[:, [2,1,0]] # BGR to RGB
# 对点云进行均匀下采样以去除稀疏点
points, colors = self.downsample_point_cloud(points, colors, voxel_size=3.0)
#print(f"下采样后点数: {len(points)} (原始点数: {np.sum(valid_points)})")
# 使用KDTree过滤稀疏点半径5mm内少于10个邻居的点将被移除
points, colors = self.filter_point_cloud_kdtree(points, colors, radius=5.0, min_neighbors=10)
#print(f"KDTree过滤后点数: {len(points)} (原始点数: {np.sum(valid_points)})")
points, colors = self.downsample_point_cloud(points, colors, voxel_size=5.0)
print(f"下采样后点数: {len(points)} (原始点数: {np.sum(valid_points)})")
# 保存为PLY文件
ply_file = os.path.join(save_path, f"{base_name}_pointcloud.ply")
with open(ply_file, 'w') as f:
# 写入PLY头部
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(points)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("property uchar red\n")
f.write("property uchar green\n")
f.write("property uchar blue\n")
f.write("end_header\n")
# 写入点云数据
for i in range(len(points)):
f.write(f"{points[i,0]:.3f} {points[i,1]:.3f} {points[i,2]:.3f} "
f"{colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
print(f"保存点云文件: {ply_file}")
def save_images(self, save_path, left_np, right_np, depth_np, frame_count):
"""保存图像,使用帧计数作为文件名"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
base_name = f"{timestamp}_{frame_count:06d}"
# 保存RGB图像
cv2.imwrite(os.path.join(save_path, f"{base_name}_left.png"), left_np)
cv2.imwrite(os.path.join(save_path, f"{base_name}_right.png"), right_np)
# 保存原始深度数据float32格式
depth_file = os.path.join(save_path, f"{base_name}_depth.npy")
np.save(depth_file, depth_np)
# 获取相机标定参数并生成点云
calibration_params = self.zed.get_camera_information().camera_configuration.calibration_parameters
self.save_point_cloud(left_np, depth_np, calibration_params, save_path, base_name)
print(f"保存图像、深度数据和点云: {base_name}")
def save_fish_point_cloud(self, left_np, depth_np, calib_params, save_path, base_name, max_depth=1000, max_y=100, min_points=500):
"""
保存检测到的鱼类点云(只保存满足检测条件的点)
left_np: 左图像的numpy数组 (H, W, 3) BGR格式
depth_np: 深度图的numpy数组 (H, W) 单位为mm
calib_params: 相机标定参数
save_path: 保存路径
base_name: 文件名基础
max_depth: 最大深度值mm默认1000
max_y: 最大Y坐标值mm默认100
min_points: 过滤后所需的最少点数如果少于这个数则不保存默认500
"""
height, width = depth_np.shape
fx = calib_params.left_cam.fx
fy = calib_params.left_cam.fy
cx = calib_params.left_cam.cx
cy = calib_params.left_cam.cy
# 创建网格坐标
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
# 计算3D点
Z = depth_np # 深度值mm
X = (x_grid - cx) * Z / fx
Y = (y_grid - cy) * Z / fy
# 应用与detect_fish相同的过滤条件
valid_depth = (Z > 0) & (Z <= max_depth)
valid_y = Y < max_y
fish_mask = valid_depth & valid_y
# 只保存满足条件的点
points = np.stack([X[fish_mask], Y[fish_mask], Z[fish_mask]], axis=1)
raw_points = points.copy()
# 获取对应的颜色转换为RGB
colors = left_np[fish_mask]
colors = colors[:, [2,1,0]] # BGR to RGB
raw_colors = colors.copy()
# 对点云进行均匀下采样以去除稀疏点
points, colors = self.downsample_point_cloud(points, colors, voxel_size=3.0)
# 使用KDTree过滤稀疏点半径5mm内少于10个邻居的点将被移除
points, colors = self.filter_point_cloud_kdtree(points, colors, radius=5.0, min_neighbors=10)
print(f"KDTree过滤后点数: {len(points)} (原始点数: {np.sum(fish_mask)})")
points, colors = self.downsample_point_cloud(points, colors, voxel_size=5.0)
final_point_count = len(points)
print(f"下采样后点数: {final_point_count} (原始点数: {np.sum(fish_mask)})")
# 检查过滤后的点数是否满足最小点数要求
if final_point_count < min_points:
print(f"跳过保存:过滤后点数 ({final_point_count}) 少于最小要求 ({min_points})")
return False
# 保存为PLY文件到pointcloud目录
pointcloud_dir = os.path.join(save_path, "pointcloud")
self.ensure_save_path(pointcloud_dir)
ply_file = os.path.join(pointcloud_dir, f"{base_name}_pointcloud.ply")
with open(ply_file, 'w') as f:
# 写入PLY头部
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(points)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("property uchar red\n")
f.write("property uchar green\n")
f.write("property uchar blue\n")
f.write("end_header\n")
# 写入点云数据
for i in range(len(points)):
f.write(f"{points[i,0]:.3f} {points[i,1]:.3f} {points[i,2]:.3f} "
f"{colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
print(f"保存点云文件: {ply_file} (点数: {len(points)})")
return True
def save_fish_point_cloud_debug(self, left_np, depth_np, calib_params, save_path, base_name, max_depth=1000, max_y=100, min_points=500):
"""
在debug模式下保存检测到的鱼类点云只保存满足检测条件的点
left_np: 左图像的numpy数组 (H, W, 3) BGR格式
depth_np: 深度图的numpy数组 (H, W) 单位为mm
calib_params: 相机标定参数
save_path: 保存路径
base_name: 文件名基础
max_depth: 最大深度值mm默认1000
max_y: 最大Y坐标值mm默认100
min_points: 过滤后所需的最少点数如果少于这个数则不保存默认500
"""
height, width = depth_np.shape
fx = calib_params.left_cam.fx
fy = calib_params.left_cam.fy
cx = calib_params.left_cam.cx
cy = calib_params.left_cam.cy
# 创建网格坐标
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
# 计算3D点
Z = depth_np # 深度值mm
X = (x_grid - cx) * Z / fx
Y = (y_grid - cy) * Z / fy
# 应用与detect_fish相同的过滤条件
valid_depth = (Z > 0) & (Z <= max_depth)
valid_y = Y < max_y
fish_mask = valid_depth & valid_y
# 只保存满足条件的点
points = np.stack([X[fish_mask], Y[fish_mask], Z[fish_mask]], axis=1)
colors = left_np[fish_mask][:, [2,1,0]] # BGR to RGB
if len(points) == 0:
print("[DEBUG] 跳过保存:满足过滤条件的原始点数为 0")
return False
# 对点云进行均匀下采样以去除稀疏点
points, colors = self.downsample_point_cloud(points, colors, voxel_size=3.0)
#print(f"[DEBUG] 下采样后点数: {len(points)} (原始点数: {np.sum(fish_mask)})")
# 使用KDTree过滤稀疏点半径5mm内少于50个邻居的点将被移除
points, colors = self.filter_point_cloud_kdtree(points, colors, radius=5.0, min_neighbors=10)
print(f"[DEBUG] KDTree过滤后点数: {len(points)} (原始点数: {np.sum(fish_mask)})")
points, colors = self.downsample_point_cloud(points, colors, voxel_size=5.0)
final_point_count = len(points)
print(f"[DEBUG] 下采样后点数: {final_point_count} (原始点数: {np.sum(fish_mask)})")
# 保存为PLY文件到pointcloud目录
pointcloud_dir = os.path.join(save_path, "pointcloud")
self.ensure_save_path(pointcloud_dir)
# 检查过滤后的点数是否满足最小点数要求
if final_point_count < min_points:
print(f"[DEBUG] 跳过保存:过滤后点数 ({final_point_count}) 少于最小要求 ({min_points})")
return False
# 保存为PLY文件
ply_file = os.path.join(pointcloud_dir, f"{base_name}_fish_detected.ply")
with open(ply_file, 'w') as f:
# 写入PLY头部
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(points)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("property uchar red\n")
f.write("property uchar green\n")
f.write("property uchar blue\n")
f.write("end_header\n")
# 写入点云数据
for i in range(len(points)):
f.write(f"{points[i,0]:.3f} {points[i,1]:.3f} {points[i,2]:.3f} "
f"{colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
print(f"[DEBUG] 保存检测到的鱼类点云: {ply_file} (点数: {len(points)})")
return True
def process_frames(self, save_path=None, record=False, save_all=False, debug=False, max_depth=1000, max_y=100, min_points=50, show_yolo_bbox=True):
"""
处理帧
save_path: 保存路径
record: 是否录制SVO文件
save_all: 是否保存所有帧
debug: 是否启用debug模式检测到鱼时保存点云
max_depth: 最大深度值mm默认1000
max_y: 最大Y坐标值mm默认100
min_points: 检测到鱼所需的最少点数默认50
show_yolo_bbox: 是否显示YOLO检测的边界框默认True
"""
self.save_path = save_path # 保存路径作为类属性
# 确保保存路径存在
if save_path:
self.ensure_save_path(save_path)
# 如果是录制模式,开始录制
if record and self.camera_mode:
if save_path:
svo_file = os.path.join(save_path, f"recording_{time.strftime('%Y%m%d_%H%M%S')}.svo")
else:
svo_file = f"recording_{time.strftime('%Y%m%d_%H%M%S')}.svo"
self.start_recording(svo_file)
# 获取相机标定参数(用于鱼类检测)
calibration_params = self.zed.get_camera_information().camera_configuration.calibration_parameters
# 创建运行时参数
runtime_parameters = sl.RuntimeParameters()
# 创建图像容器
left_image = sl.Mat()
right_image = sl.Mat()
depth_map = sl.Mat()
# 创建输出目录
if save_path:
images_dir = os.path.join(save_path, "images")
pointcloud_dir = os.path.join(save_path, "pointcloud")
self.ensure_save_path(images_dir)
self.ensure_save_path(pointcloud_dir)
print(f"图像将保存到: {images_dir}")
print(f"点云将保存到: {pointcloud_dir}")
frame_count = 0
start_time = time.time()
fish_detection_count = 0
frame_skip_interval = 5 # Process every 5th frame
# 用于非debug模式的鱼类跟踪
previous_fish_detected = False
active_fish_tracking = [] # 列表,每个元素是 (first_detection_frame, fish_number)
fish_counter = 0 # 用于生成fish1, fish2等编号
try:
while True:
# 抓取新的帧
if self.zed.grab(runtime_parameters) == sl.ERROR_CODE.SUCCESS:
frame_count += 1
# Skip frames: only process every 5th frame
if frame_count % frame_skip_interval != 0:
# 跳过帧不处理,直接继续
continue # Skip processing for this frame
# Process this frame (every 5th frame)
# 获取图像
self.zed.retrieve_image(left_image, sl.VIEW.LEFT)
self.zed.retrieve_image(right_image, sl.VIEW.RIGHT)
self.zed.retrieve_measure(depth_map, sl.MEASURE.DEPTH)
# 转换为numpy数组
left_np = left_image.get_data()
right_np = right_image.get_data()
depth_np = depth_map.get_data()
# YOLO鱼类检测如果启用
yolo_detected = False
yolo_bboxes = []
if self.use_yolo_detector:
try:
yolo_bboxes = self.detector.detect_fish(left_np)
yolo_detected = len(yolo_bboxes) > 0
if yolo_detected:
print(f"[帧 {frame_count}] YOLO检测到 {len(yolo_bboxes)} 条鱼")
except Exception as e:
print(f"[帧 {frame_count}] YOLO检测错误: {e}")
# 深度鱼类检测
fish_detected, fish_point_count = self.detect_fish(depth_np, calibration_params, min_points=min_points, max_depth=max_depth, max_y=max_y)
# 非debug模式跟踪鱼类检测并保存点云
if not debug and save_path:
# 检测新的鱼类出现从False到True的转换
# 如果启用了YOLO需要同时满足YOLO和深度检测否则只需深度检测
detection_triggered = (self.use_yolo_detector and yolo_detected and fish_detected) or (not self.use_yolo_detector and fish_detected)
if detection_triggered and not previous_fish_detected:
fish_counter += 1
first_detection_frame = frame_count
active_fish_tracking.append((first_detection_frame, fish_counter))
detection_method = "YOLO+深度" if self.use_yolo_detector else "深度"
print(f"[帧 {frame_count}] 检测到新鱼类 (fish{fish_counter}){detection_method}检测,满足条件的点数: {fish_point_count}")
# 检查是否有需要保存的点云first_frame + 12
fish_to_remove = []
for first_frame, fish_num in active_fish_tracking:
if frame_count == first_frame + 12:
# 如果启用了YOLO再次检查YOLO是否检测到鱼
should_save = True
if self.use_yolo_detector:
current_yolo_detected = False
try:
current_yolo_bboxes = self.detector.detect_fish(left_np)
current_yolo_detected = len(current_yolo_bboxes) > 0
except:
pass
if not current_yolo_detected:
should_save = False
print(f"[帧 {frame_count}] 跳过保存鱼类点云: fish{fish_num} (YOLO未检测到鱼)")
if should_save:
# 保存点云
timestamp = time.strftime("%Y%m%d_%H%M%S")
base_name = f"{timestamp}_{frame_count:06d}_fish{fish_num}"
saved = self.save_fish_point_cloud(left_np, depth_np, calibration_params, save_path, base_name, max_depth=max_depth, max_y=max_y, min_points=min_points)
if saved:
pointcloud_saved_this_frame = True
print(f"[帧 {frame_count}] 保存鱼类点云: fish{fish_num} (首次检测于帧 {first_frame})")
else:
print(f"[帧 {frame_count}] 跳过保存鱼类点云: fish{fish_num} (过滤后点数不足)")
fish_to_remove.append((first_frame, fish_num))
# 移除已保存的鱼类跟踪
for item in fish_to_remove:
active_fish_tracking.remove(item)
# 如果鱼类消失从True到False但保留跟踪直到保存完成
# 这样即使鱼类在保存前消失我们仍会在first_frame + 12时保存
if not detection_triggered and previous_fish_detected:
print(f"[帧 {frame_count}] 鱼类消失,但保留跟踪直到保存完成")
# 跟踪当前帧是否成功保存了点云(用于决定是否保存图像)
pointcloud_saved_this_frame = False
if fish_detected:
fish_detection_count += 1
if debug:
print(f"[帧 {frame_count}] 检测到鱼类!满足条件的点数: {fish_point_count}")
# 如果启用debug模式保存检测到的鱼类点云
if debug and save_path:
# 如果启用了YOLO只有当YOLO检测到鱼时才保存
should_save_debug = True
if self.use_yolo_detector and not yolo_detected:
should_save_debug = False
print(f"[帧 {frame_count}] [DEBUG] 跳过保存YOLO未检测到鱼")
if should_save_debug:
timestamp = time.strftime("%Y%m%d_%H%M%S")
base_name = f"{timestamp}_{frame_count:06d}_fish"
saved = self.save_fish_point_cloud_debug(left_np, depth_np, calibration_params, save_path, base_name, max_depth=max_depth, max_y=max_y, min_points=min_points)
if saved:
pointcloud_saved_this_frame = True
else:
print(f"[帧 {frame_count}] [DEBUG] 跳过保存:过滤后点数不足")
# 更新前一次的检测状态
# 如果启用了YOLO需要同时满足YOLO和深度检测
if self.use_yolo_detector:
previous_fish_detected = yolo_detected and fish_detected
else:
previous_fish_detected = fish_detected
# 如果设置保存所有帧
if save_all and save_path:
self.save_images(save_path, left_np, right_np, depth_np, frame_count)
# 复制左图用于绘制边界框(避免修改原始图像)
display_left = left_np.copy()
# YOLO检测并绘制边界框如果启用
if self.use_yolo_detector and show_yolo_bbox and yolo_bboxes:
try:
display_left = self.detector.draw_bboxes(display_left, yolo_bboxes)
except Exception as e:
print(f"[帧 {frame_count}] 绘制边界框错误: {e}")
# 在图像上显示检测状态和信息
if self.camera_mode:
elapsed_time = time.time() - start_time
fps = frame_count / elapsed_time if elapsed_time > 0 else 0
status = f"Frame: {frame_count} (processed), FPS: {fps:.1f}"
cv2.putText(display_left, status, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
# 在图像上显示检测状态
if self.use_yolo_detector:
detection_status = f"YOLO: {'YES' if yolo_detected else 'NO'} | Depth: {'YES' if fish_detected else 'NO'} ({fish_point_count} pts)"
status_color = (0, 255, 0) if (yolo_detected and fish_detected) else (0, 165, 255) if (yolo_detected or fish_detected) else (0, 0, 255)
else:
detection_status = f"Fish: {'YES' if fish_detected else 'NO'} ({fish_point_count} pts)"
status_color = (0, 255, 0) if fish_detected else (0, 0, 255)
cv2.putText(display_left, detection_status, (10, 60),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2)
# 保存带边界框的图像(只有当点云成功保存时才保存图像)
if save_path and pointcloud_saved_this_frame:
images_dir = os.path.join(save_path, "images")
timestamp = time.strftime("%Y%m%d_%H%M%S")
image_filename = f"{timestamp}_{frame_count:06d}_bbox.png"
image_path = os.path.join(images_dir, image_filename)
cv2.imwrite(image_path, display_left)
print(f"[帧 {frame_count}] 保存带边界框的图像: {image_path}")
elif not self.camera_mode: # 如果是SVO模式且到达文件末尾
break
finally:
if record:
self.zed.disable_recording()
self.close()
print(f"总共处理了 {frame_count}")
print(f"检测到鱼类的帧数: {fish_detection_count}")
if self.camera_mode:
elapsed_time = time.time() - start_time
print(f"平均帧率: {frame_count/elapsed_time:.1f} FPS")
def close(self):
# 关闭相机
self.zed.close()
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--svo', type=str, help='SVO文件路径', default=None)
parser.add_argument('--save_path', type=str, help='保存图像的路径', default='./zed_images')
parser.add_argument('--record', action='store_true', help='是否录制新的SVO文件')
parser.add_argument('--save_all', action='store_true', help='是否保存所有帧')
parser.add_argument('--camera', action='store_true', help='使用实时相机模式')
parser.add_argument('--debug', action='store_true', help='启用debug模式检测到鱼时自动保存点云')
parser.add_argument('--max_depth', type=float, help='最大深度值mm默认1000', default=1000)
parser.add_argument('--max_y', type=float, help='最大Y坐标值mm默认100', default=100)
parser.add_argument('--min_points', type=int, help='检测到鱼所需的最少点数默认50', default=500)
parser.add_argument('--use_yolo', action='store_true', help='使用YOLO检测器检测鱼只有YOLO检测到边界框时才保存点云')
parser.add_argument('--yolo_model', type=str, help='YOLO模型路径.pt文件', default=None)
parser.add_argument('--show_yolo_bbox', action='store_true', default=True, help='显示YOLO检测的边界框默认启用')
args = parser.parse_args()
if args.camera and args.svo:
print("错误不能同时指定相机模式和SVO文件")
return
# 创建ZED读取器实例
zed_reader = ZEDReader(args.svo, camera_mode=args.camera,
use_yolo_detector=args.use_yolo,
yolo_model_path=args.yolo_model)
# 打开相机或文件
if zed_reader.open():
# 处理帧
zed_reader.process_frames(args.save_path, args.record, args.save_all, args.debug,
args.max_depth, args.max_y, args.min_points, args.show_yolo_bbox)
if __name__ == "__main__":
main()