295 lines
11 KiB
Python
295 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Python-native Landmark3D Point Transformer - GPU trainable.
|
|
|
|
Replicates the landmark3d.pt architecture with proper device handling.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch.utils.checkpoint import checkpoint
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# Device-aware point cloud utils (from Pointnet2)
|
|
def _square_distance(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
|
|
B, N, _ = src.shape
|
|
_, M, _ = dst.shape
|
|
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
|
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
|
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
|
return dist
|
|
|
|
|
|
def _index_points(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
device = points.device
|
|
B = points.shape[0]
|
|
view_shape = list(idx.shape)
|
|
view_shape[1:] = [1] * (len(view_shape) - 1)
|
|
repeat_shape = list(idx.shape)
|
|
repeat_shape[0] = 1
|
|
batch_indices = torch.arange(B, dtype=torch.long, device=device).view(view_shape).repeat(repeat_shape)
|
|
return points[batch_indices, idx, :]
|
|
|
|
|
|
def _farthest_point_sample(xyz: torch.Tensor, npoint: int) -> torch.Tensor:
|
|
device = xyz.device
|
|
B, N, C = xyz.shape
|
|
centroids = torch.zeros(B, npoint, dtype=torch.long, device=device)
|
|
distance = torch.ones(B, N, device=device) * 1e10
|
|
farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device)
|
|
batch_indices = torch.arange(B, dtype=torch.long, device=device)
|
|
for i in range(npoint):
|
|
centroids[:, i] = farthest
|
|
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
|
|
dist = torch.sum((xyz - centroid) ** 2, -1)
|
|
mask = dist < distance
|
|
distance[mask] = dist[mask]
|
|
farthest = torch.max(distance, -1)[1]
|
|
return centroids
|
|
|
|
|
|
def _query_ball_point(radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
|
|
device = xyz.device
|
|
B, N, C = xyz.shape
|
|
_, S, _ = new_xyz.shape
|
|
group_idx = torch.arange(N, dtype=torch.long, device=device).view(1, 1, N).repeat([B, S, 1])
|
|
sqrdists = _square_distance(new_xyz, xyz)
|
|
group_idx[sqrdists > radius ** 2] = N
|
|
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
|
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
|
mask = group_idx == N
|
|
group_idx[mask] = group_first[mask]
|
|
return group_idx
|
|
|
|
|
|
def _sample_and_group(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
B, N, C = xyz.shape
|
|
S = npoint
|
|
fps_idx = _farthest_point_sample(xyz, npoint)
|
|
new_xyz = _index_points(xyz, fps_idx)
|
|
idx = _query_ball_point(radius, nsample, xyz, new_xyz)
|
|
grouped_xyz = _index_points(xyz, idx)
|
|
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
|
|
grouped_points = _index_points(points, idx)
|
|
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1)
|
|
return new_xyz, new_points
|
|
|
|
|
|
class SetAbstraction(nn.Module):
|
|
"""Set Abstraction: FPS + ball query + MLP."""
|
|
|
|
def __init__(self, npoint: int, radius: float, nsample: int, in_channel: int, mlp: List[int]):
|
|
super().__init__()
|
|
self.npoint = npoint
|
|
self.radius = radius
|
|
self.nsample = nsample
|
|
self.mlp_convs = nn.ModuleList()
|
|
self.mlp_bns = nn.ModuleList()
|
|
last_channel = in_channel
|
|
for out_channel in mlp:
|
|
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
|
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
|
last_channel = out_channel
|
|
|
|
def forward(self, xyz: torch.Tensor, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# xyz (B,N,3), points (B,N,C)
|
|
new_xyz, new_points = _sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
|
|
# new_points (B,npoint,nsample,in_channel) -> (B,in_channel,nsample,npoint)
|
|
new_points = new_points.permute(0, 3, 2, 1)
|
|
for i, conv in enumerate(self.mlp_convs):
|
|
bn = self.mlp_bns[i]
|
|
new_points = F.relu(bn(conv(new_points)))
|
|
new_points = torch.max(new_points, 2)[0] # (B,out,N)
|
|
new_xyz = new_xyz.permute(0, 2, 1) # (B,3,N)
|
|
return new_xyz, new_points
|
|
|
|
|
|
class TransitionUp(nn.Module):
|
|
"""Transition Up: upsample if needed, concat skip + MLP."""
|
|
|
|
def __init__(self, in_channel: int, out_channel: int):
|
|
super().__init__()
|
|
self.fc1 = nn.Sequential(nn.Linear(in_channel, out_channel), nn.ReLU())
|
|
self.fc2 = nn.Sequential(nn.Linear(out_channel, out_channel), nn.ReLU())
|
|
|
|
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
|
|
# x (B,C1,N1), skip (B,C2,N2), upsample x to N2 if needed, concat -> (B,C1+C2,N2)
|
|
B, C1, N1 = x.shape
|
|
_, C2, N2 = skip.shape
|
|
if N1 != N2:
|
|
x = F.interpolate(x, size=N2, mode="linear", align_corners=False)
|
|
x = x.permute(0, 2, 1)
|
|
skip = skip.permute(0, 2, 1)
|
|
out = torch.cat([x, skip], dim=-1)
|
|
out = self.fc1(out)
|
|
out = self.fc2(out)
|
|
return out.permute(0, 2, 1)
|
|
|
|
|
|
class PointTransformerBlock(nn.Module):
|
|
"""Point Transformer block with vector attention. Uses chunked attention for large N to avoid OOM."""
|
|
|
|
CHUNK_SIZE = 48 # Process queries in chunks to limit (B, chunk, N, D) memory; reduce if OOM
|
|
|
|
def __init__(self, in_dim: int, out_dim: int = 256):
|
|
super().__init__()
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
self.fc1 = nn.Linear(in_dim, out_dim)
|
|
self.fc2 = nn.Linear(out_dim, in_dim)
|
|
self.fc_delta = nn.Sequential(
|
|
nn.Linear(3, out_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(out_dim, out_dim),
|
|
)
|
|
self.fc_gamma = nn.Sequential(
|
|
nn.Linear(out_dim, out_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(out_dim, out_dim),
|
|
)
|
|
self.w_qs = nn.Linear(out_dim, out_dim)
|
|
self.w_ks = nn.Linear(out_dim, out_dim)
|
|
self.w_vs = nn.Linear(out_dim, out_dim)
|
|
|
|
def forward(
|
|
self,
|
|
xyz: torch.Tensor,
|
|
feat: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# xyz (B,N,3), feat (B,N,C)
|
|
B, N, _ = feat.shape
|
|
q = self.w_qs(self.fc1(feat)) # (B,N,D)
|
|
k = self.w_ks(self.fc1(feat))
|
|
v = self.w_vs(self.fc1(feat))
|
|
chunk_size = self.CHUNK_SIZE
|
|
if N <= chunk_size:
|
|
# Small N: full attention
|
|
xyz_i = xyz.unsqueeze(2)
|
|
xyz_j = xyz.unsqueeze(1)
|
|
delta_xyz = (xyz_i - xyz_j).reshape(B, N, N, 3)
|
|
delta = self.fc_delta(delta_xyz)
|
|
attn_logits = (q.unsqueeze(2) - k.unsqueeze(1) + delta) / (self.out_dim ** 0.5)
|
|
attn = F.softmax(attn_logits, dim=2)
|
|
out = (attn * (v.unsqueeze(1) + delta)).sum(dim=2)
|
|
else:
|
|
# Large N: chunked attention to avoid OOM
|
|
out_list = []
|
|
for start in range(0, N, chunk_size):
|
|
end = min(start + chunk_size, N)
|
|
q_chunk = q[:, start:end] # (B, chunk, D)
|
|
xyz_chunk = xyz[:, start:end] # (B, chunk, 3)
|
|
xyz_i = xyz_chunk.unsqueeze(2) # (B, chunk, 1, 3)
|
|
xyz_j = xyz.unsqueeze(1) # (B, 1, N, 3)
|
|
delta_xyz = (xyz_i - xyz_j).reshape(B, end - start, N, 3)
|
|
delta = self.fc_delta(delta_xyz) # (B, chunk, N, D)
|
|
attn_logits = (q_chunk.unsqueeze(2) - k.unsqueeze(1) + delta) / (self.out_dim ** 0.5)
|
|
attn = F.softmax(attn_logits, dim=2)
|
|
out_chunk = (attn * (v.unsqueeze(1) + delta)).sum(dim=2)
|
|
out_list.append(out_chunk)
|
|
out = torch.cat(out_list, dim=1)
|
|
out = self.fc2(out) + feat
|
|
return out
|
|
|
|
|
|
class Backbone(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.npoints = [384, 128, 32, 8]
|
|
self.radius = 0.2
|
|
self.nsample = 32
|
|
self.fc1 = nn.Sequential(nn.Linear(34, 32), nn.ReLU(), nn.Linear(32, 32))
|
|
self.transformer1 = PointTransformerBlock(32, 256)
|
|
self.transformers = nn.ModuleList([
|
|
PointTransformerBlock(64, 256),
|
|
PointTransformerBlock(128, 256),
|
|
PointTransformerBlock(256, 256),
|
|
PointTransformerBlock(512, 256),
|
|
])
|
|
self.transition_downs = nn.ModuleList([
|
|
SetAbstraction(self.npoints[0], self.radius, self.nsample, 35, [64, 64]),
|
|
SetAbstraction(self.npoints[1], self.radius, self.nsample, 67, [128, 128]),
|
|
SetAbstraction(self.npoints[2], self.radius, self.nsample, 131, [256, 256]),
|
|
SetAbstraction(self.npoints[3], self.radius, self.nsample, 259, [512, 512]),
|
|
])
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
xyz = x[:, :, :3]
|
|
feat = self.fc1(x)
|
|
feat = checkpoint(self.transformer1, xyz, feat, use_reentrant=False)
|
|
skips_xyz, skips_feat = [], []
|
|
for i, (td, tr) in enumerate(zip(self.transition_downs, self.transformers)):
|
|
xyz_b, feat_b = td(xyz, feat)
|
|
xyz_perm = xyz_b.permute(0, 2, 1)
|
|
feat_perm = feat_b.permute(0, 2, 1)
|
|
feat_b = tr(xyz_perm, feat_perm)
|
|
skips_xyz.append(xyz_perm)
|
|
skips_feat.append(feat_b)
|
|
xyz, feat = xyz_perm, feat_b
|
|
return feat, skips_xyz, skips_feat
|
|
|
|
|
|
class Landmark3DNative(nn.Module):
|
|
"""Python-native Landmark3D - GPU trainable."""
|
|
|
|
def __init__(self, n_points: int = 768):
|
|
super().__init__()
|
|
self.n_points = n_points
|
|
self.backbone = Backbone()
|
|
self.fc2 = nn.Sequential(
|
|
nn.Conv1d(512, 512, 1),
|
|
nn.BatchNorm1d(512),
|
|
nn.ReLU(),
|
|
nn.Conv1d(512, 512, 1),
|
|
nn.BatchNorm1d(512),
|
|
nn.ReLU(),
|
|
nn.Conv1d(512, 512, 1),
|
|
)
|
|
self.transformer2 = PointTransformerBlock(512, 256)
|
|
self.transformers = nn.ModuleList([
|
|
PointTransformerBlock(256, 256),
|
|
PointTransformerBlock(128, 256),
|
|
PointTransformerBlock(64, 256),
|
|
PointTransformerBlock(32, 256),
|
|
])
|
|
self.transition_ups = nn.ModuleList([
|
|
TransitionUp(512 + 512, 256),
|
|
TransitionUp(256 + 256, 128),
|
|
TransitionUp(128 + 128, 64),
|
|
TransitionUp(64 + 64, 32),
|
|
])
|
|
self.fc3 = nn.Sequential(
|
|
nn.Linear(32, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 44),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
B, N, _ = x.shape
|
|
feat, skips_xyz, skips_feat = self.backbone(x)
|
|
xyz = skips_xyz[-1]
|
|
feat = feat.permute(0, 2, 1)
|
|
feat = self.fc2(feat)
|
|
feat = self.transformer2(xyz, feat.permute(0, 2, 1))
|
|
feat = feat.permute(0, 2, 1)
|
|
|
|
for i, (tu, tr) in enumerate(zip(self.transition_ups, self.transformers)):
|
|
skip_xyz = skips_xyz[3 - i]
|
|
skip_feat = skips_feat[3 - i]
|
|
feat = tu(feat, skip_feat.permute(0, 2, 1))
|
|
feat = tr(skip_xyz, feat.permute(0, 2, 1))
|
|
feat = feat.permute(0, 2, 1)
|
|
|
|
out = self.fc3(feat.permute(0, 2, 1))
|
|
weights = F.softmax(out[:, :, :11], dim=1)
|
|
offset = out[:, :, 11:].reshape(B, -1, 11, 3)
|
|
input_xyz = skips_xyz[0].unsqueeze(2)
|
|
landmarks = ((input_xyz + offset) * weights.unsqueeze(-1)).sum(dim=1)
|
|
return landmarks
|