6.3 KiB
6.3 KiB
Point Cloud Quality Classifier
基于深度学习的鱼类点云质量分类模块,用于区分高质量和低质量的鱼类3D点云。
问题背景
在水下环境中,即使 YOLO 检测效果很好,深度图质量可能很差,导致生成的点云包含大量异常点,影响后续对齐和估计精度。通过深度学习分类可以自动识别和过滤低质量点云,提高整体系统鲁棒性。
工作流程
1. 标注鱼类点云(Label Fish Point Clouds)
目标:为训练数据收集和标注高质量/低质量点云样本
步骤:
- 从
fish_video_weight_evaluation.py生成的点云中收集样本 - 手动检查每个点云的质量:
- Good(高质量):点云完整、形状清晰、无明显异常点、适合对齐和估计
- Bad(低质量):点云稀疏、形状不完整、包含大量异常点、深度误差明显
- 将点云文件分类到对应文件夹:
dataset/ ├── train/ │ ├── good/ # 高质量点云样本 │ └── bad/ # 低质量点云样本 ├── val/ │ ├── good/ │ └── bad/ └── test/ ├── good/ └── bad/
标注工具(待实现):
- 可视化工具:使用 Open3D 或 CloudCompare 查看点云
- 批量标注脚本:自动组织点云文件到对应类别文件夹
- 标注验证:确保数据集平衡和质量
数据要求:
- 每个类别至少需要 500+ 样本(建议 1000+)
- 训练/验证/测试集比例:70% / 15% / 15%
- 点云格式:PLY 文件(包含 XYZ 坐标和 RGB 颜色)
2. 训练 PointTransformer 模型(Train PointTransformer)
模型选择:
- PointTransformer:基于 Transformer 的点云处理模型,对点云质量分类效果好
- PointNet++:备选方案,经典的点云分类模型
训练步骤:
-
数据预处理:
- 点云归一化(中心化、缩放)
- 点云采样到固定点数(如 1024 或 2048 点)
- 数据增强(旋转、缩放、噪声添加)
-
模型训练:
python train_pointcloud_classifier.py \ --data dataset/ \ --model pointtransformer \ --epochs 100 \ --batch_size 32 \ --num_points 1024 \ --lr 0.001 -
训练参数:
- 输入点数:1024 或 2048(根据点云密度调整)
- 学习率:0.001(可调整)
- 批次大小:32(根据 GPU 内存调整)
- 训练轮数:100+(根据过拟合情况调整)
-
模型保存:
- 保存最佳模型:
checkpoints/best_pointcloud_classifier.pt - 保存训练日志和指标
- 保存最佳模型:
3. 测试和评估(Test and Evaluate)
测试步骤:
-
模型推理:
python test_pointcloud_classifier.py \ --model checkpoints/best_pointcloud_classifier.pt \ --data dataset/test/ \ --output results/ -
评估指标:
- 准确率(Accuracy)
- 精确率(Precision)
- 召回率(Recall)
- F1 分数
- 混淆矩阵
-
可视化结果:
- 分类结果统计
- 错误分类案例分析
- 点云质量分布可视化
集成到主流程:
- 在
fish_video_weight_evaluation.py中集成分类器 - 自动过滤被分类为 "bad" 的点云
- 仅保留 "good" 点云用于后续对齐和估计
文件结构
pointcloud_classifier/
├── README.md # 本文档
├── label_pointclouds.py # 点云标注工具(待实现)
├── train_pointcloud_classifier.py # 训练脚本(待实现)
├── test_pointcloud_classifier.py # 测试脚本(待实现)
├── models/
│ ├── pointtransformer.py # PointTransformer 模型定义(待实现)
│ └── pointnet2.py # PointNet++ 模型定义(备选,待实现)
├── dataset/
│ ├── train/
│ │ ├── good/
│ │ └── bad/
│ ├── val/
│ │ ├── good/
│ │ └── bad/
│ └── test/
│ ├── good/
│ └── bad/
└── checkpoints/ # 保存训练好的模型
└── best_pointcloud_classifier.pt
依赖库
pip install torch torchvision
pip install torch-geometric # PointTransformer 需要
pip install open3d # 点云可视化
pip install numpy
pip install scikit-learn # 评估指标
使用示例
标注点云
# 使用标注工具查看和分类点云
python label_pointclouds.py --input output_preview/*/cloud/ --output dataset/
训练模型
# 训练 PointTransformer 分类器
python train_pointcloud_classifier.py \
--data dataset/ \
--model pointtransformer \
--epochs 100 \
--batch_size 32
测试模型
# 在测试集上评估
python test_pointcloud_classifier.py \
--model checkpoints/best_pointcloud_classifier.pt \
--data dataset/test/
集成到主流程
# 在 fish_video_weight_evaluation.py 中使用
from pointcloud_classifier import PointCloudClassifier
classifier = PointCloudClassifier("checkpoints/best_pointcloud_classifier.pt")
quality_score = classifier.predict(points, colors)
if quality_score > 0.5: # "good" threshold
# 保存点云用于后续处理
save_pointcloud(points, colors)
注意事项
- 数据平衡:确保 good/bad 样本数量大致平衡,避免类别不平衡问题
- 点云预处理:统一点云格式和点数,确保模型输入一致性
- 模型选择:PointTransformer 通常比 PointNet++ 效果更好,但计算量更大
- 阈值调整:根据实际应用场景调整分类阈值(默认 0.5)
- 持续改进:收集错误分类样本,扩充训练数据,迭代改进模型
待实现功能
- 点云标注工具(可视化 + 批量分类)
- PointTransformer 模型实现
- 训练脚本(数据加载、训练循环、模型保存)
- 测试脚本(推理、评估、可视化)
- 集成到
fish_video_weight_evaluation.py - 模型部署优化(量化、加速)
参考资料
- PointTransformer: https://github.com/POSTECH-CVLab/point-transformer
- PointNet++: https://github.com/charlesq34/pointnet2
- PyTorch Geometric: https://pytorch-geometric.readthedocs.io/