5.8 KiB
5.8 KiB
Fish Weight Prediction using PointNet++
This module uses PointNet++ to predict fish weight from partial point cloud data.
Overview
We provide two approaches for fish weight prediction:
Approach 1: Direct Regression (Original)
Directly predict the absolute weight from a single point cloud.
Approach 2: Comparison-Based (Recommended)
Predict the weight difference between two point clouds. For a new fish:
- Find the closest reference fish by length (from known dataset)
- Predict weight difference using the trained model
- Calculate:
new_weight = reference_weight + predicted_difference
Why comparison-based?
- Relative comparisons are often more accurate than absolute predictions
- Leverages known reference data
- More robust to incomplete/partial point clouds
- Better performance with small datasets (~100 samples)
Workflow
Approach 1: Direct Regression
- Dataset Preparation (
dataset.py): Prepare training data from multiple point clouds - Training (
train_weight_regression.py): Train PointNet++ model for weight regression - Testing/Inference (
test_weight_regression.py): Test the trained model on new point clouds
Approach 2: Comparison-Based (Recommended)
- Dataset Preparation (
dataset.py): Same as Approach 1 - Training (
train_weight_comparison.py): Train PointNet++ model to predict weight differences - Testing/Inference (
test_weight_comparison.py): Test using reference dataset
Dataset Preparation
The dataset.py script processes point clouds from an input folder:
- Input: Folder containing multiple point cloud subfolders (e.g.,
output_preview/xxxx/cloud/) - Process:
- For each subfolder, find all PLY files
- Select the point cloud with the largest length (max x - min x)
- Normalize the point cloud by moving it to the center of origin (centroid = 0)
- Save the normalized PLY file and corresponding weight label to the output folder
Usage:
python3 measure/dataset.py --input /path/to/pointclouds --labels /path/to/labels.csv --output /path/to/dataset
Label CSV Format:
subfolder_name,weight
HD1080_SN43186771_16-41-37,0.5
HD1080_SN43186771_16-41-40,0.6
...
Training
The train_weight_regression.py script trains a PointNet++ model for weight regression:
- Model: PointNet++ (SSG - Single Scale Grouping) adapted for regression
- Data Augmentation (for 150 samples):
- Random point sampling (different number of points)
- Random rotation around z-axis
- Random scaling (small variations)
- Random jitter (noise)
- Random point dropout
Usage:
python3 measure/train_weight_regression.py \
--data_path /path/to/dataset \
--batch_size 16 \
--num_point 1024 \
--epoch 200 \
--learning_rate 0.001
Testing/Inference
The test_weight_regression.py script performs inference on new point clouds:
Usage:
# Test on a single point cloud
python3 measure/test_weight_regression.py \
--model /path/to/checkpoint.pth \
--ply /path/to/pointcloud.ply
# Test on a folder of point clouds
python3 measure/test_weight_regression.py \
--model /path/to/checkpoint.pth \
--folder /path/to/pointclouds \
--output results.json
Comparison-Based Approach (Recommended)
Training
Train a model to predict weight differences between point cloud pairs:
python3 measure/train_weight_comparison.py \
--data_path /path/to/dataset \
--reference_folder /path/to/reference/dataset \
--batch_size 8 \
--num_point 1024 \
--epoch 200 \
--learning_rate 0.001 \
--pair_strategy random # or 'length_based'
Pair Strategies:
random: Random pairs (more diverse training)length_based: Pair based on similar lengths (more realistic comparisons)
Testing
For a new fish point cloud, find the closest reference and predict weight difference:
# Test on a single point cloud
python3 measure/test_weight_comparison.py \
--model /path/to/checkpoint.pth \
--reference_folder /path/to/reference/dataset \
--ply /path/to/new_fish.ply
# Test on a folder of point clouds
python3 measure/test_weight_comparison.py \
--model /path/to/checkpoint.pth \
--reference_folder /path/to/reference/dataset \
--folder /path/to/new_fishes \
--output results.json
How it works:
- Loads all reference point clouds (with known weights and lengths)
- For each new fish, finds the closest reference by length
- Predicts weight difference:
predicted_diff = model(reference_pc, new_pc) - Calculates weight:
new_weight = reference_weight + predicted_diff
File Structure
measure/
├── README.md # This file
├── dataset.py # Dataset preparation script
├── train_weight_regression.py # Direct regression training
├── test_weight_regression.py # Direct regression inference
├── train_weight_comparison.py # Comparison-based training
├── test_weight_comparison.py # Comparison-based inference
├── pointnet2_regression.py # Direct regression model
├── pointnet2_comparison.py # Comparison model
├── data_loader.py # Direct regression data loader
├── data_loader_comparison.py # Comparison data loader
└── data/ # Data folder (for OCR results, etc.)
Notes
- Direct Regression: Predicts absolute weight from a single point cloud
- Comparison-Based: Predicts weight difference between two point clouds (recommended for small datasets)
- Point clouds are normalized to the origin before training/inference
- Data augmentation is crucial given the small dataset size (~100-150 samples)
- The comparison model uses shared PointNet++ encoder for both point clouds, then concatenates features