Files

6.3 KiB
Raw Permalink Blame History

Point Cloud Quality Classifier

基于深度学习的鱼类点云质量分类模块用于区分高质量和低质量的鱼类3D点云。

问题背景

在水下环境中,即使 YOLO 检测效果很好,深度图质量可能很差,导致生成的点云包含大量异常点,影响后续对齐和估计精度。通过深度学习分类可以自动识别和过滤低质量点云,提高整体系统鲁棒性。

工作流程

1. 标注鱼类点云Label Fish Point Clouds

目标:为训练数据收集和标注高质量/低质量点云样本

步骤

  1. fish_video_weight_evaluation.py 生成的点云中收集样本
  2. 手动检查每个点云的质量:
    • Good高质量:点云完整、形状清晰、无明显异常点、适合对齐和估计
    • Bad低质量:点云稀疏、形状不完整、包含大量异常点、深度误差明显
  3. 将点云文件分类到对应文件夹:
    dataset/
    ├── train/
    │   ├── good/     # 高质量点云样本
    │   └── bad/      # 低质量点云样本
    ├── val/
    │   ├── good/
    │   └── bad/
    └── test/
        ├── good/
        └── bad/
    

标注工具(待实现):

  • 可视化工具:使用 Open3D 或 CloudCompare 查看点云
  • 批量标注脚本:自动组织点云文件到对应类别文件夹
  • 标注验证:确保数据集平衡和质量

数据要求

  • 每个类别至少需要 500+ 样本(建议 1000+
  • 训练/验证/测试集比例70% / 15% / 15%
  • 点云格式PLY 文件(包含 XYZ 坐标和 RGB 颜色)

2. 训练 PointTransformer 模型Train PointTransformer

模型选择

  • PointTransformer:基于 Transformer 的点云处理模型,对点云质量分类效果好
  • PointNet++:备选方案,经典的点云分类模型

训练步骤

  1. 数据预处理

    • 点云归一化(中心化、缩放)
    • 点云采样到固定点数(如 1024 或 2048 点)
    • 数据增强(旋转、缩放、噪声添加)
  2. 模型训练

    python train_pointcloud_classifier.py \
        --data dataset/ \
        --model pointtransformer \
        --epochs 100 \
        --batch_size 32 \
        --num_points 1024 \
        --lr 0.001
    
  3. 训练参数

    • 输入点数1024 或 2048根据点云密度调整
    • 学习率0.001(可调整)
    • 批次大小32根据 GPU 内存调整)
    • 训练轮数100+(根据过拟合情况调整)
  4. 模型保存

    • 保存最佳模型:checkpoints/best_pointcloud_classifier.pt
    • 保存训练日志和指标

3. 测试和评估Test and Evaluate

测试步骤

  1. 模型推理

    python test_pointcloud_classifier.py \
        --model checkpoints/best_pointcloud_classifier.pt \
        --data dataset/test/ \
        --output results/
    
  2. 评估指标

    • 准确率Accuracy
    • 精确率Precision
    • 召回率Recall
    • F1 分数
    • 混淆矩阵
  3. 可视化结果

    • 分类结果统计
    • 错误分类案例分析
    • 点云质量分布可视化

集成到主流程

  • fish_video_weight_evaluation.py 中集成分类器
  • 自动过滤被分类为 "bad" 的点云
  • 仅保留 "good" 点云用于后续对齐和估计

文件结构

pointcloud_classifier/
├── README.md                    # 本文档
├── label_pointclouds.py         # 点云标注工具(待实现)
├── train_pointcloud_classifier.py  # 训练脚本(待实现)
├── test_pointcloud_classifier.py   # 测试脚本(待实现)
├── models/
│   ├── pointtransformer.py     # PointTransformer 模型定义(待实现)
│   └── pointnet2.py            # PointNet++ 模型定义(备选,待实现)
├── dataset/
│   ├── train/
│   │   ├── good/
│   │   └── bad/
│   ├── val/
│   │   ├── good/
│   │   └── bad/
│   └── test/
│       ├── good/
│       └── bad/
└── checkpoints/                 # 保存训练好的模型
    └── best_pointcloud_classifier.pt

依赖库

pip install torch torchvision
pip install torch-geometric  # PointTransformer 需要
pip install open3d           # 点云可视化
pip install numpy
pip install scikit-learn     # 评估指标

使用示例

标注点云

# 使用标注工具查看和分类点云
python label_pointclouds.py --input output_preview/*/cloud/ --output dataset/

训练模型

# 训练 PointTransformer 分类器
python train_pointcloud_classifier.py \
    --data dataset/ \
    --model pointtransformer \
    --epochs 100 \
    --batch_size 32

测试模型

# 在测试集上评估
python test_pointcloud_classifier.py \
    --model checkpoints/best_pointcloud_classifier.pt \
    --data dataset/test/

集成到主流程

# 在 fish_video_weight_evaluation.py 中使用
from pointcloud_classifier import PointCloudClassifier

classifier = PointCloudClassifier("checkpoints/best_pointcloud_classifier.pt")
quality_score = classifier.predict(points, colors)
if quality_score > 0.5:  # "good" threshold
    # 保存点云用于后续处理
    save_pointcloud(points, colors)

注意事项

  1. 数据平衡:确保 good/bad 样本数量大致平衡,避免类别不平衡问题
  2. 点云预处理:统一点云格式和点数,确保模型输入一致性
  3. 模型选择PointTransformer 通常比 PointNet++ 效果更好,但计算量更大
  4. 阈值调整:根据实际应用场景调整分类阈值(默认 0.5
  5. 持续改进:收集错误分类样本,扩充训练数据,迭代改进模型

待实现功能

  • 点云标注工具(可视化 + 批量分类)
  • PointTransformer 模型实现
  • 训练脚本(数据加载、训练循环、模型保存)
  • 测试脚本(推理、评估、可视化)
  • 集成到 fish_video_weight_evaluation.py
  • 模型部署优化(量化、加速)

参考资料