#!/usr/bin/env python3 """ DGCNN (Dynamic Graph CNN) model for fish weight regression from 3D point clouds. Uses EdgeConv layers with k-NN dynamic graph construction. Often works well with limited data due to simpler architecture and fewer parameters than Point Transformer, while capturing local geometry better than vanilla PointNet. Architecture (adapted from Wang et al. "Dynamic Graph CNN for Learning on Point Clouds"): - 4 EdgeConv layers with k-NN (k=20) - Concat features from all layers -> 1x1 conv -> emb_dims - Global max + avg pooling -> MLP -> 1 scalar (kg) Input: (B, N, 3) point cloud (same as Point Transformer). """ from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F def knn(x: torch.Tensor, k: int) -> torch.Tensor: """ Batch k-NN. x: (B, C, N). Returns idx: (B, N, k) indices of k nearest neighbors per point. """ inner = -2 * torch.matmul(x.transpose(2, 1), x) xx = torch.sum(x ** 2, dim=1, keepdim=True) pairwise_distance = -xx - inner - xx.transpose(2, 1) idx = pairwise_distance.topk(k=k, dim=-1)[1] # (B, N, k) return idx def get_graph_feature(x: torch.Tensor, k: int = 20, idx: torch.Tensor | None = None) -> torch.Tensor: """ Build edge features for EdgeConv. x: (B, C, N). Returns (B, 2*C, N, k) - [x_j - x_i, x_i] per edge. """ B, C, N = x.shape if idx is None: idx = knn(x, k=k) # (B, N, k) device = x.device idx_base = torch.arange(0, B, device=device).view(-1, 1, 1) * N idx = idx + idx_base idx = idx.view(-1) x = x.transpose(2, 1).contiguous() # (B, N, C) feature = x.view(B * N, -1)[idx, :] feature = feature.view(B, N, k, C) x = x.view(B, N, 1, C).repeat(1, 1, k, 1) feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous() # (B, 2*C, N, k) return feature class DGCNNWeightRegressor(nn.Module): """DGCNN for scalar weight regression from 3D point clouds. Args: k: Number of k-NN neighbors (default: 20) emb_dims: Embedding dimension before head (default: 1024) dropout: Dropout rate (default: 0.5) """ def __init__(self, k: int = 20, emb_dims: int = 1024, dropout: float = 0.5): super().__init__() self.k = k self.emb_dims = emb_dims self.bn1 = nn.BatchNorm2d(64) self.bn2 = nn.BatchNorm2d(64) self.bn3 = nn.BatchNorm2d(128) self.bn4 = nn.BatchNorm2d(256) self.bn5 = nn.BatchNorm1d(emb_dims) self.conv1 = nn.Sequential( nn.Conv2d(6, 64, kernel_size=1, bias=False), self.bn1, nn.LeakyReLU(negative_slope=0.2), ) self.conv2 = nn.Sequential( nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False), self.bn2, nn.LeakyReLU(negative_slope=0.2), ) self.conv3 = nn.Sequential( nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False), self.bn3, nn.LeakyReLU(negative_slope=0.2), ) self.conv4 = nn.Sequential( nn.Conv2d(128 * 2, 256, kernel_size=1, bias=False), self.bn4, nn.LeakyReLU(negative_slope=0.2), ) self.conv5 = nn.Sequential( nn.Conv1d(512, emb_dims, kernel_size=1, bias=False), self.bn5, nn.LeakyReLU(negative_slope=0.2), ) self.linear1 = nn.Linear(emb_dims * 2, 512, bias=False) self.bn6 = nn.BatchNorm1d(512) self.dp1 = nn.Dropout(p=dropout) self.linear2 = nn.Linear(512, 256) self.bn7 = nn.BatchNorm1d(256) self.dp2 = nn.Dropout(p=dropout) self.linear3 = nn.Linear(256, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, N, 3) point cloud batch. Returns: (B,) predicted weight in kg. """ # DGCNN expects (B, C, N) x = x.transpose(1, 2).contiguous() # (B, 3, N) B = x.size(0) x = get_graph_feature(x, k=self.k) x = self.conv1(x) x1 = x.max(dim=-1, keepdim=False)[0] # (B, 64, N) x = get_graph_feature(x1, k=self.k) x = self.conv2(x) x2 = x.max(dim=-1, keepdim=False)[0] # (B, 64, N) x = get_graph_feature(x2, k=self.k) x = self.conv3(x) x3 = x.max(dim=-1, keepdim=False)[0] # (B, 128, N) x = get_graph_feature(x3, k=self.k) x = self.conv4(x) x4 = x.max(dim=-1, keepdim=False)[0] # (B, 256, N) x = torch.cat((x1, x2, x3, x4), dim=1) # (B, 512, N) x = self.conv5(x) # (B, emb_dims, N) x1 = F.adaptive_max_pool1d(x, 1).view(B, -1) x2 = F.adaptive_avg_pool1d(x, 1).view(B, -1) x = torch.cat((x1, x2), 1) # (B, emb_dims*2) x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) x = self.dp1(x) x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) x = self.dp2(x) x = self.linear3(x).squeeze(-1) return x