fix fish weight calculation by using correct parameters. output video
This commit is contained in:
172
FishMeasure/measure/README.md
Normal file
172
FishMeasure/measure/README.md
Normal file
@@ -0,0 +1,172 @@
|
||||
# Fish Weight Prediction using PointNet++
|
||||
|
||||
This module uses PointNet++ to predict fish weight from partial point cloud data.
|
||||
|
||||
## Overview
|
||||
|
||||
We provide two approaches for fish weight prediction:
|
||||
|
||||
### Approach 1: Direct Regression (Original)
|
||||
Directly predict the absolute weight from a single point cloud.
|
||||
|
||||
### Approach 2: Comparison-Based (Recommended)
|
||||
Predict the weight **difference** between two point clouds. For a new fish:
|
||||
1. Find the closest reference fish by length (from known dataset)
|
||||
2. Predict weight difference using the trained model
|
||||
3. Calculate: `new_weight = reference_weight + predicted_difference`
|
||||
|
||||
**Why comparison-based?**
|
||||
- Relative comparisons are often more accurate than absolute predictions
|
||||
- Leverages known reference data
|
||||
- More robust to incomplete/partial point clouds
|
||||
- Better performance with small datasets (~100 samples)
|
||||
|
||||
## Workflow
|
||||
|
||||
### Approach 1: Direct Regression
|
||||
|
||||
1. **Dataset Preparation** (`dataset.py`): Prepare training data from multiple point clouds
|
||||
2. **Training** (`train_weight_regression.py`): Train PointNet++ model for weight regression
|
||||
3. **Testing/Inference** (`test_weight_regression.py`): Test the trained model on new point clouds
|
||||
|
||||
### Approach 2: Comparison-Based (Recommended)
|
||||
|
||||
1. **Dataset Preparation** (`dataset.py`): Same as Approach 1
|
||||
2. **Training** (`train_weight_comparison.py`): Train PointNet++ model to predict weight differences
|
||||
3. **Testing/Inference** (`test_weight_comparison.py`): Test using reference dataset
|
||||
|
||||
## Dataset Preparation
|
||||
|
||||
The `dataset.py` script processes point clouds from an input folder:
|
||||
|
||||
- **Input**: Folder containing multiple point cloud subfolders (e.g., `output_preview/xxxx/cloud/`)
|
||||
- **Process**:
|
||||
1. For each subfolder, find all PLY files
|
||||
2. Select the point cloud with the largest length (max x - min x)
|
||||
3. Normalize the point cloud by moving it to the center of origin (centroid = 0)
|
||||
4. Save the normalized PLY file and corresponding weight label to the output folder
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
python3 measure/dataset.py --input /path/to/pointclouds --labels /path/to/labels.csv --output /path/to/dataset
|
||||
```
|
||||
|
||||
**Label CSV Format**:
|
||||
```csv
|
||||
subfolder_name,weight
|
||||
HD1080_SN43186771_16-41-37,0.5
|
||||
HD1080_SN43186771_16-41-40,0.6
|
||||
...
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
The `train_weight_regression.py` script trains a PointNet++ model for weight regression:
|
||||
|
||||
- **Model**: PointNet++ (SSG - Single Scale Grouping) adapted for regression
|
||||
- **Data Augmentation** (for 150 samples):
|
||||
- Random point sampling (different number of points)
|
||||
- Random rotation around z-axis
|
||||
- Random scaling (small variations)
|
||||
- Random jitter (noise)
|
||||
- Random point dropout
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
python3 measure/train_weight_regression.py \
|
||||
--data_path /path/to/dataset \
|
||||
--batch_size 16 \
|
||||
--num_point 1024 \
|
||||
--epoch 200 \
|
||||
--learning_rate 0.001
|
||||
```
|
||||
|
||||
## Testing/Inference
|
||||
|
||||
The `test_weight_regression.py` script performs inference on new point clouds:
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# Test on a single point cloud
|
||||
python3 measure/test_weight_regression.py \
|
||||
--model /path/to/checkpoint.pth \
|
||||
--ply /path/to/pointcloud.ply
|
||||
|
||||
# Test on a folder of point clouds
|
||||
python3 measure/test_weight_regression.py \
|
||||
--model /path/to/checkpoint.pth \
|
||||
--folder /path/to/pointclouds \
|
||||
--output results.json
|
||||
```
|
||||
|
||||
## Comparison-Based Approach (Recommended)
|
||||
|
||||
### Training
|
||||
|
||||
Train a model to predict weight differences between point cloud pairs:
|
||||
|
||||
```bash
|
||||
python3 measure/train_weight_comparison.py \
|
||||
--data_path /path/to/dataset \
|
||||
--reference_folder /path/to/reference/dataset \
|
||||
--batch_size 8 \
|
||||
--num_point 1024 \
|
||||
--epoch 200 \
|
||||
--learning_rate 0.001 \
|
||||
--pair_strategy random # or 'length_based'
|
||||
```
|
||||
|
||||
**Pair Strategies:**
|
||||
- `random`: Random pairs (more diverse training)
|
||||
- `length_based`: Pair based on similar lengths (more realistic comparisons)
|
||||
|
||||
### Testing
|
||||
|
||||
For a new fish point cloud, find the closest reference and predict weight difference:
|
||||
|
||||
```bash
|
||||
# Test on a single point cloud
|
||||
python3 measure/test_weight_comparison.py \
|
||||
--model /path/to/checkpoint.pth \
|
||||
--reference_folder /path/to/reference/dataset \
|
||||
--ply /path/to/new_fish.ply
|
||||
|
||||
# Test on a folder of point clouds
|
||||
python3 measure/test_weight_comparison.py \
|
||||
--model /path/to/checkpoint.pth \
|
||||
--reference_folder /path/to/reference/dataset \
|
||||
--folder /path/to/new_fishes \
|
||||
--output results.json
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. Loads all reference point clouds (with known weights and lengths)
|
||||
2. For each new fish, finds the closest reference by length
|
||||
3. Predicts weight difference: `predicted_diff = model(reference_pc, new_pc)`
|
||||
4. Calculates weight: `new_weight = reference_weight + predicted_diff`
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
measure/
|
||||
├── README.md # This file
|
||||
├── dataset.py # Dataset preparation script
|
||||
├── train_weight_regression.py # Direct regression training
|
||||
├── test_weight_regression.py # Direct regression inference
|
||||
├── train_weight_comparison.py # Comparison-based training
|
||||
├── test_weight_comparison.py # Comparison-based inference
|
||||
├── pointnet2_regression.py # Direct regression model
|
||||
├── pointnet2_comparison.py # Comparison model
|
||||
├── data_loader.py # Direct regression data loader
|
||||
├── data_loader_comparison.py # Comparison data loader
|
||||
└── data/ # Data folder (for OCR results, etc.)
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- **Direct Regression**: Predicts absolute weight from a single point cloud
|
||||
- **Comparison-Based**: Predicts weight difference between two point clouds (recommended for small datasets)
|
||||
- Point clouds are normalized to the origin before training/inference
|
||||
- Data augmentation is crucial given the small dataset size (~100-150 samples)
|
||||
- The comparison model uses shared PointNet++ encoder for both point clouds, then concatenates features
|
||||
|
||||
Reference in New Issue
Block a user