153 lines
5.0 KiB
Python
153 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
DGCNN (Dynamic Graph CNN) model for fish weight regression from 3D point clouds.
|
|
|
|
Uses EdgeConv layers with k-NN dynamic graph construction. Often works well
|
|
with limited data due to simpler architecture and fewer parameters than
|
|
Point Transformer, while capturing local geometry better than vanilla PointNet.
|
|
|
|
Architecture (adapted from Wang et al. "Dynamic Graph CNN for Learning on Point Clouds"):
|
|
- 4 EdgeConv layers with k-NN (k=20)
|
|
- Concat features from all layers -> 1x1 conv -> emb_dims
|
|
- Global max + avg pooling -> MLP -> 1 scalar (kg)
|
|
|
|
Input: (B, N, 3) point cloud (same as Point Transformer).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def knn(x: torch.Tensor, k: int) -> torch.Tensor:
|
|
"""
|
|
Batch k-NN. x: (B, C, N).
|
|
Returns idx: (B, N, k) indices of k nearest neighbors per point.
|
|
"""
|
|
inner = -2 * torch.matmul(x.transpose(2, 1), x)
|
|
xx = torch.sum(x ** 2, dim=1, keepdim=True)
|
|
pairwise_distance = -xx - inner - xx.transpose(2, 1)
|
|
idx = pairwise_distance.topk(k=k, dim=-1)[1] # (B, N, k)
|
|
return idx
|
|
|
|
|
|
def get_graph_feature(x: torch.Tensor, k: int = 20, idx: torch.Tensor | None = None) -> torch.Tensor:
|
|
"""
|
|
Build edge features for EdgeConv.
|
|
x: (B, C, N). Returns (B, 2*C, N, k) - [x_j - x_i, x_i] per edge.
|
|
"""
|
|
B, C, N = x.shape
|
|
if idx is None:
|
|
idx = knn(x, k=k) # (B, N, k)
|
|
|
|
device = x.device
|
|
idx_base = torch.arange(0, B, device=device).view(-1, 1, 1) * N
|
|
idx = idx + idx_base
|
|
idx = idx.view(-1)
|
|
|
|
x = x.transpose(2, 1).contiguous() # (B, N, C)
|
|
feature = x.view(B * N, -1)[idx, :]
|
|
feature = feature.view(B, N, k, C)
|
|
x = x.view(B, N, 1, C).repeat(1, 1, k, 1)
|
|
feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous() # (B, 2*C, N, k)
|
|
return feature
|
|
|
|
|
|
class DGCNNWeightRegressor(nn.Module):
|
|
"""DGCNN for scalar weight regression from 3D point clouds.
|
|
|
|
Args:
|
|
k: Number of k-NN neighbors (default: 20)
|
|
emb_dims: Embedding dimension before head (default: 1024)
|
|
dropout: Dropout rate (default: 0.5)
|
|
"""
|
|
|
|
def __init__(self, k: int = 20, emb_dims: int = 1024, dropout: float = 0.5):
|
|
super().__init__()
|
|
self.k = k
|
|
self.emb_dims = emb_dims
|
|
|
|
self.bn1 = nn.BatchNorm2d(64)
|
|
self.bn2 = nn.BatchNorm2d(64)
|
|
self.bn3 = nn.BatchNorm2d(128)
|
|
self.bn4 = nn.BatchNorm2d(256)
|
|
self.bn5 = nn.BatchNorm1d(emb_dims)
|
|
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv2d(6, 64, kernel_size=1, bias=False),
|
|
self.bn1,
|
|
nn.LeakyReLU(negative_slope=0.2),
|
|
)
|
|
self.conv2 = nn.Sequential(
|
|
nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False),
|
|
self.bn2,
|
|
nn.LeakyReLU(negative_slope=0.2),
|
|
)
|
|
self.conv3 = nn.Sequential(
|
|
nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False),
|
|
self.bn3,
|
|
nn.LeakyReLU(negative_slope=0.2),
|
|
)
|
|
self.conv4 = nn.Sequential(
|
|
nn.Conv2d(128 * 2, 256, kernel_size=1, bias=False),
|
|
self.bn4,
|
|
nn.LeakyReLU(negative_slope=0.2),
|
|
)
|
|
self.conv5 = nn.Sequential(
|
|
nn.Conv1d(512, emb_dims, kernel_size=1, bias=False),
|
|
self.bn5,
|
|
nn.LeakyReLU(negative_slope=0.2),
|
|
)
|
|
|
|
self.linear1 = nn.Linear(emb_dims * 2, 512, bias=False)
|
|
self.bn6 = nn.BatchNorm1d(512)
|
|
self.dp1 = nn.Dropout(p=dropout)
|
|
self.linear2 = nn.Linear(512, 256)
|
|
self.bn7 = nn.BatchNorm1d(256)
|
|
self.dp2 = nn.Dropout(p=dropout)
|
|
self.linear3 = nn.Linear(256, 1)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x: (B, N, 3) point cloud batch.
|
|
Returns:
|
|
(B,) predicted weight in kg.
|
|
"""
|
|
# DGCNN expects (B, C, N)
|
|
x = x.transpose(1, 2).contiguous() # (B, 3, N)
|
|
|
|
B = x.size(0)
|
|
|
|
x = get_graph_feature(x, k=self.k)
|
|
x = self.conv1(x)
|
|
x1 = x.max(dim=-1, keepdim=False)[0] # (B, 64, N)
|
|
|
|
x = get_graph_feature(x1, k=self.k)
|
|
x = self.conv2(x)
|
|
x2 = x.max(dim=-1, keepdim=False)[0] # (B, 64, N)
|
|
|
|
x = get_graph_feature(x2, k=self.k)
|
|
x = self.conv3(x)
|
|
x3 = x.max(dim=-1, keepdim=False)[0] # (B, 128, N)
|
|
|
|
x = get_graph_feature(x3, k=self.k)
|
|
x = self.conv4(x)
|
|
x4 = x.max(dim=-1, keepdim=False)[0] # (B, 256, N)
|
|
|
|
x = torch.cat((x1, x2, x3, x4), dim=1) # (B, 512, N)
|
|
x = self.conv5(x) # (B, emb_dims, N)
|
|
|
|
x1 = F.adaptive_max_pool1d(x, 1).view(B, -1)
|
|
x2 = F.adaptive_avg_pool1d(x, 1).view(B, -1)
|
|
x = torch.cat((x1, x2), 1) # (B, emb_dims*2)
|
|
|
|
x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)
|
|
x = self.dp1(x)
|
|
x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)
|
|
x = self.dp2(x)
|
|
x = self.linear3(x).squeeze(-1)
|
|
return x
|