#!/usr/bin/env python3 """ Python-native Landmark3D Point Transformer - GPU trainable. Replicates the landmark3d.pt architecture with proper device handling. """ from __future__ import annotations from pathlib import Path from typing import List, Optional, Tuple import torch from torch.utils.checkpoint import checkpoint import torch.nn as nn import torch.nn.functional as F # Device-aware point cloud utils (from Pointnet2) def _square_distance(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor: B, N, _ = src.shape _, M, _ = dst.shape dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(dst ** 2, -1).view(B, 1, M) return dist def _index_points(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = torch.arange(B, dtype=torch.long, device=device).view(view_shape).repeat(repeat_shape) return points[batch_indices, idx, :] def _farthest_point_sample(xyz: torch.Tensor, npoint: int) -> torch.Tensor: device = xyz.device B, N, C = xyz.shape centroids = torch.zeros(B, npoint, dtype=torch.long, device=device) distance = torch.ones(B, N, device=device) * 1e10 farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device) batch_indices = torch.arange(B, dtype=torch.long, device=device) for i in range(npoint): centroids[:, i] = farthest centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) dist = torch.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = torch.max(distance, -1)[1] return centroids def _query_ball_point(radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: device = xyz.device B, N, C = xyz.shape _, S, _ = new_xyz.shape group_idx = torch.arange(N, dtype=torch.long, device=device).view(1, 1, N).repeat([B, S, 1]) sqrdists = _square_distance(new_xyz, xyz) group_idx[sqrdists > radius ** 2] = N group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) mask = group_idx == N group_idx[mask] = group_first[mask] return group_idx def _sample_and_group(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: B, N, C = xyz.shape S = npoint fps_idx = _farthest_point_sample(xyz, npoint) new_xyz = _index_points(xyz, fps_idx) idx = _query_ball_point(radius, nsample, xyz, new_xyz) grouped_xyz = _index_points(xyz, idx) grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) grouped_points = _index_points(points, idx) new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) return new_xyz, new_points class SetAbstraction(nn.Module): """Set Abstraction: FPS + ball query + MLP.""" def __init__(self, npoint: int, radius: float, nsample: int, in_channel: int, mlp: List[int]): super().__init__() self.npoint = npoint self.radius = radius self.nsample = nsample self.mlp_convs = nn.ModuleList() self.mlp_bns = nn.ModuleList() last_channel = in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) self.mlp_bns.append(nn.BatchNorm2d(out_channel)) last_channel = out_channel def forward(self, xyz: torch.Tensor, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # xyz (B,N,3), points (B,N,C) new_xyz, new_points = _sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) # new_points (B,npoint,nsample,in_channel) -> (B,in_channel,nsample,npoint) new_points = new_points.permute(0, 3, 2, 1) for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] new_points = F.relu(bn(conv(new_points))) new_points = torch.max(new_points, 2)[0] # (B,out,N) new_xyz = new_xyz.permute(0, 2, 1) # (B,3,N) return new_xyz, new_points class TransitionUp(nn.Module): """Transition Up: upsample if needed, concat skip + MLP.""" def __init__(self, in_channel: int, out_channel: int): super().__init__() self.fc1 = nn.Sequential(nn.Linear(in_channel, out_channel), nn.ReLU()) self.fc2 = nn.Sequential(nn.Linear(out_channel, out_channel), nn.ReLU()) def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: # x (B,C1,N1), skip (B,C2,N2), upsample x to N2 if needed, concat -> (B,C1+C2,N2) B, C1, N1 = x.shape _, C2, N2 = skip.shape if N1 != N2: x = F.interpolate(x, size=N2, mode="linear", align_corners=False) x = x.permute(0, 2, 1) skip = skip.permute(0, 2, 1) out = torch.cat([x, skip], dim=-1) out = self.fc1(out) out = self.fc2(out) return out.permute(0, 2, 1) class PointTransformerBlock(nn.Module): """Point Transformer block with vector attention. Uses chunked attention for large N to avoid OOM.""" CHUNK_SIZE = 48 # Process queries in chunks to limit (B, chunk, N, D) memory; reduce if OOM def __init__(self, in_dim: int, out_dim: int = 256): super().__init__() self.in_dim = in_dim self.out_dim = out_dim self.fc1 = nn.Linear(in_dim, out_dim) self.fc2 = nn.Linear(out_dim, in_dim) self.fc_delta = nn.Sequential( nn.Linear(3, out_dim), nn.ReLU(), nn.Linear(out_dim, out_dim), ) self.fc_gamma = nn.Sequential( nn.Linear(out_dim, out_dim), nn.ReLU(), nn.Linear(out_dim, out_dim), ) self.w_qs = nn.Linear(out_dim, out_dim) self.w_ks = nn.Linear(out_dim, out_dim) self.w_vs = nn.Linear(out_dim, out_dim) def forward( self, xyz: torch.Tensor, feat: torch.Tensor, ) -> torch.Tensor: # xyz (B,N,3), feat (B,N,C) B, N, _ = feat.shape q = self.w_qs(self.fc1(feat)) # (B,N,D) k = self.w_ks(self.fc1(feat)) v = self.w_vs(self.fc1(feat)) chunk_size = self.CHUNK_SIZE if N <= chunk_size: # Small N: full attention xyz_i = xyz.unsqueeze(2) xyz_j = xyz.unsqueeze(1) delta_xyz = (xyz_i - xyz_j).reshape(B, N, N, 3) delta = self.fc_delta(delta_xyz) attn_logits = (q.unsqueeze(2) - k.unsqueeze(1) + delta) / (self.out_dim ** 0.5) attn = F.softmax(attn_logits, dim=2) out = (attn * (v.unsqueeze(1) + delta)).sum(dim=2) else: # Large N: chunked attention to avoid OOM out_list = [] for start in range(0, N, chunk_size): end = min(start + chunk_size, N) q_chunk = q[:, start:end] # (B, chunk, D) xyz_chunk = xyz[:, start:end] # (B, chunk, 3) xyz_i = xyz_chunk.unsqueeze(2) # (B, chunk, 1, 3) xyz_j = xyz.unsqueeze(1) # (B, 1, N, 3) delta_xyz = (xyz_i - xyz_j).reshape(B, end - start, N, 3) delta = self.fc_delta(delta_xyz) # (B, chunk, N, D) attn_logits = (q_chunk.unsqueeze(2) - k.unsqueeze(1) + delta) / (self.out_dim ** 0.5) attn = F.softmax(attn_logits, dim=2) out_chunk = (attn * (v.unsqueeze(1) + delta)).sum(dim=2) out_list.append(out_chunk) out = torch.cat(out_list, dim=1) out = self.fc2(out) + feat return out class Backbone(nn.Module): def __init__(self): super().__init__() self.npoints = [384, 128, 32, 8] self.radius = 0.2 self.nsample = 32 self.fc1 = nn.Sequential(nn.Linear(34, 32), nn.ReLU(), nn.Linear(32, 32)) self.transformer1 = PointTransformerBlock(32, 256) self.transformers = nn.ModuleList([ PointTransformerBlock(64, 256), PointTransformerBlock(128, 256), PointTransformerBlock(256, 256), PointTransformerBlock(512, 256), ]) self.transition_downs = nn.ModuleList([ SetAbstraction(self.npoints[0], self.radius, self.nsample, 35, [64, 64]), SetAbstraction(self.npoints[1], self.radius, self.nsample, 67, [128, 128]), SetAbstraction(self.npoints[2], self.radius, self.nsample, 131, [256, 256]), SetAbstraction(self.npoints[3], self.radius, self.nsample, 259, [512, 512]), ]) def forward(self, x: torch.Tensor): xyz = x[:, :, :3] feat = self.fc1(x) feat = checkpoint(self.transformer1, xyz, feat, use_reentrant=False) skips_xyz, skips_feat = [], [] for i, (td, tr) in enumerate(zip(self.transition_downs, self.transformers)): xyz_b, feat_b = td(xyz, feat) xyz_perm = xyz_b.permute(0, 2, 1) feat_perm = feat_b.permute(0, 2, 1) feat_b = tr(xyz_perm, feat_perm) skips_xyz.append(xyz_perm) skips_feat.append(feat_b) xyz, feat = xyz_perm, feat_b return feat, skips_xyz, skips_feat class Landmark3DNative(nn.Module): """Python-native Landmark3D - GPU trainable.""" def __init__(self, n_points: int = 768): super().__init__() self.n_points = n_points self.backbone = Backbone() self.fc2 = nn.Sequential( nn.Conv1d(512, 512, 1), nn.BatchNorm1d(512), nn.ReLU(), nn.Conv1d(512, 512, 1), nn.BatchNorm1d(512), nn.ReLU(), nn.Conv1d(512, 512, 1), ) self.transformer2 = PointTransformerBlock(512, 256) self.transformers = nn.ModuleList([ PointTransformerBlock(256, 256), PointTransformerBlock(128, 256), PointTransformerBlock(64, 256), PointTransformerBlock(32, 256), ]) self.transition_ups = nn.ModuleList([ TransitionUp(512 + 512, 256), TransitionUp(256 + 256, 128), TransitionUp(128 + 128, 64), TransitionUp(64 + 64, 32), ]) self.fc3 = nn.Sequential( nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 44), ) def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, _ = x.shape feat, skips_xyz, skips_feat = self.backbone(x) xyz = skips_xyz[-1] feat = feat.permute(0, 2, 1) feat = self.fc2(feat) feat = self.transformer2(xyz, feat.permute(0, 2, 1)) feat = feat.permute(0, 2, 1) for i, (tu, tr) in enumerate(zip(self.transition_ups, self.transformers)): skip_xyz = skips_xyz[3 - i] skip_feat = skips_feat[3 - i] feat = tu(feat, skip_feat.permute(0, 2, 1)) feat = tr(skip_xyz, feat.permute(0, 2, 1)) feat = feat.permute(0, 2, 1) out = self.fc3(feat.permute(0, 2, 1)) weights = F.softmax(out[:, :, :11], dim=1) offset = out[:, :, 11:].reshape(B, -1, 11, 3) input_xyz = skips_xyz[0].unsqueeze(2) landmarks = ((input_xyz + offset) * weights.unsqueeze(-1)).sum(dim=1) return landmarks