Initial commit: FishServer monorepo (FishAction, FishMeasure, fish_api)
Made-with: Cursor
This commit is contained in:
264
FishMeasure/dataset/dataset.py
Executable file
264
FishMeasure/dataset/dataset.py
Executable file
@@ -0,0 +1,264 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Dataset preparation script for YOLO training
|
||||
- Reads train/xxx/images/ and val/xxx/images/ from source directory
|
||||
- Converts labelme JSON format to YOLO format
|
||||
- Generates YOLO-compatible dataset structure
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare YOLO dataset from labelme JSON files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source_dir",
|
||||
type=str,
|
||||
default="/home/ubuntu/data/fish/detection/svo_batch/data",
|
||||
help="Source directory containing train/ and val/ subdirectories",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="./yolo_dataset",
|
||||
help="Output directory for YOLO dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_name",
|
||||
type=str,
|
||||
default="fish",
|
||||
help="Class name for YOLO labels (default: fish)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def find_matching_pairs(images_dir: Path) -> List[Tuple[Path, Path]]:
|
||||
"""
|
||||
Find all JSON files and their corresponding PNG files.
|
||||
Returns list of (json_path, png_path) tuples.
|
||||
"""
|
||||
pairs = []
|
||||
json_files = list(images_dir.glob("*.json"))
|
||||
|
||||
for json_file in json_files:
|
||||
# Find corresponding PNG file (same base name)
|
||||
png_file = json_file.with_suffix(".png")
|
||||
if png_file.exists():
|
||||
pairs.append((json_file, png_file))
|
||||
else:
|
||||
print(f"[WARNING] No matching PNG for {json_file.name}, skipping")
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def labelme_rectangle_to_yolo(points: List[List[float]], img_width: int, img_height: int) -> Optional[Tuple[float, float, float, float]]:
|
||||
"""
|
||||
Convert labelme rectangle format to YOLO format.
|
||||
Labelme rectangle: [[x1, y1], [x2, y2]] (top-left and bottom-right)
|
||||
YOLO format: (x_center, y_center, width, height) normalized [0, 1]
|
||||
"""
|
||||
if len(points) != 2:
|
||||
return None
|
||||
|
||||
x1, y1 = points[0]
|
||||
x2, y2 = points[1]
|
||||
|
||||
# Ensure x1 < x2 and y1 < y2
|
||||
x_min = min(x1, x2)
|
||||
x_max = max(x1, x2)
|
||||
y_min = min(y1, y2)
|
||||
y_max = max(y1, y2)
|
||||
|
||||
# Calculate center and dimensions
|
||||
width = x_max - x_min
|
||||
height = y_max - y_min
|
||||
x_center = x_min + width / 2.0
|
||||
y_center = y_min + height / 2.0
|
||||
|
||||
# Normalize to [0, 1]
|
||||
x_center_norm = x_center / img_width
|
||||
y_center_norm = y_center / img_height
|
||||
width_norm = width / img_width
|
||||
height_norm = height / img_height
|
||||
|
||||
# Validate bounds
|
||||
if not (0 <= x_center_norm <= 1 and 0 <= y_center_norm <= 1 and
|
||||
0 < width_norm <= 1 and 0 < height_norm <= 1):
|
||||
return None
|
||||
|
||||
return (x_center_norm, y_center_norm, width_norm, height_norm)
|
||||
|
||||
|
||||
def convert_labelme_to_yolo(json_path: Path, img_path: Path, class_id: int) -> List[str]:
|
||||
"""
|
||||
Convert labelme JSON to YOLO format label lines.
|
||||
Returns list of YOLO label strings: "class_id x_center y_center width height"
|
||||
"""
|
||||
try:
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
labelme_data = json.load(f)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to read {json_path}: {e}")
|
||||
return []
|
||||
|
||||
# Get image dimensions
|
||||
try:
|
||||
with Image.open(img_path) as img:
|
||||
img_width, img_height = img.size
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to read image {img_path}: {e}")
|
||||
return []
|
||||
|
||||
yolo_lines = []
|
||||
shapes = labelme_data.get('shapes', [])
|
||||
|
||||
for shape in shapes:
|
||||
shape_type = shape.get('shape_type', '')
|
||||
label = shape.get('label', '')
|
||||
points = shape.get('points', [])
|
||||
|
||||
# Only process rectangles for now
|
||||
if shape_type == 'rectangle' and len(points) == 2:
|
||||
yolo_bbox = labelme_rectangle_to_yolo(points, img_width, img_height)
|
||||
if yolo_bbox is not None:
|
||||
x_center, y_center, width, height = yolo_bbox
|
||||
yolo_lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
||||
elif shape_type == 'polygon':
|
||||
# For polygons, convert to bounding box
|
||||
if len(points) >= 3:
|
||||
xs = [p[0] for p in points]
|
||||
ys = [p[1] for p in points]
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
# Convert to rectangle format
|
||||
rect_points = [[x_min, y_min], [x_max, y_max]]
|
||||
yolo_bbox = labelme_rectangle_to_yolo(rect_points, img_width, img_height)
|
||||
if yolo_bbox is not None:
|
||||
x_center, y_center, width, height = yolo_bbox
|
||||
yolo_lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
||||
|
||||
return yolo_lines
|
||||
|
||||
|
||||
def process_split(source_dir: Path, split: str, output_dir: Path, class_id: int):
|
||||
"""
|
||||
Process train or val split.
|
||||
"""
|
||||
split_dir = source_dir / split
|
||||
if not split_dir.exists():
|
||||
print(f"[WARNING] {split_dir} does not exist, skipping")
|
||||
return 0
|
||||
|
||||
# Create output directories
|
||||
output_images_dir = output_dir / "images" / split
|
||||
output_labels_dir = output_dir / "labels" / split
|
||||
output_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_labels_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
processed_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
# Iterate through all subdirectories
|
||||
for subfolder in sorted(split_dir.iterdir()):
|
||||
if not subfolder.is_dir():
|
||||
continue
|
||||
|
||||
images_dir = subfolder / "images"
|
||||
if not images_dir.exists():
|
||||
continue
|
||||
|
||||
# Find all JSON-PNG pairs
|
||||
pairs = find_matching_pairs(images_dir)
|
||||
|
||||
if not pairs:
|
||||
print(f"[SKIP] No JSON-PNG pairs found in {subfolder.name}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Process each pair
|
||||
for json_path, img_path in pairs:
|
||||
# Convert labelme to YOLO
|
||||
yolo_lines = convert_labelme_to_yolo(json_path, img_path, class_id)
|
||||
|
||||
if not yolo_lines:
|
||||
print(f"[SKIP] No valid labels in {json_path.name}")
|
||||
continue
|
||||
|
||||
# Copy image to output directory
|
||||
dst_img = output_images_dir / img_path.name
|
||||
shutil.copy2(img_path, dst_img)
|
||||
|
||||
# Write YOLO label file
|
||||
label_name = img_path.stem + ".txt"
|
||||
dst_label = output_labels_dir / label_name
|
||||
with open(dst_label, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(yolo_lines) + '\n')
|
||||
|
||||
processed_count += 1
|
||||
|
||||
print(f"[{split.upper()}] Processed {processed_count} images, skipped {skipped_count} subfolders")
|
||||
return processed_count
|
||||
|
||||
|
||||
def generate_dataset_yaml(output_dir: Path, class_name: str):
|
||||
"""
|
||||
Generate dataset.yaml file for YOLO training.
|
||||
"""
|
||||
yaml_path = output_dir / "dataset.yaml"
|
||||
content = f"""path: {output_dir.resolve()}
|
||||
train: images/train
|
||||
val: images/val
|
||||
names: [{class_name}]
|
||||
"""
|
||||
with open(yaml_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
print(f"[OK] Generated dataset.yaml: {yaml_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
source_dir = Path(args.source_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
|
||||
if not source_dir.exists():
|
||||
print(f"[ERROR] Source directory does not exist: {source_dir}")
|
||||
return
|
||||
|
||||
# Create output directory
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Source directory: {source_dir}")
|
||||
print(f"Output directory: {output_dir}")
|
||||
print(f"Class name: {args.class_name}")
|
||||
print("-" * 60)
|
||||
|
||||
# Process train and val splits
|
||||
class_id = 0 # YOLO uses 0-indexed class IDs
|
||||
train_count = process_split(source_dir, "train", output_dir, class_id)
|
||||
val_count = process_split(source_dir, "val", output_dir, class_id)
|
||||
|
||||
# Generate dataset.yaml
|
||||
generate_dataset_yaml(output_dir, args.class_name)
|
||||
|
||||
print("-" * 60)
|
||||
print(f"[SUMMARY]")
|
||||
print(f" Train images: {train_count}")
|
||||
print(f" Val images: {val_count}")
|
||||
print(f" Total images: {train_count + val_count}")
|
||||
print(f" Dataset ready at: {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
802
FishMeasure/dataset/zed_reader.py
Executable file
802
FishMeasure/dataset/zed_reader.py
Executable file
@@ -0,0 +1,802 @@
|
||||
import pyzed.sl as sl
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
from scipy.spatial import cKDTree
|
||||
|
||||
# Try to import fish detector (optional)
|
||||
try:
|
||||
from detector import FishDetector
|
||||
DETECTOR_AVAILABLE = True
|
||||
except ImportError:
|
||||
DETECTOR_AVAILABLE = False
|
||||
FishDetector = None
|
||||
|
||||
class ZEDReader:
|
||||
def __init__(self, svo_path=None, camera_mode=False, use_yolo_detector=False, yolo_model_path=None):
|
||||
# 创建ZED相机对象
|
||||
self.zed = sl.Camera()
|
||||
self.camera_mode = camera_mode
|
||||
self.use_yolo_detector = use_yolo_detector
|
||||
self.detector = None
|
||||
|
||||
# Initialize YOLO detector if requested
|
||||
if use_yolo_detector:
|
||||
if not DETECTOR_AVAILABLE:
|
||||
print("警告: YOLO检测器不可用。请安装detector.py和相关依赖")
|
||||
self.use_yolo_detector = False
|
||||
else:
|
||||
try:
|
||||
self.detector = FishDetector(model_path=yolo_model_path)
|
||||
print("✓ YOLO鱼类检测器已初始化")
|
||||
except Exception as e:
|
||||
print(f"警告: 初始化YOLO检测器失败: {e}")
|
||||
self.use_yolo_detector = False
|
||||
|
||||
# 设置初始化参数
|
||||
self.init_params = sl.InitParameters()
|
||||
if svo_path and not camera_mode:
|
||||
# 如果提供SVO文件路径,则从文件读取
|
||||
print(f"从SVO文件读取: {svo_path}")
|
||||
self.init_params.set_from_svo_file(svo_path)
|
||||
self.init_params.svo_real_time_mode = False
|
||||
else:
|
||||
# 实时相机模式设置
|
||||
print("使用实时相机模式")
|
||||
self.init_params.camera_resolution = sl.RESOLUTION.HD720
|
||||
self.init_params.depth_mode = sl.DEPTH_MODE.ULTRA
|
||||
self.init_params.coordinate_units = sl.UNIT.MILLIMETER
|
||||
self.init_params.camera_fps = 30 # 设置帧率
|
||||
|
||||
def open(self):
|
||||
# 打开相机或SVO文件
|
||||
err = self.zed.open(self.init_params)
|
||||
if err != sl.ERROR_CODE.SUCCESS:
|
||||
print(f"错误: {err}")
|
||||
return False
|
||||
|
||||
# 获取并打印标定参数
|
||||
self.print_calibration_params()
|
||||
return True
|
||||
|
||||
def print_calibration_params(self):
|
||||
"""打印相机标定参数"""
|
||||
calibration_params = self.zed.get_camera_information().camera_configuration.calibration_parameters
|
||||
|
||||
print("\n=== 相机标定参数 ===")
|
||||
|
||||
# 左眼相机参数
|
||||
print("\n左眼相机参数:")
|
||||
print(f"焦距: fx={calibration_params.left_cam.fx}, fy={calibration_params.left_cam.fy}")
|
||||
print(f"主点: cx={calibration_params.left_cam.cx}, cy={calibration_params.left_cam.cy}")
|
||||
print("畸变参数:")
|
||||
print(f"k1={calibration_params.left_cam.disto[0]}")
|
||||
print(f"k2={calibration_params.left_cam.disto[1]}")
|
||||
print(f"k3={calibration_params.left_cam.disto[2]}")
|
||||
print(f"k4={calibration_params.left_cam.disto[3]}")
|
||||
print(f"k5={calibration_params.left_cam.disto[4]}")
|
||||
|
||||
# 右眼相机参数
|
||||
print("\n右眼相机参数:")
|
||||
print(f"焦距: fx={calibration_params.right_cam.fx}, fy={calibration_params.right_cam.fy}")
|
||||
print(f"主点: cx={calibration_params.right_cam.cx}, cy={calibration_params.right_cam.cy}")
|
||||
print("畸变参数:")
|
||||
print(f"k1={calibration_params.right_cam.disto[0]}")
|
||||
print(f"k2={calibration_params.right_cam.disto[1]}")
|
||||
print(f"k3={calibration_params.right_cam.disto[2]}")
|
||||
print(f"k4={calibration_params.right_cam.disto[3]}")
|
||||
print(f"k5={calibration_params.right_cam.disto[4]}")
|
||||
|
||||
#import pdb; pdb.set_trace()
|
||||
print("\n立体参数:")
|
||||
# 获取旋转矩阵
|
||||
rotation = calibration_params.stereo_transform.m
|
||||
|
||||
print("旋转矩阵 R:")
|
||||
print(f"[{rotation[0][0]}, {rotation[0][1]}, {rotation[0][2]}]")
|
||||
print(f"[{rotation[1][0]}, {rotation[1][1]}, {rotation[1][2]}]")
|
||||
print(f"[{rotation[2][0]}, {rotation[2][1]}, {rotation[2][2]}]")
|
||||
|
||||
print("\n平移向量 T:")
|
||||
print(f"[{rotation[0][3]}, {rotation[1][3]}, {rotation[2][3]}]")
|
||||
|
||||
# 保存标定参数到文件
|
||||
if hasattr(self, 'save_path') and self.save_path:
|
||||
self.save_calibration_params(calibration_params)
|
||||
|
||||
def save_calibration_params(self, calibration_params):
|
||||
"""将标定参数保存到文件"""
|
||||
if not hasattr(self, 'save_path') or not self.save_path:
|
||||
return
|
||||
|
||||
calib_file = os.path.join(self.save_path, 'calibration_params.txt')
|
||||
with open(calib_file, 'w') as f:
|
||||
f.write("=== 相机标定参数 ===\n")
|
||||
|
||||
# 左眼相机
|
||||
f.write("\n左眼相机参数:\n")
|
||||
f.write(f"焦距: fx={calibration_params.left_cam.fx}, fy={calibration_params.left_cam.fy}\n")
|
||||
f.write(f"主点: cx={calibration_params.left_cam.cx}, cy={calibration_params.left_cam.cy}\n")
|
||||
f.write("畸变参数:\n")
|
||||
for i, d in enumerate(calibration_params.left_cam.disto):
|
||||
f.write(f"k{i+1}={d}\n")
|
||||
|
||||
# 右眼相机
|
||||
f.write("\n右眼相机参数:\n")
|
||||
f.write(f"焦距: fx={calibration_params.right_cam.fx}, fy={calibration_params.right_cam.fy}\n")
|
||||
f.write(f"主点: cx={calibration_params.right_cam.cx}, cy={calibration_params.right_cam.cy}\n")
|
||||
f.write("畸变参数:\n")
|
||||
for i, d in enumerate(calibration_params.right_cam.disto):
|
||||
f.write(f"k{i+1}={d}\n")
|
||||
|
||||
# 立体参数
|
||||
f.write("\n立体参数:\n")
|
||||
rotation = calibration_params.stereo_transform.m
|
||||
|
||||
f.write("旋转矩阵 R:\n")
|
||||
f.write(f"[{rotation[0][0]}, {rotation[0][1]}, {rotation[0][2]}]\n")
|
||||
f.write(f"[{rotation[1][0]}, {rotation[1][1]}, {rotation[1][2]}]\n")
|
||||
f.write(f"[{rotation[2][0]}, {rotation[2][1]}, {rotation[2][2]}]\n")
|
||||
|
||||
f.write("\n平移向量 T:\n")
|
||||
f.write(f"[{rotation[0][3]}, {rotation[1][3]}, {rotation[2][3]}]\n")
|
||||
|
||||
def start_recording(self, filename=None):
|
||||
"""开始录制"""
|
||||
if filename is None:
|
||||
# 使用时间戳作为文件名
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"recording_{timestamp}.svo"
|
||||
|
||||
recording_param = sl.RecordingParameters(filename, sl.SVO_COMPRESSION_MODE.H264)
|
||||
err = self.zed.enable_recording(recording_param)
|
||||
if err != sl.ERROR_CODE.SUCCESS:
|
||||
print(f"无法开始录制: {err}")
|
||||
return False
|
||||
print(f"开始录制到文件: {filename}")
|
||||
return True
|
||||
|
||||
def ensure_save_path(self, save_path):
|
||||
"""确保保存路径存在,如果不存在则创建"""
|
||||
if save_path:
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
print(f"创建保存目录: {save_path}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def detect_fish(self, depth_np, calib_params, min_points=100, max_depth=1200, max_y=150):
|
||||
"""
|
||||
检测深度图中是否有鱼类
|
||||
检测条件:
|
||||
- z+ 方向(深度):0-max_depth mm
|
||||
- y- 方向(相机坐标系y轴负方向,y < max_y)
|
||||
- 满足条件的点数 >= min_points
|
||||
|
||||
Args:
|
||||
depth_np: 深度图的numpy数组 (H, W) 单位为mm
|
||||
calib_params: 相机标定参数
|
||||
min_points: 检测到鱼所需的最少点数,默认100
|
||||
max_depth: 最大深度值(mm),默认1200
|
||||
max_y: 最大Y坐标值(mm),默认150
|
||||
|
||||
Returns:
|
||||
tuple: (是否检测到鱼(bool), 满足条件的点数(int))
|
||||
"""
|
||||
height, width = depth_np.shape
|
||||
fx = calib_params.left_cam.fx
|
||||
fy = calib_params.left_cam.fy
|
||||
cx = calib_params.left_cam.cx
|
||||
cy = calib_params.left_cam.cy
|
||||
|
||||
# 创建网格坐标
|
||||
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
|
||||
|
||||
# 计算3D点
|
||||
Z = depth_np # 深度值(mm)
|
||||
X = (x_grid - cx) * Z / fx
|
||||
Y = (y_grid - cy) * Z / fy
|
||||
|
||||
# 检测条件:
|
||||
# 1. z+ 方向:深度值在 0-max_depth mm 范围内
|
||||
# 2. y- 方向:Y坐标小于max_y(相机坐标系中y轴向上为正)
|
||||
# 3. 深度值有效(> 0)
|
||||
valid_depth = (Z > 0) & (Z <= max_depth)
|
||||
valid_y = Y < max_y
|
||||
|
||||
# 同时满足所有条件的点
|
||||
fish_mask = valid_depth & valid_y
|
||||
|
||||
# 统计满足条件的点数
|
||||
fish_point_count = np.sum(fish_mask)
|
||||
|
||||
# 如果点数 >= min_points,认为检测到鱼
|
||||
fish_detected = fish_point_count >= min_points
|
||||
|
||||
return fish_detected, fish_point_count
|
||||
|
||||
def downsample_point_cloud(self, points, colors, voxel_size=5.0):
|
||||
"""
|
||||
使用体素网格对点云进行均匀下采样,去除稀疏点
|
||||
|
||||
Args:
|
||||
points: 点云坐标数组 (N, 3)
|
||||
colors: 颜色数组 (N, 3)
|
||||
voxel_size: 体素大小(mm),默认5.0mm
|
||||
|
||||
Returns:
|
||||
tuple: (下采样后的点云, 下采样后的颜色)
|
||||
"""
|
||||
if len(points) == 0:
|
||||
return points, colors
|
||||
|
||||
# 计算每个点所属的体素索引
|
||||
voxel_indices = np.floor(points / voxel_size).astype(np.int32)
|
||||
|
||||
# 使用字典来存储每个体素中的点(保留第一个点)
|
||||
voxel_dict = {}
|
||||
for i in range(len(points)):
|
||||
voxel_key = tuple(voxel_indices[i])
|
||||
if voxel_key not in voxel_dict:
|
||||
voxel_dict[voxel_key] = i
|
||||
|
||||
# 获取保留的点的索引
|
||||
kept_indices = list(voxel_dict.values())
|
||||
|
||||
return points[kept_indices], colors[kept_indices]
|
||||
|
||||
def filter_point_cloud_kdtree(self, points, colors, radius=5.0, min_neighbors=50):
|
||||
"""
|
||||
使用KDTree进行半径搜索,过滤掉稀疏点
|
||||
如果点在指定半径内的邻居数量少于min_neighbors,则移除该点
|
||||
|
||||
Args:
|
||||
points: 点云坐标数组 (N, 3)
|
||||
colors: 颜色数组 (N, 3)
|
||||
radius: 搜索半径(mm),默认5.0mm
|
||||
min_neighbors: 最小邻居数量,默认50
|
||||
|
||||
Returns:
|
||||
tuple: (过滤后的点云, 过滤后的颜色)
|
||||
"""
|
||||
if len(points) == 0:
|
||||
return points, colors
|
||||
|
||||
# 构建KDTree
|
||||
tree = cKDTree(points)
|
||||
|
||||
# 对每个点进行半径搜索(包括点本身)
|
||||
# 使用query_ball_point找到半径内的所有点
|
||||
neighbor_counts = np.zeros(len(points), dtype=np.int32)
|
||||
|
||||
# 批量查询以提高效率
|
||||
for i in range(len(points)):
|
||||
neighbors = tree.query_ball_point(points[i], radius)
|
||||
neighbor_counts[i] = len(neighbors)
|
||||
|
||||
# 保留邻居数量 >= min_neighbors 的点
|
||||
valid_mask = neighbor_counts >= min_neighbors
|
||||
filtered_points = points[valid_mask]
|
||||
filtered_colors = colors[valid_mask]
|
||||
|
||||
return filtered_points, filtered_colors
|
||||
|
||||
def save_point_cloud(self, left_np, depth_np, calib_params, save_path, base_name):
|
||||
"""
|
||||
从左图像和深度图生成点云
|
||||
left_np: 左图像的numpy数组 (H, W, 3) BGR格式
|
||||
depth_np: 深度图的numpy数组 (H, W) 单位为mm
|
||||
calib_params: 相机标定参数
|
||||
"""
|
||||
height, width = depth_np.shape
|
||||
fx = calib_params.left_cam.fx
|
||||
fy = calib_params.left_cam.fy
|
||||
cx = calib_params.left_cam.cx
|
||||
cy = calib_params.left_cam.cy
|
||||
|
||||
# 创建网格坐标
|
||||
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
|
||||
|
||||
# 计算3D点
|
||||
Z = depth_np # 深度值(mm)
|
||||
X = (x_grid - cx) * Z / fx
|
||||
Y = (y_grid - cy) * Z / fy
|
||||
|
||||
# 创建点云
|
||||
valid_points = Z > 0 # 只使用有效的深度值
|
||||
points = np.stack([X[valid_points], Y[valid_points], Z[valid_points]], axis=1)
|
||||
|
||||
# 获取对应的颜色(转换为RGB)
|
||||
colors = left_np[valid_points]
|
||||
colors = colors[:, [2,1,0]] # BGR to RGB
|
||||
|
||||
# 对点云进行均匀下采样以去除稀疏点
|
||||
points, colors = self.downsample_point_cloud(points, colors, voxel_size=3.0)
|
||||
#print(f"下采样后点数: {len(points)} (原始点数: {np.sum(valid_points)})")
|
||||
|
||||
# 使用KDTree过滤稀疏点(半径5mm内少于10个邻居的点将被移除)
|
||||
points, colors = self.filter_point_cloud_kdtree(points, colors, radius=5.0, min_neighbors=10)
|
||||
#print(f"KDTree过滤后点数: {len(points)} (原始点数: {np.sum(valid_points)})")
|
||||
|
||||
points, colors = self.downsample_point_cloud(points, colors, voxel_size=5.0)
|
||||
print(f"下采样后点数: {len(points)} (原始点数: {np.sum(valid_points)})")
|
||||
|
||||
# 保存为PLY文件
|
||||
ply_file = os.path.join(save_path, f"{base_name}_pointcloud.ply")
|
||||
|
||||
with open(ply_file, 'w') as f:
|
||||
# 写入PLY头部
|
||||
f.write("ply\n")
|
||||
f.write("format ascii 1.0\n")
|
||||
f.write(f"element vertex {len(points)}\n")
|
||||
f.write("property float x\n")
|
||||
f.write("property float y\n")
|
||||
f.write("property float z\n")
|
||||
f.write("property uchar red\n")
|
||||
f.write("property uchar green\n")
|
||||
f.write("property uchar blue\n")
|
||||
f.write("end_header\n")
|
||||
|
||||
# 写入点云数据
|
||||
for i in range(len(points)):
|
||||
f.write(f"{points[i,0]:.3f} {points[i,1]:.3f} {points[i,2]:.3f} "
|
||||
f"{colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
|
||||
|
||||
print(f"保存点云文件: {ply_file}")
|
||||
|
||||
def save_images(self, save_path, left_np, right_np, depth_np, frame_count):
|
||||
"""保存图像,使用帧计数作为文件名"""
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
base_name = f"{timestamp}_{frame_count:06d}"
|
||||
|
||||
# 保存RGB图像
|
||||
cv2.imwrite(os.path.join(save_path, f"{base_name}_left.png"), left_np)
|
||||
cv2.imwrite(os.path.join(save_path, f"{base_name}_right.png"), right_np)
|
||||
|
||||
# 保存原始深度数据(float32格式)
|
||||
depth_file = os.path.join(save_path, f"{base_name}_depth.npy")
|
||||
np.save(depth_file, depth_np)
|
||||
|
||||
# 获取相机标定参数并生成点云
|
||||
calibration_params = self.zed.get_camera_information().camera_configuration.calibration_parameters
|
||||
self.save_point_cloud(left_np, depth_np, calibration_params, save_path, base_name)
|
||||
|
||||
print(f"保存图像、深度数据和点云: {base_name}")
|
||||
|
||||
def save_fish_point_cloud(self, left_np, depth_np, calib_params, save_path, base_name, max_depth=1000, max_y=100, min_points=500):
|
||||
"""
|
||||
保存检测到的鱼类点云(只保存满足检测条件的点)
|
||||
left_np: 左图像的numpy数组 (H, W, 3) BGR格式
|
||||
depth_np: 深度图的numpy数组 (H, W) 单位为mm
|
||||
calib_params: 相机标定参数
|
||||
save_path: 保存路径
|
||||
base_name: 文件名基础
|
||||
max_depth: 最大深度值(mm),默认1000
|
||||
max_y: 最大Y坐标值(mm),默认100
|
||||
min_points: 过滤后所需的最少点数,如果少于这个数则不保存,默认500
|
||||
"""
|
||||
height, width = depth_np.shape
|
||||
fx = calib_params.left_cam.fx
|
||||
fy = calib_params.left_cam.fy
|
||||
cx = calib_params.left_cam.cx
|
||||
cy = calib_params.left_cam.cy
|
||||
|
||||
# 创建网格坐标
|
||||
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
|
||||
|
||||
# 计算3D点
|
||||
Z = depth_np # 深度值(mm)
|
||||
X = (x_grid - cx) * Z / fx
|
||||
Y = (y_grid - cy) * Z / fy
|
||||
|
||||
# 应用与detect_fish相同的过滤条件
|
||||
valid_depth = (Z > 0) & (Z <= max_depth)
|
||||
valid_y = Y < max_y
|
||||
fish_mask = valid_depth & valid_y
|
||||
|
||||
# 只保存满足条件的点
|
||||
points = np.stack([X[fish_mask], Y[fish_mask], Z[fish_mask]], axis=1)
|
||||
raw_points = points.copy()
|
||||
|
||||
# 获取对应的颜色(转换为RGB)
|
||||
colors = left_np[fish_mask]
|
||||
colors = colors[:, [2,1,0]] # BGR to RGB
|
||||
raw_colors = colors.copy()
|
||||
|
||||
# 对点云进行均匀下采样以去除稀疏点
|
||||
points, colors = self.downsample_point_cloud(points, colors, voxel_size=3.0)
|
||||
|
||||
# 使用KDTree过滤稀疏点(半径5mm内少于10个邻居的点将被移除)
|
||||
points, colors = self.filter_point_cloud_kdtree(points, colors, radius=5.0, min_neighbors=10)
|
||||
print(f"KDTree过滤后点数: {len(points)} (原始点数: {np.sum(fish_mask)})")
|
||||
|
||||
points, colors = self.downsample_point_cloud(points, colors, voxel_size=5.0)
|
||||
final_point_count = len(points)
|
||||
print(f"下采样后点数: {final_point_count} (原始点数: {np.sum(fish_mask)})")
|
||||
|
||||
# 检查过滤后的点数是否满足最小点数要求
|
||||
if final_point_count < min_points:
|
||||
print(f"跳过保存:过滤后点数 ({final_point_count}) 少于最小要求 ({min_points})")
|
||||
return False
|
||||
|
||||
# 保存为PLY文件到pointcloud目录
|
||||
pointcloud_dir = os.path.join(save_path, "pointcloud")
|
||||
self.ensure_save_path(pointcloud_dir)
|
||||
ply_file = os.path.join(pointcloud_dir, f"{base_name}_pointcloud.ply")
|
||||
|
||||
with open(ply_file, 'w') as f:
|
||||
# 写入PLY头部
|
||||
f.write("ply\n")
|
||||
f.write("format ascii 1.0\n")
|
||||
f.write(f"element vertex {len(points)}\n")
|
||||
f.write("property float x\n")
|
||||
f.write("property float y\n")
|
||||
f.write("property float z\n")
|
||||
f.write("property uchar red\n")
|
||||
f.write("property uchar green\n")
|
||||
f.write("property uchar blue\n")
|
||||
f.write("end_header\n")
|
||||
|
||||
# 写入点云数据
|
||||
for i in range(len(points)):
|
||||
f.write(f"{points[i,0]:.3f} {points[i,1]:.3f} {points[i,2]:.3f} "
|
||||
f"{colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
|
||||
|
||||
print(f"保存点云文件: {ply_file} (点数: {len(points)})")
|
||||
return True
|
||||
|
||||
def save_fish_point_cloud_debug(self, left_np, depth_np, calib_params, save_path, base_name, max_depth=1000, max_y=100, min_points=500):
|
||||
"""
|
||||
在debug模式下保存检测到的鱼类点云(只保存满足检测条件的点)
|
||||
left_np: 左图像的numpy数组 (H, W, 3) BGR格式
|
||||
depth_np: 深度图的numpy数组 (H, W) 单位为mm
|
||||
calib_params: 相机标定参数
|
||||
save_path: 保存路径
|
||||
base_name: 文件名基础
|
||||
max_depth: 最大深度值(mm),默认1000
|
||||
max_y: 最大Y坐标值(mm),默认100
|
||||
min_points: 过滤后所需的最少点数,如果少于这个数则不保存,默认500
|
||||
"""
|
||||
height, width = depth_np.shape
|
||||
fx = calib_params.left_cam.fx
|
||||
fy = calib_params.left_cam.fy
|
||||
cx = calib_params.left_cam.cx
|
||||
cy = calib_params.left_cam.cy
|
||||
|
||||
# 创建网格坐标
|
||||
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
|
||||
|
||||
# 计算3D点
|
||||
Z = depth_np # 深度值(mm)
|
||||
X = (x_grid - cx) * Z / fx
|
||||
Y = (y_grid - cy) * Z / fy
|
||||
|
||||
# 应用与detect_fish相同的过滤条件
|
||||
valid_depth = (Z > 0) & (Z <= max_depth)
|
||||
valid_y = Y < max_y
|
||||
fish_mask = valid_depth & valid_y
|
||||
|
||||
# 只保存满足条件的点
|
||||
points = np.stack([X[fish_mask], Y[fish_mask], Z[fish_mask]], axis=1)
|
||||
colors = left_np[fish_mask][:, [2,1,0]] # BGR to RGB
|
||||
|
||||
if len(points) == 0:
|
||||
print("[DEBUG] 跳过保存:满足过滤条件的原始点数为 0")
|
||||
return False
|
||||
|
||||
# 对点云进行均匀下采样以去除稀疏点
|
||||
points, colors = self.downsample_point_cloud(points, colors, voxel_size=3.0)
|
||||
#print(f"[DEBUG] 下采样后点数: {len(points)} (原始点数: {np.sum(fish_mask)})")
|
||||
|
||||
# 使用KDTree过滤稀疏点(半径5mm内少于50个邻居的点将被移除)
|
||||
points, colors = self.filter_point_cloud_kdtree(points, colors, radius=5.0, min_neighbors=10)
|
||||
print(f"[DEBUG] KDTree过滤后点数: {len(points)} (原始点数: {np.sum(fish_mask)})")
|
||||
|
||||
points, colors = self.downsample_point_cloud(points, colors, voxel_size=5.0)
|
||||
final_point_count = len(points)
|
||||
print(f"[DEBUG] 下采样后点数: {final_point_count} (原始点数: {np.sum(fish_mask)})")
|
||||
|
||||
# 保存为PLY文件到pointcloud目录
|
||||
pointcloud_dir = os.path.join(save_path, "pointcloud")
|
||||
self.ensure_save_path(pointcloud_dir)
|
||||
|
||||
# 检查过滤后的点数是否满足最小点数要求
|
||||
if final_point_count < min_points:
|
||||
print(f"[DEBUG] 跳过保存:过滤后点数 ({final_point_count}) 少于最小要求 ({min_points})")
|
||||
return False
|
||||
|
||||
# 保存为PLY文件
|
||||
ply_file = os.path.join(pointcloud_dir, f"{base_name}_fish_detected.ply")
|
||||
|
||||
with open(ply_file, 'w') as f:
|
||||
# 写入PLY头部
|
||||
f.write("ply\n")
|
||||
f.write("format ascii 1.0\n")
|
||||
f.write(f"element vertex {len(points)}\n")
|
||||
f.write("property float x\n")
|
||||
f.write("property float y\n")
|
||||
f.write("property float z\n")
|
||||
f.write("property uchar red\n")
|
||||
f.write("property uchar green\n")
|
||||
f.write("property uchar blue\n")
|
||||
f.write("end_header\n")
|
||||
|
||||
# 写入点云数据
|
||||
for i in range(len(points)):
|
||||
f.write(f"{points[i,0]:.3f} {points[i,1]:.3f} {points[i,2]:.3f} "
|
||||
f"{colors[i,0]} {colors[i,1]} {colors[i,2]}\n")
|
||||
|
||||
print(f"[DEBUG] 保存检测到的鱼类点云: {ply_file} (点数: {len(points)})")
|
||||
return True
|
||||
|
||||
def process_frames(self, save_path=None, record=False, save_all=False, debug=False, max_depth=1000, max_y=100, min_points=50, show_yolo_bbox=True):
|
||||
"""
|
||||
处理帧
|
||||
save_path: 保存路径
|
||||
record: 是否录制SVO文件
|
||||
save_all: 是否保存所有帧
|
||||
debug: 是否启用debug模式(检测到鱼时保存点云)
|
||||
max_depth: 最大深度值(mm),默认1000
|
||||
max_y: 最大Y坐标值(mm),默认100
|
||||
min_points: 检测到鱼所需的最少点数,默认50
|
||||
show_yolo_bbox: 是否显示YOLO检测的边界框,默认True
|
||||
"""
|
||||
self.save_path = save_path # 保存路径作为类属性
|
||||
|
||||
# 确保保存路径存在
|
||||
if save_path:
|
||||
self.ensure_save_path(save_path)
|
||||
|
||||
# 如果是录制模式,开始录制
|
||||
if record and self.camera_mode:
|
||||
if save_path:
|
||||
svo_file = os.path.join(save_path, f"recording_{time.strftime('%Y%m%d_%H%M%S')}.svo")
|
||||
else:
|
||||
svo_file = f"recording_{time.strftime('%Y%m%d_%H%M%S')}.svo"
|
||||
self.start_recording(svo_file)
|
||||
|
||||
# 获取相机标定参数(用于鱼类检测)
|
||||
calibration_params = self.zed.get_camera_information().camera_configuration.calibration_parameters
|
||||
|
||||
# 创建运行时参数
|
||||
runtime_parameters = sl.RuntimeParameters()
|
||||
|
||||
# 创建图像容器
|
||||
left_image = sl.Mat()
|
||||
right_image = sl.Mat()
|
||||
depth_map = sl.Mat()
|
||||
|
||||
# 创建输出目录
|
||||
if save_path:
|
||||
images_dir = os.path.join(save_path, "images")
|
||||
pointcloud_dir = os.path.join(save_path, "pointcloud")
|
||||
self.ensure_save_path(images_dir)
|
||||
self.ensure_save_path(pointcloud_dir)
|
||||
print(f"图像将保存到: {images_dir}")
|
||||
print(f"点云将保存到: {pointcloud_dir}")
|
||||
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
fish_detection_count = 0
|
||||
frame_skip_interval = 5 # Process every 5th frame
|
||||
|
||||
# 用于非debug模式的鱼类跟踪
|
||||
previous_fish_detected = False
|
||||
active_fish_tracking = [] # 列表,每个元素是 (first_detection_frame, fish_number)
|
||||
fish_counter = 0 # 用于生成fish1, fish2等编号
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 抓取新的帧
|
||||
if self.zed.grab(runtime_parameters) == sl.ERROR_CODE.SUCCESS:
|
||||
frame_count += 1
|
||||
|
||||
# Skip frames: only process every 5th frame
|
||||
if frame_count % frame_skip_interval != 0:
|
||||
# 跳过帧不处理,直接继续
|
||||
continue # Skip processing for this frame
|
||||
|
||||
# Process this frame (every 5th frame)
|
||||
# 获取图像
|
||||
self.zed.retrieve_image(left_image, sl.VIEW.LEFT)
|
||||
self.zed.retrieve_image(right_image, sl.VIEW.RIGHT)
|
||||
self.zed.retrieve_measure(depth_map, sl.MEASURE.DEPTH)
|
||||
|
||||
# 转换为numpy数组
|
||||
left_np = left_image.get_data()
|
||||
right_np = right_image.get_data()
|
||||
depth_np = depth_map.get_data()
|
||||
|
||||
# YOLO鱼类检测(如果启用)
|
||||
yolo_detected = False
|
||||
yolo_bboxes = []
|
||||
if self.use_yolo_detector:
|
||||
try:
|
||||
yolo_bboxes = self.detector.detect_fish(left_np)
|
||||
yolo_detected = len(yolo_bboxes) > 0
|
||||
if yolo_detected:
|
||||
print(f"[帧 {frame_count}] YOLO检测到 {len(yolo_bboxes)} 条鱼")
|
||||
except Exception as e:
|
||||
print(f"[帧 {frame_count}] YOLO检测错误: {e}")
|
||||
|
||||
# 深度鱼类检测
|
||||
fish_detected, fish_point_count = self.detect_fish(depth_np, calibration_params, min_points=min_points, max_depth=max_depth, max_y=max_y)
|
||||
|
||||
# 非debug模式:跟踪鱼类检测并保存点云
|
||||
if not debug and save_path:
|
||||
# 检测新的鱼类出现(从False到True的转换)
|
||||
# 如果启用了YOLO,需要同时满足YOLO和深度检测;否则只需深度检测
|
||||
detection_triggered = (self.use_yolo_detector and yolo_detected and fish_detected) or (not self.use_yolo_detector and fish_detected)
|
||||
|
||||
if detection_triggered and not previous_fish_detected:
|
||||
fish_counter += 1
|
||||
first_detection_frame = frame_count
|
||||
active_fish_tracking.append((first_detection_frame, fish_counter))
|
||||
detection_method = "YOLO+深度" if self.use_yolo_detector else "深度"
|
||||
print(f"[帧 {frame_count}] 检测到新鱼类 (fish{fish_counter})!{detection_method}检测,满足条件的点数: {fish_point_count}")
|
||||
|
||||
# 检查是否有需要保存的点云(first_frame + 12)
|
||||
fish_to_remove = []
|
||||
for first_frame, fish_num in active_fish_tracking:
|
||||
if frame_count == first_frame + 12:
|
||||
# 如果启用了YOLO,再次检查YOLO是否检测到鱼
|
||||
should_save = True
|
||||
if self.use_yolo_detector:
|
||||
current_yolo_detected = False
|
||||
try:
|
||||
current_yolo_bboxes = self.detector.detect_fish(left_np)
|
||||
current_yolo_detected = len(current_yolo_bboxes) > 0
|
||||
except:
|
||||
pass
|
||||
|
||||
if not current_yolo_detected:
|
||||
should_save = False
|
||||
print(f"[帧 {frame_count}] 跳过保存鱼类点云: fish{fish_num} (YOLO未检测到鱼)")
|
||||
|
||||
if should_save:
|
||||
# 保存点云
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
base_name = f"{timestamp}_{frame_count:06d}_fish{fish_num}"
|
||||
saved = self.save_fish_point_cloud(left_np, depth_np, calibration_params, save_path, base_name, max_depth=max_depth, max_y=max_y, min_points=min_points)
|
||||
if saved:
|
||||
pointcloud_saved_this_frame = True
|
||||
print(f"[帧 {frame_count}] 保存鱼类点云: fish{fish_num} (首次检测于帧 {first_frame})")
|
||||
else:
|
||||
print(f"[帧 {frame_count}] 跳过保存鱼类点云: fish{fish_num} (过滤后点数不足)")
|
||||
|
||||
fish_to_remove.append((first_frame, fish_num))
|
||||
|
||||
# 移除已保存的鱼类跟踪
|
||||
for item in fish_to_remove:
|
||||
active_fish_tracking.remove(item)
|
||||
|
||||
# 如果鱼类消失(从True到False),但保留跟踪直到保存完成
|
||||
# 这样即使鱼类在保存前消失,我们仍会在first_frame + 12时保存
|
||||
if not detection_triggered and previous_fish_detected:
|
||||
print(f"[帧 {frame_count}] 鱼类消失,但保留跟踪直到保存完成")
|
||||
|
||||
# 跟踪当前帧是否成功保存了点云(用于决定是否保存图像)
|
||||
pointcloud_saved_this_frame = False
|
||||
|
||||
if fish_detected:
|
||||
fish_detection_count += 1
|
||||
if debug:
|
||||
print(f"[帧 {frame_count}] 检测到鱼类!满足条件的点数: {fish_point_count}")
|
||||
|
||||
# 如果启用debug模式,保存检测到的鱼类点云
|
||||
if debug and save_path:
|
||||
# 如果启用了YOLO,只有当YOLO检测到鱼时才保存
|
||||
should_save_debug = True
|
||||
if self.use_yolo_detector and not yolo_detected:
|
||||
should_save_debug = False
|
||||
print(f"[帧 {frame_count}] [DEBUG] 跳过保存:YOLO未检测到鱼")
|
||||
|
||||
if should_save_debug:
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
base_name = f"{timestamp}_{frame_count:06d}_fish"
|
||||
saved = self.save_fish_point_cloud_debug(left_np, depth_np, calibration_params, save_path, base_name, max_depth=max_depth, max_y=max_y, min_points=min_points)
|
||||
if saved:
|
||||
pointcloud_saved_this_frame = True
|
||||
else:
|
||||
print(f"[帧 {frame_count}] [DEBUG] 跳过保存:过滤后点数不足")
|
||||
|
||||
# 更新前一次的检测状态
|
||||
# 如果启用了YOLO,需要同时满足YOLO和深度检测
|
||||
if self.use_yolo_detector:
|
||||
previous_fish_detected = yolo_detected and fish_detected
|
||||
else:
|
||||
previous_fish_detected = fish_detected
|
||||
|
||||
# 如果设置保存所有帧
|
||||
if save_all and save_path:
|
||||
self.save_images(save_path, left_np, right_np, depth_np, frame_count)
|
||||
|
||||
# 复制左图用于绘制边界框(避免修改原始图像)
|
||||
display_left = left_np.copy()
|
||||
|
||||
# YOLO检测并绘制边界框(如果启用)
|
||||
if self.use_yolo_detector and show_yolo_bbox and yolo_bboxes:
|
||||
try:
|
||||
display_left = self.detector.draw_bboxes(display_left, yolo_bboxes)
|
||||
except Exception as e:
|
||||
print(f"[帧 {frame_count}] 绘制边界框错误: {e}")
|
||||
|
||||
# 在图像上显示检测状态和信息
|
||||
if self.camera_mode:
|
||||
elapsed_time = time.time() - start_time
|
||||
fps = frame_count / elapsed_time if elapsed_time > 0 else 0
|
||||
status = f"Frame: {frame_count} (processed), FPS: {fps:.1f}"
|
||||
cv2.putText(display_left, status, (10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
||||
|
||||
# 在图像上显示检测状态
|
||||
if self.use_yolo_detector:
|
||||
detection_status = f"YOLO: {'YES' if yolo_detected else 'NO'} | Depth: {'YES' if fish_detected else 'NO'} ({fish_point_count} pts)"
|
||||
status_color = (0, 255, 0) if (yolo_detected and fish_detected) else (0, 165, 255) if (yolo_detected or fish_detected) else (0, 0, 255)
|
||||
else:
|
||||
detection_status = f"Fish: {'YES' if fish_detected else 'NO'} ({fish_point_count} pts)"
|
||||
status_color = (0, 255, 0) if fish_detected else (0, 0, 255)
|
||||
|
||||
cv2.putText(display_left, detection_status, (10, 60),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2)
|
||||
|
||||
# 保存带边界框的图像(只有当点云成功保存时才保存图像)
|
||||
if save_path and pointcloud_saved_this_frame:
|
||||
images_dir = os.path.join(save_path, "images")
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
image_filename = f"{timestamp}_{frame_count:06d}_bbox.png"
|
||||
image_path = os.path.join(images_dir, image_filename)
|
||||
cv2.imwrite(image_path, display_left)
|
||||
print(f"[帧 {frame_count}] 保存带边界框的图像: {image_path}")
|
||||
elif not self.camera_mode: # 如果是SVO模式且到达文件末尾
|
||||
break
|
||||
|
||||
finally:
|
||||
if record:
|
||||
self.zed.disable_recording()
|
||||
self.close()
|
||||
print(f"总共处理了 {frame_count} 帧")
|
||||
print(f"检测到鱼类的帧数: {fish_detection_count}")
|
||||
if self.camera_mode:
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"平均帧率: {frame_count/elapsed_time:.1f} FPS")
|
||||
|
||||
def close(self):
|
||||
# 关闭相机
|
||||
self.zed.close()
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--svo', type=str, help='SVO文件路径', default=None)
|
||||
parser.add_argument('--save_path', type=str, help='保存图像的路径', default='./zed_images')
|
||||
parser.add_argument('--record', action='store_true', help='是否录制新的SVO文件')
|
||||
parser.add_argument('--save_all', action='store_true', help='是否保存所有帧')
|
||||
parser.add_argument('--camera', action='store_true', help='使用实时相机模式')
|
||||
parser.add_argument('--debug', action='store_true', help='启用debug模式:检测到鱼时自动保存点云')
|
||||
parser.add_argument('--max_depth', type=float, help='最大深度值(mm),默认1000', default=1000)
|
||||
parser.add_argument('--max_y', type=float, help='最大Y坐标值(mm),默认100', default=100)
|
||||
parser.add_argument('--min_points', type=int, help='检测到鱼所需的最少点数,默认50', default=500)
|
||||
parser.add_argument('--use_yolo', action='store_true', help='使用YOLO检测器检测鱼(只有YOLO检测到边界框时才保存点云)')
|
||||
parser.add_argument('--yolo_model', type=str, help='YOLO模型路径(.pt文件)', default=None)
|
||||
parser.add_argument('--show_yolo_bbox', action='store_true', default=True, help='显示YOLO检测的边界框(默认启用)')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.camera and args.svo:
|
||||
print("错误:不能同时指定相机模式和SVO文件")
|
||||
return
|
||||
|
||||
# 创建ZED读取器实例
|
||||
zed_reader = ZEDReader(args.svo, camera_mode=args.camera,
|
||||
use_yolo_detector=args.use_yolo,
|
||||
yolo_model_path=args.yolo_model)
|
||||
|
||||
# 打开相机或文件
|
||||
if zed_reader.open():
|
||||
# 处理帧
|
||||
zed_reader.process_frames(args.save_path, args.record, args.save_all, args.debug,
|
||||
args.max_depth, args.max_y, args.min_points, args.show_yolo_bbox)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user