# Fish Point Cloud Quality Classification 使用 PointNet/PointNet++ 进行鱼类点云质量分类(good/bad)的训练脚本。 ## 更新内容 1. **新的数据加载器** (`data_utils/FishPointCloudDataLoader.py`): - 从 PLY 文件加载点云数据 - 支持 good/bad 二分类 - 支持数据预处理和缓存 2. **更新的训练脚本** (`train_classification.py`): - 支持鱼类点云数据集 - 自动适配 2 类分类(good/bad) - 兼容原有的 ModelNet 数据集 ## 使用方法 ### 1. 准备数据集 确保数据集结构如下: ``` dataset/ ├── train/ │ ├── good/ # 高质量点云 PLY 文件 │ └── bad/ # 低质量点云 PLY 文件 ├── val/ │ ├── good/ │ └── bad/ └── test/ ├── good/ └── bad/ ``` ### 2. 训练模型 #### 使用 PointNet 训练: ```bash cd pointcloud_classifier/Pointnet_Pointnet2_pytorch python train_classification.py \ --use_fish_dataset \ --data_path ../../dataset/ \ --model pointnet_cls \ --num_category 2 \ --num_point 1024 \ --batch_size 24 \ --epoch 200 \ --learning_rate 0.001 \ --gpu 0 \ --log_dir fish_pointnet ``` #### 使用 PointNet++ 训练: ```bash python train_classification.py \ --use_fish_dataset \ --data_path ../../dataset/ \ --model pointnet2_cls_ssg \ --num_category 2 \ --num_point 1024 \ --batch_size 24 \ --epoch 200 \ --learning_rate 0.001 \ --gpu 0 \ --log_dir fish_pointnet2 ``` ### 3. 使用预训练检查点 如果已有预训练模型(如 ModelNet40 上训练的),可以使用 `--pretrained` 参数加载: ```bash # 从指定路径加载预训练模型 python train_classification.py \ --use_fish_dataset \ --data_path ../../dataset/ \ --model pointnet2_cls_ssg \ --num_category 2 \ --pretrained log/classification/pointnet2_cls_ssg/checkpoints/best_model.pth \ --log_dir fish_pointnet2_finetune # 或者从其他预训练模型加载(如 pointnet2_ssg_wo_normals) python train_classification.py \ --use_fish_dataset \ --data_path ../../dataset/ \ --model pointnet2_cls_ssg \ --num_category 2 \ --pretrained log/classification/pointnet2_ssg_wo_normals/checkpoints/best_model.pth \ --log_dir fish_pointnet2_finetune ``` **注意**: - 如果预训练模型的分类头类别数(如 40)与当前任务(2)不同,脚本会自动跳过分类头,只加载特征提取器部分 - 这样可以实现迁移学习,利用预训练的特征提取能力 ### 4. 参数说明 - `--use_fish_dataset`: 使用鱼类点云数据集(必须) - `--data_path`: 数据集根目录路径 - `--model`: 模型名称 (`pointnet_cls` 或 `pointnet2_cls_ssg`) - `--num_category`: 类别数量(2 表示 good/bad) - `--num_point`: 每个点云的点数(默认 1024) - `--batch_size`: 批次大小 - `--epoch`: 训练轮数 - `--learning_rate`: 学习率 - `--gpu`: GPU 设备 ID - `--log_dir`: 实验日志目录名称 - `--pretrained`: 预训练模型路径(可选,用于迁移学习) - `--process_data`: 预处理并缓存数据(首次运行建议使用) - `--use_uniform_sample`: 使用 FPS 均匀采样 ### 5. 输出 训练过程会保存: - 检查点:`log/classification/{log_dir}/checkpoints/best_model.pth` - 训练日志:`log/classification/{log_dir}/logs/{model_name}.txt` ## 注意事项 1. **数据格式**:确保 PLY 文件格式正确,包含 XYZ 坐标 2. **类别标签**:good=1, bad=0 3. **预训练模型**:如果使用预训练模型,需要确保模型架构兼容(可能需要调整分类头) 4. **内存**:如果点云文件很大,建议使用 `--process_data` 预处理数据 ## 示例 完整训练示例: ```bash # 1. 准备数据集(使用 dataset.py 脚本) cd ../../ python pointcloud_classifier/dataset.py \ --source output_preview4/ \ --output dataset/ \ --train_ratio 0.7 \ --val_ratio 0.15 \ --test_ratio 0.15 # 2. 训练模型 cd pointcloud_classifier/Pointnet_Pointnet2_pytorch python train_classification.py \ --use_fish_dataset \ --data_path ../../dataset/ \ --model pointnet2_cls_ssg \ --num_category 2 \ --num_point 1024 \ --batch_size 32 \ --epoch 200 \ --learning_rate 0.001 \ --process_data \ --gpu 0 \ --log_dir fish_quality_classifier ```