155 lines
4.3 KiB
Markdown
155 lines
4.3 KiB
Markdown
|
|
# Fish Point Cloud Quality Classification
|
|||
|
|
|
|||
|
|
使用 PointNet/PointNet++ 进行鱼类点云质量分类(good/bad)的训练脚本。
|
|||
|
|
|
|||
|
|
## 更新内容
|
|||
|
|
|
|||
|
|
1. **新的数据加载器** (`data_utils/FishPointCloudDataLoader.py`):
|
|||
|
|
- 从 PLY 文件加载点云数据
|
|||
|
|
- 支持 good/bad 二分类
|
|||
|
|
- 支持数据预处理和缓存
|
|||
|
|
|
|||
|
|
2. **更新的训练脚本** (`train_classification.py`):
|
|||
|
|
- 支持鱼类点云数据集
|
|||
|
|
- 自动适配 2 类分类(good/bad)
|
|||
|
|
- 兼容原有的 ModelNet 数据集
|
|||
|
|
|
|||
|
|
## 使用方法
|
|||
|
|
|
|||
|
|
### 1. 准备数据集
|
|||
|
|
|
|||
|
|
确保数据集结构如下:
|
|||
|
|
```
|
|||
|
|
dataset/
|
|||
|
|
├── train/
|
|||
|
|
│ ├── good/ # 高质量点云 PLY 文件
|
|||
|
|
│ └── bad/ # 低质量点云 PLY 文件
|
|||
|
|
├── val/
|
|||
|
|
│ ├── good/
|
|||
|
|
│ └── bad/
|
|||
|
|
└── test/
|
|||
|
|
├── good/
|
|||
|
|
└── bad/
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 2. 训练模型
|
|||
|
|
|
|||
|
|
#### 使用 PointNet 训练:
|
|||
|
|
```bash
|
|||
|
|
cd pointcloud_classifier/Pointnet_Pointnet2_pytorch
|
|||
|
|
|
|||
|
|
python train_classification.py \
|
|||
|
|
--use_fish_dataset \
|
|||
|
|
--data_path ../../dataset/ \
|
|||
|
|
--model pointnet_cls \
|
|||
|
|
--num_category 2 \
|
|||
|
|
--num_point 1024 \
|
|||
|
|
--batch_size 24 \
|
|||
|
|
--epoch 200 \
|
|||
|
|
--learning_rate 0.001 \
|
|||
|
|
--gpu 0 \
|
|||
|
|
--log_dir fish_pointnet
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
#### 使用 PointNet++ 训练:
|
|||
|
|
```bash
|
|||
|
|
python train_classification.py \
|
|||
|
|
--use_fish_dataset \
|
|||
|
|
--data_path ../../dataset/ \
|
|||
|
|
--model pointnet2_cls_ssg \
|
|||
|
|
--num_category 2 \
|
|||
|
|
--num_point 1024 \
|
|||
|
|
--batch_size 24 \
|
|||
|
|
--epoch 200 \
|
|||
|
|
--learning_rate 0.001 \
|
|||
|
|
--gpu 0 \
|
|||
|
|
--log_dir fish_pointnet2
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 3. 使用预训练检查点
|
|||
|
|
|
|||
|
|
如果已有预训练模型(如 ModelNet40 上训练的),可以使用 `--pretrained` 参数加载:
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
# 从指定路径加载预训练模型
|
|||
|
|
python train_classification.py \
|
|||
|
|
--use_fish_dataset \
|
|||
|
|
--data_path ../../dataset/ \
|
|||
|
|
--model pointnet2_cls_ssg \
|
|||
|
|
--num_category 2 \
|
|||
|
|
--pretrained log/classification/pointnet2_cls_ssg/checkpoints/best_model.pth \
|
|||
|
|
--log_dir fish_pointnet2_finetune
|
|||
|
|
|
|||
|
|
# 或者从其他预训练模型加载(如 pointnet2_ssg_wo_normals)
|
|||
|
|
python train_classification.py \
|
|||
|
|
--use_fish_dataset \
|
|||
|
|
--data_path ../../dataset/ \
|
|||
|
|
--model pointnet2_cls_ssg \
|
|||
|
|
--num_category 2 \
|
|||
|
|
--pretrained log/classification/pointnet2_ssg_wo_normals/checkpoints/best_model.pth \
|
|||
|
|
--log_dir fish_pointnet2_finetune
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**注意**:
|
|||
|
|
- 如果预训练模型的分类头类别数(如 40)与当前任务(2)不同,脚本会自动跳过分类头,只加载特征提取器部分
|
|||
|
|
- 这样可以实现迁移学习,利用预训练的特征提取能力
|
|||
|
|
|
|||
|
|
### 4. 参数说明
|
|||
|
|
|
|||
|
|
- `--use_fish_dataset`: 使用鱼类点云数据集(必须)
|
|||
|
|
- `--data_path`: 数据集根目录路径
|
|||
|
|
- `--model`: 模型名称 (`pointnet_cls` 或 `pointnet2_cls_ssg`)
|
|||
|
|
- `--num_category`: 类别数量(2 表示 good/bad)
|
|||
|
|
- `--num_point`: 每个点云的点数(默认 1024)
|
|||
|
|
- `--batch_size`: 批次大小
|
|||
|
|
- `--epoch`: 训练轮数
|
|||
|
|
- `--learning_rate`: 学习率
|
|||
|
|
- `--gpu`: GPU 设备 ID
|
|||
|
|
- `--log_dir`: 实验日志目录名称
|
|||
|
|
- `--pretrained`: 预训练模型路径(可选,用于迁移学习)
|
|||
|
|
- `--process_data`: 预处理并缓存数据(首次运行建议使用)
|
|||
|
|
- `--use_uniform_sample`: 使用 FPS 均匀采样
|
|||
|
|
|
|||
|
|
### 5. 输出
|
|||
|
|
|
|||
|
|
训练过程会保存:
|
|||
|
|
- 检查点:`log/classification/{log_dir}/checkpoints/best_model.pth`
|
|||
|
|
- 训练日志:`log/classification/{log_dir}/logs/{model_name}.txt`
|
|||
|
|
|
|||
|
|
## 注意事项
|
|||
|
|
|
|||
|
|
1. **数据格式**:确保 PLY 文件格式正确,包含 XYZ 坐标
|
|||
|
|
2. **类别标签**:good=1, bad=0
|
|||
|
|
3. **预训练模型**:如果使用预训练模型,需要确保模型架构兼容(可能需要调整分类头)
|
|||
|
|
4. **内存**:如果点云文件很大,建议使用 `--process_data` 预处理数据
|
|||
|
|
|
|||
|
|
## 示例
|
|||
|
|
|
|||
|
|
完整训练示例:
|
|||
|
|
```bash
|
|||
|
|
# 1. 准备数据集(使用 dataset.py 脚本)
|
|||
|
|
cd ../../
|
|||
|
|
python pointcloud_classifier/dataset.py \
|
|||
|
|
--source output_preview4/ \
|
|||
|
|
--output dataset/ \
|
|||
|
|
--train_ratio 0.7 \
|
|||
|
|
--val_ratio 0.15 \
|
|||
|
|
--test_ratio 0.15
|
|||
|
|
|
|||
|
|
# 2. 训练模型
|
|||
|
|
cd pointcloud_classifier/Pointnet_Pointnet2_pytorch
|
|||
|
|
python train_classification.py \
|
|||
|
|
--use_fish_dataset \
|
|||
|
|
--data_path ../../dataset/ \
|
|||
|
|
--model pointnet2_cls_ssg \
|
|||
|
|
--num_category 2 \
|
|||
|
|
--num_point 1024 \
|
|||
|
|
--batch_size 32 \
|
|||
|
|
--epoch 200 \
|
|||
|
|
--learning_rate 0.001 \
|
|||
|
|
--process_data \
|
|||
|
|
--gpu 0 \
|
|||
|
|
--log_dir fish_quality_classifier
|
|||
|
|
```
|
|||
|
|
|