4.3 KiB
4.3 KiB
Fish Point Cloud Quality Classification
使用 PointNet/PointNet++ 进行鱼类点云质量分类(good/bad)的训练脚本。
更新内容
-
新的数据加载器 (
data_utils/FishPointCloudDataLoader.py):- 从 PLY 文件加载点云数据
- 支持 good/bad 二分类
- 支持数据预处理和缓存
-
更新的训练脚本 (
train_classification.py):- 支持鱼类点云数据集
- 自动适配 2 类分类(good/bad)
- 兼容原有的 ModelNet 数据集
使用方法
1. 准备数据集
确保数据集结构如下:
dataset/
├── train/
│ ├── good/ # 高质量点云 PLY 文件
│ └── bad/ # 低质量点云 PLY 文件
├── val/
│ ├── good/
│ └── bad/
└── test/
├── good/
└── bad/
2. 训练模型
使用 PointNet 训练:
cd pointcloud_classifier/Pointnet_Pointnet2_pytorch
python train_classification.py \
--use_fish_dataset \
--data_path ../../dataset/ \
--model pointnet_cls \
--num_category 2 \
--num_point 1024 \
--batch_size 24 \
--epoch 200 \
--learning_rate 0.001 \
--gpu 0 \
--log_dir fish_pointnet
使用 PointNet++ 训练:
python train_classification.py \
--use_fish_dataset \
--data_path ../../dataset/ \
--model pointnet2_cls_ssg \
--num_category 2 \
--num_point 1024 \
--batch_size 24 \
--epoch 200 \
--learning_rate 0.001 \
--gpu 0 \
--log_dir fish_pointnet2
3. 使用预训练检查点
如果已有预训练模型(如 ModelNet40 上训练的),可以使用 --pretrained 参数加载:
# 从指定路径加载预训练模型
python train_classification.py \
--use_fish_dataset \
--data_path ../../dataset/ \
--model pointnet2_cls_ssg \
--num_category 2 \
--pretrained log/classification/pointnet2_cls_ssg/checkpoints/best_model.pth \
--log_dir fish_pointnet2_finetune
# 或者从其他预训练模型加载(如 pointnet2_ssg_wo_normals)
python train_classification.py \
--use_fish_dataset \
--data_path ../../dataset/ \
--model pointnet2_cls_ssg \
--num_category 2 \
--pretrained log/classification/pointnet2_ssg_wo_normals/checkpoints/best_model.pth \
--log_dir fish_pointnet2_finetune
注意:
- 如果预训练模型的分类头类别数(如 40)与当前任务(2)不同,脚本会自动跳过分类头,只加载特征提取器部分
- 这样可以实现迁移学习,利用预训练的特征提取能力
4. 参数说明
--use_fish_dataset: 使用鱼类点云数据集(必须)--data_path: 数据集根目录路径--model: 模型名称 (pointnet_cls或pointnet2_cls_ssg)--num_category: 类别数量(2 表示 good/bad)--num_point: 每个点云的点数(默认 1024)--batch_size: 批次大小--epoch: 训练轮数--learning_rate: 学习率--gpu: GPU 设备 ID--log_dir: 实验日志目录名称--pretrained: 预训练模型路径(可选,用于迁移学习)--process_data: 预处理并缓存数据(首次运行建议使用)--use_uniform_sample: 使用 FPS 均匀采样
5. 输出
训练过程会保存:
- 检查点:
log/classification/{log_dir}/checkpoints/best_model.pth - 训练日志:
log/classification/{log_dir}/logs/{model_name}.txt
注意事项
- 数据格式:确保 PLY 文件格式正确,包含 XYZ 坐标
- 类别标签:good=1, bad=0
- 预训练模型:如果使用预训练模型,需要确保模型架构兼容(可能需要调整分类头)
- 内存:如果点云文件很大,建议使用
--process_data预处理数据
示例
完整训练示例:
# 1. 准备数据集(使用 dataset.py 脚本)
cd ../../
python pointcloud_classifier/dataset.py \
--source output_preview4/ \
--output dataset/ \
--train_ratio 0.7 \
--val_ratio 0.15 \
--test_ratio 0.15
# 2. 训练模型
cd pointcloud_classifier/Pointnet_Pointnet2_pytorch
python train_classification.py \
--use_fish_dataset \
--data_path ../../dataset/ \
--model pointnet2_cls_ssg \
--num_category 2 \
--num_point 1024 \
--batch_size 32 \
--epoch 200 \
--learning_rate 0.001 \
--process_data \
--gpu 0 \
--log_dir fish_quality_classifier