Files
FishServer/FishMeasure/weight_estimator/dgcnn_weight_model.py

153 lines
5.0 KiB
Python

#!/usr/bin/env python3
"""
DGCNN (Dynamic Graph CNN) model for fish weight regression from 3D point clouds.
Uses EdgeConv layers with k-NN dynamic graph construction. Often works well
with limited data due to simpler architecture and fewer parameters than
Point Transformer, while capturing local geometry better than vanilla PointNet.
Architecture (adapted from Wang et al. "Dynamic Graph CNN for Learning on Point Clouds"):
- 4 EdgeConv layers with k-NN (k=20)
- Concat features from all layers -> 1x1 conv -> emb_dims
- Global max + avg pooling -> MLP -> 1 scalar (kg)
Input: (B, N, 3) point cloud (same as Point Transformer).
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
def knn(x: torch.Tensor, k: int) -> torch.Tensor:
"""
Batch k-NN. x: (B, C, N).
Returns idx: (B, N, k) indices of k nearest neighbors per point.
"""
inner = -2 * torch.matmul(x.transpose(2, 1), x)
xx = torch.sum(x ** 2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1)
idx = pairwise_distance.topk(k=k, dim=-1)[1] # (B, N, k)
return idx
def get_graph_feature(x: torch.Tensor, k: int = 20, idx: torch.Tensor | None = None) -> torch.Tensor:
"""
Build edge features for EdgeConv.
x: (B, C, N). Returns (B, 2*C, N, k) - [x_j - x_i, x_i] per edge.
"""
B, C, N = x.shape
if idx is None:
idx = knn(x, k=k) # (B, N, k)
device = x.device
idx_base = torch.arange(0, B, device=device).view(-1, 1, 1) * N
idx = idx + idx_base
idx = idx.view(-1)
x = x.transpose(2, 1).contiguous() # (B, N, C)
feature = x.view(B * N, -1)[idx, :]
feature = feature.view(B, N, k, C)
x = x.view(B, N, 1, C).repeat(1, 1, k, 1)
feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous() # (B, 2*C, N, k)
return feature
class DGCNNWeightRegressor(nn.Module):
"""DGCNN for scalar weight regression from 3D point clouds.
Args:
k: Number of k-NN neighbors (default: 20)
emb_dims: Embedding dimension before head (default: 1024)
dropout: Dropout rate (default: 0.5)
"""
def __init__(self, k: int = 20, emb_dims: int = 1024, dropout: float = 0.5):
super().__init__()
self.k = k
self.emb_dims = emb_dims
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.bn4 = nn.BatchNorm2d(256)
self.bn5 = nn.BatchNorm1d(emb_dims)
self.conv1 = nn.Sequential(
nn.Conv2d(6, 64, kernel_size=1, bias=False),
self.bn1,
nn.LeakyReLU(negative_slope=0.2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False),
self.bn2,
nn.LeakyReLU(negative_slope=0.2),
)
self.conv3 = nn.Sequential(
nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False),
self.bn3,
nn.LeakyReLU(negative_slope=0.2),
)
self.conv4 = nn.Sequential(
nn.Conv2d(128 * 2, 256, kernel_size=1, bias=False),
self.bn4,
nn.LeakyReLU(negative_slope=0.2),
)
self.conv5 = nn.Sequential(
nn.Conv1d(512, emb_dims, kernel_size=1, bias=False),
self.bn5,
nn.LeakyReLU(negative_slope=0.2),
)
self.linear1 = nn.Linear(emb_dims * 2, 512, bias=False)
self.bn6 = nn.BatchNorm1d(512)
self.dp1 = nn.Dropout(p=dropout)
self.linear2 = nn.Linear(512, 256)
self.bn7 = nn.BatchNorm1d(256)
self.dp2 = nn.Dropout(p=dropout)
self.linear3 = nn.Linear(256, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, N, 3) point cloud batch.
Returns:
(B,) predicted weight in kg.
"""
# DGCNN expects (B, C, N)
x = x.transpose(1, 2).contiguous() # (B, 3, N)
B = x.size(0)
x = get_graph_feature(x, k=self.k)
x = self.conv1(x)
x1 = x.max(dim=-1, keepdim=False)[0] # (B, 64, N)
x = get_graph_feature(x1, k=self.k)
x = self.conv2(x)
x2 = x.max(dim=-1, keepdim=False)[0] # (B, 64, N)
x = get_graph_feature(x2, k=self.k)
x = self.conv3(x)
x3 = x.max(dim=-1, keepdim=False)[0] # (B, 128, N)
x = get_graph_feature(x3, k=self.k)
x = self.conv4(x)
x4 = x.max(dim=-1, keepdim=False)[0] # (B, 256, N)
x = torch.cat((x1, x2, x3, x4), dim=1) # (B, 512, N)
x = self.conv5(x) # (B, emb_dims, N)
x1 = F.adaptive_max_pool1d(x, 1).view(B, -1)
x2 = F.adaptive_avg_pool1d(x, 1).view(B, -1)
x = torch.cat((x1, x2), 1) # (B, emb_dims*2)
x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)
x = self.dp1(x)
x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)
x = self.dp2(x)
x = self.linear3(x).squeeze(-1)
return x