Files
FishServer/FishMeasure/pointcloud_classifier/Pointnet_Pointnet2_pytorch/README_FISH.md

4.3 KiB
Raw Blame History

Fish Point Cloud Quality Classification

使用 PointNet/PointNet++ 进行鱼类点云质量分类good/bad的训练脚本。

更新内容

  1. 新的数据加载器 (data_utils/FishPointCloudDataLoader.py):

    • 从 PLY 文件加载点云数据
    • 支持 good/bad 二分类
    • 支持数据预处理和缓存
  2. 更新的训练脚本 (train_classification.py):

    • 支持鱼类点云数据集
    • 自动适配 2 类分类good/bad
    • 兼容原有的 ModelNet 数据集

使用方法

1. 准备数据集

确保数据集结构如下:

dataset/
├── train/
│   ├── good/     # 高质量点云 PLY 文件
│   └── bad/      # 低质量点云 PLY 文件
├── val/
│   ├── good/
│   └── bad/
└── test/
    ├── good/
    └── bad/

2. 训练模型

使用 PointNet 训练:

cd pointcloud_classifier/Pointnet_Pointnet2_pytorch

python train_classification.py \
    --use_fish_dataset \
    --data_path ../../dataset/ \
    --model pointnet_cls \
    --num_category 2 \
    --num_point 1024 \
    --batch_size 24 \
    --epoch 200 \
    --learning_rate 0.001 \
    --gpu 0 \
    --log_dir fish_pointnet

使用 PointNet++ 训练:

python train_classification.py \
    --use_fish_dataset \
    --data_path ../../dataset/ \
    --model pointnet2_cls_ssg \
    --num_category 2 \
    --num_point 1024 \
    --batch_size 24 \
    --epoch 200 \
    --learning_rate 0.001 \
    --gpu 0 \
    --log_dir fish_pointnet2

3. 使用预训练检查点

如果已有预训练模型(如 ModelNet40 上训练的),可以使用 --pretrained 参数加载:

# 从指定路径加载预训练模型
python train_classification.py \
    --use_fish_dataset \
    --data_path ../../dataset/ \
    --model pointnet2_cls_ssg \
    --num_category 2 \
    --pretrained log/classification/pointnet2_cls_ssg/checkpoints/best_model.pth \
    --log_dir fish_pointnet2_finetune

# 或者从其他预训练模型加载(如 pointnet2_ssg_wo_normals
python train_classification.py \
    --use_fish_dataset \
    --data_path ../../dataset/ \
    --model pointnet2_cls_ssg \
    --num_category 2 \
    --pretrained log/classification/pointnet2_ssg_wo_normals/checkpoints/best_model.pth \
    --log_dir fish_pointnet2_finetune

注意

  • 如果预训练模型的分类头类别数(如 40与当前任务2不同脚本会自动跳过分类头只加载特征提取器部分
  • 这样可以实现迁移学习,利用预训练的特征提取能力

4. 参数说明

  • --use_fish_dataset: 使用鱼类点云数据集(必须)
  • --data_path: 数据集根目录路径
  • --model: 模型名称 (pointnet_clspointnet2_cls_ssg)
  • --num_category: 类别数量2 表示 good/bad
  • --num_point: 每个点云的点数(默认 1024
  • --batch_size: 批次大小
  • --epoch: 训练轮数
  • --learning_rate: 学习率
  • --gpu: GPU 设备 ID
  • --log_dir: 实验日志目录名称
  • --pretrained: 预训练模型路径(可选,用于迁移学习)
  • --process_data: 预处理并缓存数据(首次运行建议使用)
  • --use_uniform_sample: 使用 FPS 均匀采样

5. 输出

训练过程会保存:

  • 检查点:log/classification/{log_dir}/checkpoints/best_model.pth
  • 训练日志:log/classification/{log_dir}/logs/{model_name}.txt

注意事项

  1. 数据格式:确保 PLY 文件格式正确,包含 XYZ 坐标
  2. 类别标签good=1, bad=0
  3. 预训练模型:如果使用预训练模型,需要确保模型架构兼容(可能需要调整分类头)
  4. 内存:如果点云文件很大,建议使用 --process_data 预处理数据

示例

完整训练示例:

# 1. 准备数据集(使用 dataset.py 脚本)
cd ../../
python pointcloud_classifier/dataset.py \
    --source output_preview4/ \
    --output dataset/ \
    --train_ratio 0.7 \
    --val_ratio 0.15 \
    --test_ratio 0.15

# 2. 训练模型
cd pointcloud_classifier/Pointnet_Pointnet2_pytorch
python train_classification.py \
    --use_fish_dataset \
    --data_path ../../dataset/ \
    --model pointnet2_cls_ssg \
    --num_category 2 \
    --num_point 1024 \
    --batch_size 32 \
    --epoch 200 \
    --learning_rate 0.001 \
    --process_data \
    --gpu 0 \
    --log_dir fish_quality_classifier