#!/usr/bin/env python3 """ Point Transformer model for fish weight regression from 3D point clouds. Architecture adapted from the Landmark3D Point Transformer (teeth 3D landmark prediction), tailored for scalar regression: - Encoder: hierarchical Set Abstraction + PointTransformerBlock backbone - Feature refinement: Conv1d + BN layers on coarsest-level features - Global pooling: max + avg concatenation - Regression head: MLP -> 1 scalar (weight in kg) Reuses building blocks (SetAbstraction, PointTransformerBlock) from landmark3d_native.py which provides a pure-Python, GPU-trainable implementation of FPS, ball query, and vector attention — no custom CUDA kernels required. """ from __future__ import annotations import sys from pathlib import Path from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint _dir = Path(__file__).resolve().parent if str(_dir) not in sys.path: sys.path.insert(0, str(_dir)) from landmark3d_native import SetAbstraction, PointTransformerBlock class PointTransformerBackbone(nn.Module): """Hierarchical Point Transformer encoder. 4-level Set Abstraction + PointTransformerBlock, producing increasingly abstract features at decreasing spatial resolution. Args: in_dim: Per-point input feature dimension (3 for XYZ, 6 for XYZ+normals). npoints: Downsampling schedule per level. Default [384, 128, 32, 8]. radius: Ball query radius (shared across levels). nsample: Max neighbors per ball query. """ def __init__( self, in_dim: int = 3, npoints: Optional[List[int]] = None, radius: float = 0.2, nsample: int = 32, ): super().__init__() if npoints is None: npoints = [384, 128, 32, 8] self.npoints = npoints self.radius = radius self.nsample = nsample feat_dims = [32, 64, 128, 256, 512] self.fc1 = nn.Sequential( nn.Linear(in_dim, feat_dims[0]), nn.ReLU(), nn.Linear(feat_dims[0], feat_dims[0]), ) self.transformer1 = PointTransformerBlock(feat_dims[0], 256) self.transition_downs = nn.ModuleList([ SetAbstraction( npoints[i], radius, nsample, 3 + feat_dims[i], [feat_dims[i + 1], feat_dims[i + 1]], ) for i in range(4) ]) self.transformers = nn.ModuleList([ PointTransformerBlock(feat_dims[i + 1], 256) for i in range(4) ]) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: (B, N, in_dim) input point features. Returns: feat: (B, S, 512) coarsest-level features (S = npoints[-1]). xyz: (B, S, 3) coarsest-level coordinates. """ xyz = x[:, :, :3] feat = self.fc1(x) feat = checkpoint(self.transformer1, xyz, feat, use_reentrant=False) for td, tr in zip(self.transition_downs, self.transformers): xyz_b, feat_b = td(xyz, feat) xyz = xyz_b.permute(0, 2, 1) feat = feat_b.permute(0, 2, 1) feat = tr(xyz, feat) return feat, xyz class PointTransformerWeightRegressor(nn.Module): """Point Transformer for scalar weight regression from 3D point clouds. Architecture: PointTransformerBackbone (encoder) -> Feature refinement (Conv1d + BN + ReLU) -> Global max + avg pooling -> 512-dim -> MLP regression head -> 1 scalar Args: in_dim: Per-point input feature dimension (3 for XYZ only). n_points: Number of input points per sample (default 768). """ def __init__(self, in_dim: int = 3, n_points: int = 768): super().__init__() self.in_dim = in_dim self.n_points = n_points self.backbone = PointTransformerBackbone(in_dim=in_dim) self.fc_refine = nn.Sequential( nn.Conv1d(512, 512, 1), nn.BatchNorm1d(512), nn.ReLU(), nn.Conv1d(512, 256, 1), nn.BatchNorm1d(256), nn.ReLU(), ) self.head = nn.Sequential( nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, N, in_dim) point cloud batch. Returns: (B,) predicted weight. """ feat, _xyz = self.backbone(x) # (B, S, 512), (B, S, 3) feat = feat.permute(0, 2, 1) # (B, 512, S) feat = self.fc_refine(feat) # (B, 256, S) max_feat = feat.max(dim=2)[0] # (B, 256) avg_feat = feat.mean(dim=2) # (B, 256) global_feat = torch.cat([max_feat, avg_feat], dim=1) # (B, 512) return self.head(global_feat).squeeze(-1) # (B,)