Files

802 lines
37 KiB
Python
Raw Permalink Normal View History

import pyzed.sl as sl
import cv2
import numpy as np
import os
import time
from scipy.spatial import cKDTree
# Try to import fish detector (optional)
try:
from detector import FishDetector
DETECTOR_AVAILABLE = True
except ImportError:
DETECTOR_AVAILABLE = False
FishDetector = None
class ZEDReader:
def __init__(self, svo_path=None, camera_mode=False, use_yolo_detector=False, yolo_model_path=None):
# 创建ZED相机对象
self.zed = sl.Camera()
self.camera_mode = camera_mode
self.use_yolo_detector = use_yolo_detector
self.detector = None
# Initialize YOLO detector if requested
if use_yolo_detector:
if not DETECTOR_AVAILABLE:
print("警告: YOLO检测器不可用。请安装detector.py和相关依赖")
self.use_yolo_detector = False
else:
try:
self.detector = FishDetector(model_path=yolo_model_path)
print("✓ YOLO鱼类检测器已初始化")
except Exception as e:
print(f"警告: 初始化YOLO检测器失败: {e}")
self.use_yolo_detector = False
# 设置初始化参数
self.init_params = sl.InitParameters()
if svo_path and not camera_mode:
# 如果提供SVO文件路径则从文件读取
print(f"从SVO文件读取: {svo_path}")
self.init_params.set_from_svo_file(svo_path)
self.init_params.svo_real_time_mode = False
else:
# 实时相机模式设置
print("使用实时相机模式")
self.init_params.camera_resolution = sl.RESOLUTION.HD1200
self.init_params.depth_mode = sl.DEPTH_MODE.NEURAL
self.init_params.coordinate_units = sl.UNIT.MILLIMETER
self.init_params.camera_fps = 30 # 设置帧率
def open(self):
# 打开相机或SVO文件
err = self.zed.open(self.init_params)
if err != sl.ERROR_CODE.SUCCESS:
print(f"错误: {err}")
return False
# 获取并打印标定参数
self.print_calibration_params()
return True
def print_calibration_params(self):
"""打印相机标定参数"""
calibration_params = self.zed.get_camera_information().camera_configuration.calibration_parameters
print("\n=== 相机标定参数 ===")
# 左眼相机参数
print("\n左眼相机参数:")
print(f"焦距: fx={calibration_params.left_cam.fx}, fy={calibration_params.left_cam.fy}")
print(f"主点: cx={calibration_params.left_cam.cx}, cy={calibration_params.left_cam.cy}")
print("畸变参数:")
print(f"k1={calibration_params.left_cam.disto[0]}")
print(f"k2={calibration_params.left_cam.disto[1]}")
print(f"k3={calibration_params.left_cam.disto[2]}")
print(f"k4={calibration_params.left_cam.disto[3]}")
print(f"k5={calibration_params.left_cam.disto[4]}")
# 右眼相机参数
print("\n右眼相机参数:")
print(f"焦距: fx={calibration_params.right_cam.fx}, fy={calibration_params.right_cam.fy}")
print(f"主点: cx={calibration_params.right_cam.cx}, cy={calibration_params.right_cam.cy}")
print("畸变参数:")
print(f"k1={calibration_params.right_cam.disto[0]}")
print(f"k2={calibration_params.right_cam.disto[1]}")
print(f"k3={calibration_params.right_cam.disto[2]}")
print(f"k4={calibration_params.right_cam.disto[3]}")
print(f"k5={calibration_params.right_cam.disto[4]}")
#import pdb; pdb.set_trace()
print("\n立体参数:")
# 获取旋转矩阵
rotation = calibration_params.stereo_transform.m
print("旋转矩阵 R:")
print(f"[{rotation[0][0]}, {rotation[0][1]}, {rotation[0][2]}]")
print(f"[{rotation[1][0]}, {rotation[1][1]}, {rotation[1][2]}]")
print(f"[{rotation[2][0]}, {rotation[2][1]}, {rotation[2][2]}]")
print("\n平移向量 T:")
print(f"[{rotation[0][3]}, {rotation[1][3]}, {rotation[2][3]}]")
# 保存标定参数到文件
if hasattr(self, 'save_path') and self.save_path:
self.save_calibration_params(calibration_params)
def save_calibration_params(self, calibration_params):
"""将标定参数保存到文件"""
if not hasattr(self, 'save_path') or not self.save_path:
return
calib_file = os.path.join(self.save_path, 'calibration_params.txt')
with open(calib_file, 'w') as f:
f.write("=== 相机标定参数 ===\n")
# 左眼相机
f.write("\n左眼相机参数:\n")
f.write(f"焦距: fx={calibration_params.left_cam.fx}, fy={calibration_params.left_cam.fy}\n")
f.write(f"主点: cx={calibration_params.left_cam.cx}, cy={calibration_params.left_cam.cy}\n")
f.write("畸变参数:\n")
for i, d in enumerate(calibration_params.left_cam.disto):
f.write(f"k{i+1}={d}\n")
# 右眼相机
f.write("\n右眼相机参数:\n")
f.write(f"焦距: fx={calibration_params.right_cam.fx}, fy={calibration_params.right_cam.fy}\n")
f.write(f"主点: cx={calibration_params.right_cam.cx}, cy={calibration_params.right_cam.cy}\n")
f.write("畸变参数:\n")
for i, d in enumerate(calibration_params.right_cam.disto):
f.write(f"k{i+1}={d}\n")
# 立体参数
f.write("\n立体参数:\n")
rotation = calibration_params.stereo_transform.m
f.write("旋转矩阵 R:\n")
f.write(f"[{rotation[0][0]}, {rotation[0][1]}, {rotation[0][2]}]\n")
f.write(f"[{rotation[1][0]}, {rotation[1][1]}, {rotation[1][2]}]\n")
f.write(f"[{rotation[2][0]}, {rotation[2][1]}, {rotation[2][2]}]\n")
f.write("\n平移向量 T:\n")
f.write(f"[{rotation[0][3]}, {rotation[1][3]}, {rotation[2][3]}]\n")
def start_recording(self, filename=None):
"""开始录制"""
if filename is None:
# 使用时间戳作为文件名
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"recording_{timestamp}.svo"
recording_param = sl.RecordingParameters(filename, sl.SVO_COMPRESSION_MODE.H264)
err = self.zed.enable_recording(recording_param)
if err != sl.ERROR_CODE.SUCCESS:
print(f"无法开始录制: {err}")
return False
print(f"开始录制到文件: {filename}")
return True
def ensure_save_path(self, save_path):
"""确保保存路径存在,如果不存在则创建"""
if save_path:
if not os.path.exists(save_path):
os.makedirs(save_path)
print(f"创建保存目录: {save_path}")
return True
return False
def detect_fish(self, depth_np, calib_params, min_points=100, max_depth=1200, max_y=150):
"""
检测深度图中是否有鱼类
检测条件
- z+ 方向深度0-max_depth mm
- y- 方向相机坐标系y轴负方向y < max_y
- 满足条件的点数 >= min_points
Args:
depth_np: 深度图的numpy数组 (H, W) 单位为mm
calib_params: 相机标定参数
min_points: 检测到鱼所需的最少点数默认100
max_depth: 最大深度值mm默认1200
max_y: 最大Y坐标值mm默认150
Returns:
tuple: (是否检测到鱼(bool), 满足条件的点数(int))
"""
height, width = depth_np.shape
fx = calib_params.left_cam.fx
fy = calib_params.left_cam.fy
cx = calib_params.left_cam.cx
cy = calib_params.left_cam.cy
# 创建网格坐标
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
# 计算3D点
Z = depth_np # 深度值mm
X = (x_grid - cx) * Z / fx
Y = (y_grid - cy) * Z / fy
# 检测条件:
# 1. z+ 方向:深度值在 0-max_depth mm 范围内
# 2. y- 方向Y坐标小于max_y相机坐标系中y轴向上为正
# 3. 深度值有效(> 0
valid_depth = (Z > 0) & (Z <= max_depth)
valid_y = Y < max_y
# 同时满足所有条件的点
fish_mask = valid_depth & valid_y
# 统计满足条件的点数
fish_point_count = np.sum(fish_mask)
# 如果点数 >= min_points认为检测到鱼
fish_detected = fish_point_count >= min_points
return fish_detected, fish_point_count
def downsample_point_cloud(self, points, colors, voxel_size=5.0):
"""
使用体素网格对点云进行均匀下采样去除稀疏点
Args:
points: 点云坐标数组 (N, 3)
colors: 颜色数组 (N, 3)
voxel_size: 体素大小mm默认5.0mm
Returns:
tuple: (下采样后的点云, 下采样后的颜色)
"""
if len(points) == 0:
return points, colors
# 计算每个点所属的体素索引
voxel_indices = np.floor(points / voxel_size).astype(np.int32)
# 使用字典来存储每个体素中的点(保留第一个点)
voxel_dict = {}
for i in range(len(points)):
voxel_key = tuple(voxel_indices[i])
if voxel_key not in voxel_dict:
voxel_dict[voxel_key] = i
# 获取保留的点的索引
kept_indices = list(voxel_dict.values())
return points[kept_indices], colors[kept_indices]
def filter_point_cloud_kdtree(self, points, colors, radius=5.0, min_neighbors=50):
"""
使用KDTree进行半径搜索过滤掉稀疏点
如果点在指定半径内的邻居数量少于min_neighbors则移除该点
Args:
points: 点云坐标数组 (N, 3)
colors: 颜色数组 (N, 3)
radius: 搜索半径mm默认5.0mm
min_neighbors: 最小邻居数量默认50
Returns:
tuple: (过滤后的点云, 过滤后的颜色)
"""
if len(points) == 0:
return points, colors
# 构建KDTree
tree = cKDTree(points)
# 对每个点进行半径搜索(包括点本身)
# 使用query_ball_point找到半径内的所有点
neighbor_counts = np.zeros(len(points), dtype=np.int32)
# 批量查询以提高效率
for i in range(len(points)):
neighbors = tree.query_ball_point(points[i], radius)
neighbor_counts[i] = len(neighbors)
# 保留邻居数量 >= min_neighbors 的点
valid_mask = neighbor_counts >= min_neighbors
filtered_points = points[valid_mask]
filtered_colors = colors[valid_mask]
return filtered_points, filtered_colors
def save_point_cloud(self, left_np, depth_np, calib_params, save_path, base_name):
"""
从左图像和深度图生成点云
left_np: 左图像的numpy数组 (H, W, 3) BGR格式
depth_np: 深度图的numpy数组 (H, W) 单位为mm
calib_params: 相机标定参数
"""
height, width = depth_np.shape
fx = calib_params.left_cam.fx
fy = calib_params.left_cam.fy
cx = calib_params.left_cam.cx
cy = calib_params.left_cam.cy
# 创建网格坐标
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
# 计算3D点
Z = depth_np # 深度值mm
X = (x_grid - cx) * Z / fx
Y = (y_grid - cy) * Z / fy
# 创建点云
valid_points = Z > 0 # 只使用有效的深度值
points = np.stack([X[valid_points], Y[valid_points], Z[valid_points]], axis=1)
# 获取对应的颜色转换为RGB
colors = left_np[valid_points]
colors = colors[:, [2,1,0]] # BGR to RGB
# 对点云进行均匀下采样以去除稀疏点
points, colors = self.downsample_point_cloud(points, colors, voxel_size=3.0)
#print(f"下采样后点数: {len(points)} (原始点数: {np.sum(valid_points)})")
# 使用KDTree过滤稀疏点半径5mm内少于10个邻居的点将被移除
points, colors = self.filter_point_cloud_kdtree(points, colors, radius=5.0, min_neighbors=10)
#print(f"KDTree过滤后点数: {len(points)} (原始点数: {np.sum(valid_points)})")
points, colors = self.downsample_point_cloud(points, colors, voxel_size=5.0)
print(f"下采样后点数: {len(points)} (原始点数: {np.sum(valid_points)})")
# 保存为PLY文件
ply_file = os.path.join(save_path, f"{base_name}_pointcloud.ply")
with open(ply_file, 'w') as f:
# 写入PLY头部
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(points)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("property uchar red\n")
f.write("property uchar green\n")
f.write("property uchar blue\n")
f.write("end_header\n")
# 写入点云数据
for i in range(len(points)):
f.write(f"{points[i,0]:.3f} {points[i,1]:.3f} {points[i,2]:.3f} "
f"{colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
print(f"保存点云文件: {ply_file}")
def save_images(self, save_path, left_np, right_np, depth_np, frame_count):
"""保存图像,使用帧计数作为文件名"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
base_name = f"{timestamp}_{frame_count:06d}"
# 保存RGB图像
cv2.imwrite(os.path.join(save_path, f"{base_name}_left.png"), left_np)
cv2.imwrite(os.path.join(save_path, f"{base_name}_right.png"), right_np)
# 保存原始深度数据float32格式
depth_file = os.path.join(save_path, f"{base_name}_depth.npy")
np.save(depth_file, depth_np)
# 获取相机标定参数并生成点云
calibration_params = self.zed.get_camera_information().camera_configuration.calibration_parameters
self.save_point_cloud(left_np, depth_np, calibration_params, save_path, base_name)
print(f"保存图像、深度数据和点云: {base_name}")
def save_fish_point_cloud(self, left_np, depth_np, calib_params, save_path, base_name, max_depth=1000, max_y=100, min_points=500):
"""
保存检测到的鱼类点云只保存满足检测条件的点
left_np: 左图像的numpy数组 (H, W, 3) BGR格式
depth_np: 深度图的numpy数组 (H, W) 单位为mm
calib_params: 相机标定参数
save_path: 保存路径
base_name: 文件名基础
max_depth: 最大深度值mm默认1000
max_y: 最大Y坐标值mm默认100
min_points: 过滤后所需的最少点数如果少于这个数则不保存默认500
"""
height, width = depth_np.shape
fx = calib_params.left_cam.fx
fy = calib_params.left_cam.fy
cx = calib_params.left_cam.cx
cy = calib_params.left_cam.cy
# 创建网格坐标
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
# 计算3D点
Z = depth_np # 深度值mm
X = (x_grid - cx) * Z / fx
Y = (y_grid - cy) * Z / fy
# 应用与detect_fish相同的过滤条件
valid_depth = (Z > 0) & (Z <= max_depth)
valid_y = Y < max_y
fish_mask = valid_depth & valid_y
# 只保存满足条件的点
points = np.stack([X[fish_mask], Y[fish_mask], Z[fish_mask]], axis=1)
raw_points = points.copy()
# 获取对应的颜色转换为RGB
colors = left_np[fish_mask]
colors = colors[:, [2,1,0]] # BGR to RGB
raw_colors = colors.copy()
# 对点云进行均匀下采样以去除稀疏点
points, colors = self.downsample_point_cloud(points, colors, voxel_size=3.0)
# 使用KDTree过滤稀疏点半径5mm内少于10个邻居的点将被移除
points, colors = self.filter_point_cloud_kdtree(points, colors, radius=5.0, min_neighbors=10)
print(f"KDTree过滤后点数: {len(points)} (原始点数: {np.sum(fish_mask)})")
points, colors = self.downsample_point_cloud(points, colors, voxel_size=5.0)
final_point_count = len(points)
print(f"下采样后点数: {final_point_count} (原始点数: {np.sum(fish_mask)})")
# 检查过滤后的点数是否满足最小点数要求
if final_point_count < min_points:
print(f"跳过保存:过滤后点数 ({final_point_count}) 少于最小要求 ({min_points})")
return False
# 保存为PLY文件到pointcloud目录
pointcloud_dir = os.path.join(save_path, "pointcloud")
self.ensure_save_path(pointcloud_dir)
ply_file = os.path.join(pointcloud_dir, f"{base_name}_pointcloud.ply")
with open(ply_file, 'w') as f:
# 写入PLY头部
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(points)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("property uchar red\n")
f.write("property uchar green\n")
f.write("property uchar blue\n")
f.write("end_header\n")
# 写入点云数据
for i in range(len(points)):
f.write(f"{points[i,0]:.3f} {points[i,1]:.3f} {points[i,2]:.3f} "
f"{colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
print(f"保存点云文件: {ply_file} (点数: {len(points)})")
return True
def save_fish_point_cloud_debug(self, left_np, depth_np, calib_params, save_path, base_name, max_depth=1000, max_y=100, min_points=500):
"""
在debug模式下保存检测到的鱼类点云只保存满足检测条件的点
left_np: 左图像的numpy数组 (H, W, 3) BGR格式
depth_np: 深度图的numpy数组 (H, W) 单位为mm
calib_params: 相机标定参数
save_path: 保存路径
base_name: 文件名基础
max_depth: 最大深度值mm默认1000
max_y: 最大Y坐标值mm默认100
min_points: 过滤后所需的最少点数如果少于这个数则不保存默认500
"""
height, width = depth_np.shape
fx = calib_params.left_cam.fx
fy = calib_params.left_cam.fy
cx = calib_params.left_cam.cx
cy = calib_params.left_cam.cy
# 创建网格坐标
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
# 计算3D点
Z = depth_np # 深度值mm
X = (x_grid - cx) * Z / fx
Y = (y_grid - cy) * Z / fy
# 应用与detect_fish相同的过滤条件
valid_depth = (Z > 0) & (Z <= max_depth)
valid_y = Y < max_y
fish_mask = valid_depth & valid_y
# 只保存满足条件的点
points = np.stack([X[fish_mask], Y[fish_mask], Z[fish_mask]], axis=1)
colors = left_np[fish_mask][:, [2,1,0]] # BGR to RGB
if len(points) == 0:
print("[DEBUG] 跳过保存:满足过滤条件的原始点数为 0")
return False
# 对点云进行均匀下采样以去除稀疏点
points, colors = self.downsample_point_cloud(points, colors, voxel_size=3.0)
#print(f"[DEBUG] 下采样后点数: {len(points)} (原始点数: {np.sum(fish_mask)})")
# 使用KDTree过滤稀疏点半径5mm内少于50个邻居的点将被移除
points, colors = self.filter_point_cloud_kdtree(points, colors, radius=5.0, min_neighbors=10)
print(f"[DEBUG] KDTree过滤后点数: {len(points)} (原始点数: {np.sum(fish_mask)})")
points, colors = self.downsample_point_cloud(points, colors, voxel_size=5.0)
final_point_count = len(points)
print(f"[DEBUG] 下采样后点数: {final_point_count} (原始点数: {np.sum(fish_mask)})")
# 保存为PLY文件到pointcloud目录
pointcloud_dir = os.path.join(save_path, "pointcloud")
self.ensure_save_path(pointcloud_dir)
# 检查过滤后的点数是否满足最小点数要求
if final_point_count < min_points:
print(f"[DEBUG] 跳过保存:过滤后点数 ({final_point_count}) 少于最小要求 ({min_points})")
return False
# 保存为PLY文件
ply_file = os.path.join(pointcloud_dir, f"{base_name}_fish_detected.ply")
with open(ply_file, 'w') as f:
# 写入PLY头部
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(points)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("property uchar red\n")
f.write("property uchar green\n")
f.write("property uchar blue\n")
f.write("end_header\n")
# 写入点云数据
for i in range(len(points)):
f.write(f"{points[i,0]:.3f} {points[i,1]:.3f} {points[i,2]:.3f} "
f"{colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
print(f"[DEBUG] 保存检测到的鱼类点云: {ply_file} (点数: {len(points)})")
return True
def process_frames(self, save_path=None, record=False, save_all=False, debug=False, max_depth=1000, max_y=100, min_points=50, show_yolo_bbox=True):
"""
处理帧
save_path: 保存路径
record: 是否录制SVO文件
save_all: 是否保存所有帧
debug: 是否启用debug模式检测到鱼时保存点云
max_depth: 最大深度值mm默认1000
max_y: 最大Y坐标值mm默认100
min_points: 检测到鱼所需的最少点数默认50
show_yolo_bbox: 是否显示YOLO检测的边界框默认True
"""
self.save_path = save_path # 保存路径作为类属性
# 确保保存路径存在
if save_path:
self.ensure_save_path(save_path)
# 如果是录制模式,开始录制
if record and self.camera_mode:
if save_path:
svo_file = os.path.join(save_path, f"recording_{time.strftime('%Y%m%d_%H%M%S')}.svo")
else:
svo_file = f"recording_{time.strftime('%Y%m%d_%H%M%S')}.svo"
self.start_recording(svo_file)
# 获取相机标定参数(用于鱼类检测)
calibration_params = self.zed.get_camera_information().camera_configuration.calibration_parameters
# 创建运行时参数
runtime_parameters = sl.RuntimeParameters()
# 创建图像容器
left_image = sl.Mat()
right_image = sl.Mat()
depth_map = sl.Mat()
# 创建输出目录
if save_path:
images_dir = os.path.join(save_path, "images")
pointcloud_dir = os.path.join(save_path, "pointcloud")
self.ensure_save_path(images_dir)
self.ensure_save_path(pointcloud_dir)
print(f"图像将保存到: {images_dir}")
print(f"点云将保存到: {pointcloud_dir}")
frame_count = 0
start_time = time.time()
fish_detection_count = 0
frame_skip_interval = 5 # Process every 5th frame
# 用于非debug模式的鱼类跟踪
previous_fish_detected = False
active_fish_tracking = [] # 列表,每个元素是 (first_detection_frame, fish_number)
fish_counter = 0 # 用于生成fish1, fish2等编号
try:
while True:
# 抓取新的帧
if self.zed.grab(runtime_parameters) == sl.ERROR_CODE.SUCCESS:
frame_count += 1
# Skip frames: only process every 5th frame
if frame_count % frame_skip_interval != 0:
# 跳过帧不处理,直接继续
continue # Skip processing for this frame
# Process this frame (every 5th frame)
# 获取图像
self.zed.retrieve_image(left_image, sl.VIEW.LEFT)
self.zed.retrieve_image(right_image, sl.VIEW.RIGHT)
self.zed.retrieve_measure(depth_map, sl.MEASURE.DEPTH)
# 转换为numpy数组
left_np = left_image.get_data()
right_np = right_image.get_data()
depth_np = depth_map.get_data()
# YOLO鱼类检测如果启用
yolo_detected = False
yolo_bboxes = []
if self.use_yolo_detector:
try:
yolo_bboxes = self.detector.detect_fish(left_np)
yolo_detected = len(yolo_bboxes) > 0
if yolo_detected:
print(f"[帧 {frame_count}] YOLO检测到 {len(yolo_bboxes)} 条鱼")
except Exception as e:
print(f"[帧 {frame_count}] YOLO检测错误: {e}")
# 深度鱼类检测
fish_detected, fish_point_count = self.detect_fish(depth_np, calibration_params, min_points=min_points, max_depth=max_depth, max_y=max_y)
# 非debug模式跟踪鱼类检测并保存点云
if not debug and save_path:
# 检测新的鱼类出现从False到True的转换
# 如果启用了YOLO需要同时满足YOLO和深度检测否则只需深度检测
detection_triggered = (self.use_yolo_detector and yolo_detected and fish_detected) or (not self.use_yolo_detector and fish_detected)
if detection_triggered and not previous_fish_detected:
fish_counter += 1
first_detection_frame = frame_count
active_fish_tracking.append((first_detection_frame, fish_counter))
detection_method = "YOLO+深度" if self.use_yolo_detector else "深度"
print(f"[帧 {frame_count}] 检测到新鱼类 (fish{fish_counter}){detection_method}检测,满足条件的点数: {fish_point_count}")
# 检查是否有需要保存的点云first_frame + 12
fish_to_remove = []
for first_frame, fish_num in active_fish_tracking:
if frame_count == first_frame + 12:
# 如果启用了YOLO再次检查YOLO是否检测到鱼
should_save = True
if self.use_yolo_detector:
current_yolo_detected = False
try:
current_yolo_bboxes = self.detector.detect_fish(left_np)
current_yolo_detected = len(current_yolo_bboxes) > 0
except:
pass
if not current_yolo_detected:
should_save = False
print(f"[帧 {frame_count}] 跳过保存鱼类点云: fish{fish_num} (YOLO未检测到鱼)")
if should_save:
# 保存点云
timestamp = time.strftime("%Y%m%d_%H%M%S")
base_name = f"{timestamp}_{frame_count:06d}_fish{fish_num}"
saved = self.save_fish_point_cloud(left_np, depth_np, calibration_params, save_path, base_name, max_depth=max_depth, max_y=max_y, min_points=min_points)
if saved:
pointcloud_saved_this_frame = True
print(f"[帧 {frame_count}] 保存鱼类点云: fish{fish_num} (首次检测于帧 {first_frame})")
else:
print(f"[帧 {frame_count}] 跳过保存鱼类点云: fish{fish_num} (过滤后点数不足)")
fish_to_remove.append((first_frame, fish_num))
# 移除已保存的鱼类跟踪
for item in fish_to_remove:
active_fish_tracking.remove(item)
# 如果鱼类消失从True到False但保留跟踪直到保存完成
# 这样即使鱼类在保存前消失我们仍会在first_frame + 12时保存
if not detection_triggered and previous_fish_detected:
print(f"[帧 {frame_count}] 鱼类消失,但保留跟踪直到保存完成")
# 跟踪当前帧是否成功保存了点云(用于决定是否保存图像)
pointcloud_saved_this_frame = False
if fish_detected:
fish_detection_count += 1
if debug:
print(f"[帧 {frame_count}] 检测到鱼类!满足条件的点数: {fish_point_count}")
# 如果启用debug模式保存检测到的鱼类点云
if debug and save_path:
# 如果启用了YOLO只有当YOLO检测到鱼时才保存
should_save_debug = True
if self.use_yolo_detector and not yolo_detected:
should_save_debug = False
print(f"[帧 {frame_count}] [DEBUG] 跳过保存YOLO未检测到鱼")
if should_save_debug:
timestamp = time.strftime("%Y%m%d_%H%M%S")
base_name = f"{timestamp}_{frame_count:06d}_fish"
saved = self.save_fish_point_cloud_debug(left_np, depth_np, calibration_params, save_path, base_name, max_depth=max_depth, max_y=max_y, min_points=min_points)
if saved:
pointcloud_saved_this_frame = True
else:
print(f"[帧 {frame_count}] [DEBUG] 跳过保存:过滤后点数不足")
# 更新前一次的检测状态
# 如果启用了YOLO需要同时满足YOLO和深度检测
if self.use_yolo_detector:
previous_fish_detected = yolo_detected and fish_detected
else:
previous_fish_detected = fish_detected
# 如果设置保存所有帧
if save_all and save_path:
self.save_images(save_path, left_np, right_np, depth_np, frame_count)
# 复制左图用于绘制边界框(避免修改原始图像)
display_left = left_np.copy()
# YOLO检测并绘制边界框如果启用
if self.use_yolo_detector and show_yolo_bbox and yolo_bboxes:
try:
display_left = self.detector.draw_bboxes(display_left, yolo_bboxes)
except Exception as e:
print(f"[帧 {frame_count}] 绘制边界框错误: {e}")
# 在图像上显示检测状态和信息
if self.camera_mode:
elapsed_time = time.time() - start_time
fps = frame_count / elapsed_time if elapsed_time > 0 else 0
status = f"Frame: {frame_count} (processed), FPS: {fps:.1f}"
cv2.putText(display_left, status, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
# 在图像上显示检测状态
if self.use_yolo_detector:
detection_status = f"YOLO: {'YES' if yolo_detected else 'NO'} | Depth: {'YES' if fish_detected else 'NO'} ({fish_point_count} pts)"
status_color = (0, 255, 0) if (yolo_detected and fish_detected) else (0, 165, 255) if (yolo_detected or fish_detected) else (0, 0, 255)
else:
detection_status = f"Fish: {'YES' if fish_detected else 'NO'} ({fish_point_count} pts)"
status_color = (0, 255, 0) if fish_detected else (0, 0, 255)
cv2.putText(display_left, detection_status, (10, 60),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2)
# 保存带边界框的图像(只有当点云成功保存时才保存图像)
if save_path and pointcloud_saved_this_frame:
images_dir = os.path.join(save_path, "images")
timestamp = time.strftime("%Y%m%d_%H%M%S")
image_filename = f"{timestamp}_{frame_count:06d}_bbox.png"
image_path = os.path.join(images_dir, image_filename)
cv2.imwrite(image_path, display_left)
print(f"[帧 {frame_count}] 保存带边界框的图像: {image_path}")
elif not self.camera_mode: # 如果是SVO模式且到达文件末尾
break
finally:
if record:
self.zed.disable_recording()
self.close()
print(f"总共处理了 {frame_count}")
print(f"检测到鱼类的帧数: {fish_detection_count}")
if self.camera_mode:
elapsed_time = time.time() - start_time
print(f"平均帧率: {frame_count/elapsed_time:.1f} FPS")
def close(self):
# 关闭相机
self.zed.close()
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--svo', type=str, help='SVO文件路径', default=None)
parser.add_argument('--save_path', type=str, help='保存图像的路径', default='./zed_images')
parser.add_argument('--record', action='store_true', help='是否录制新的SVO文件')
parser.add_argument('--save_all', action='store_true', help='是否保存所有帧')
parser.add_argument('--camera', action='store_true', help='使用实时相机模式')
parser.add_argument('--debug', action='store_true', help='启用debug模式检测到鱼时自动保存点云')
parser.add_argument('--max_depth', type=float, help='最大深度值mm默认1000', default=1000)
parser.add_argument('--max_y', type=float, help='最大Y坐标值mm默认100', default=100)
parser.add_argument('--min_points', type=int, help='检测到鱼所需的最少点数默认50', default=500)
parser.add_argument('--use_yolo', action='store_true', help='使用YOLO检测器检测鱼只有YOLO检测到边界框时才保存点云')
parser.add_argument('--yolo_model', type=str, help='YOLO模型路径.pt文件', default=None)
parser.add_argument('--show_yolo_bbox', action='store_true', default=True, help='显示YOLO检测的边界框默认启用')
args = parser.parse_args()
if args.camera and args.svo:
print("错误不能同时指定相机模式和SVO文件")
return
# 创建ZED读取器实例
zed_reader = ZEDReader(args.svo, camera_mode=args.camera,
use_yolo_detector=args.use_yolo,
yolo_model_path=args.yolo_model)
# 打开相机或文件
if zed_reader.open():
# 处理帧
zed_reader.process_frames(args.save_path, args.record, args.save_all, args.debug,
args.max_depth, args.max_y, args.min_points, args.show_yolo_bbox)
if __name__ == "__main__":
main()