Files
FishServer/FishMeasure/weight_estimator/landmark3d_native.py
2026-04-08 19:32:23 +08:00

295 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Python-native Landmark3D Point Transformer - GPU trainable.
Replicates the landmark3d.pt architecture with proper device handling.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional, Tuple
import torch
from torch.utils.checkpoint import checkpoint
import torch.nn as nn
import torch.nn.functional as F
# Device-aware point cloud utils (from Pointnet2)
def _square_distance(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def _index_points(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long, device=device).view(view_shape).repeat(repeat_shape)
return points[batch_indices, idx, :]
def _farthest_point_sample(xyz: torch.Tensor, npoint: int) -> torch.Tensor:
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long, device=device)
distance = torch.ones(B, N, device=device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device)
batch_indices = torch.arange(B, dtype=torch.long, device=device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def _query_ball_point(radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long, device=device).view(1, 1, N).repeat([B, S, 1])
sqrdists = _square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def _sample_and_group(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
B, N, C = xyz.shape
S = npoint
fps_idx = _farthest_point_sample(xyz, npoint)
new_xyz = _index_points(xyz, fps_idx)
idx = _query_ball_point(radius, nsample, xyz, new_xyz)
grouped_xyz = _index_points(xyz, idx)
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
grouped_points = _index_points(points, idx)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1)
return new_xyz, new_points
class SetAbstraction(nn.Module):
"""Set Abstraction: FPS + ball query + MLP."""
def __init__(self, npoint: int, radius: float, nsample: int, in_channel: int, mlp: List[int]):
super().__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
def forward(self, xyz: torch.Tensor, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# xyz (B,N,3), points (B,N,C)
new_xyz, new_points = _sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
# new_points (B,npoint,nsample,in_channel) -> (B,in_channel,nsample,npoint)
new_points = new_points.permute(0, 3, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = torch.max(new_points, 2)[0] # (B,out,N)
new_xyz = new_xyz.permute(0, 2, 1) # (B,3,N)
return new_xyz, new_points
class TransitionUp(nn.Module):
"""Transition Up: upsample if needed, concat skip + MLP."""
def __init__(self, in_channel: int, out_channel: int):
super().__init__()
self.fc1 = nn.Sequential(nn.Linear(in_channel, out_channel), nn.ReLU())
self.fc2 = nn.Sequential(nn.Linear(out_channel, out_channel), nn.ReLU())
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
# x (B,C1,N1), skip (B,C2,N2), upsample x to N2 if needed, concat -> (B,C1+C2,N2)
B, C1, N1 = x.shape
_, C2, N2 = skip.shape
if N1 != N2:
x = F.interpolate(x, size=N2, mode="linear", align_corners=False)
x = x.permute(0, 2, 1)
skip = skip.permute(0, 2, 1)
out = torch.cat([x, skip], dim=-1)
out = self.fc1(out)
out = self.fc2(out)
return out.permute(0, 2, 1)
class PointTransformerBlock(nn.Module):
"""Point Transformer block with vector attention. Uses chunked attention for large N to avoid OOM."""
CHUNK_SIZE = 48 # Process queries in chunks to limit (B, chunk, N, D) memory; reduce if OOM
def __init__(self, in_dim: int, out_dim: int = 256):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.fc1 = nn.Linear(in_dim, out_dim)
self.fc2 = nn.Linear(out_dim, in_dim)
self.fc_delta = nn.Sequential(
nn.Linear(3, out_dim),
nn.ReLU(),
nn.Linear(out_dim, out_dim),
)
self.fc_gamma = nn.Sequential(
nn.Linear(out_dim, out_dim),
nn.ReLU(),
nn.Linear(out_dim, out_dim),
)
self.w_qs = nn.Linear(out_dim, out_dim)
self.w_ks = nn.Linear(out_dim, out_dim)
self.w_vs = nn.Linear(out_dim, out_dim)
def forward(
self,
xyz: torch.Tensor,
feat: torch.Tensor,
) -> torch.Tensor:
# xyz (B,N,3), feat (B,N,C)
B, N, _ = feat.shape
q = self.w_qs(self.fc1(feat)) # (B,N,D)
k = self.w_ks(self.fc1(feat))
v = self.w_vs(self.fc1(feat))
chunk_size = self.CHUNK_SIZE
if N <= chunk_size:
# Small N: full attention
xyz_i = xyz.unsqueeze(2)
xyz_j = xyz.unsqueeze(1)
delta_xyz = (xyz_i - xyz_j).reshape(B, N, N, 3)
delta = self.fc_delta(delta_xyz)
attn_logits = (q.unsqueeze(2) - k.unsqueeze(1) + delta) / (self.out_dim ** 0.5)
attn = F.softmax(attn_logits, dim=2)
out = (attn * (v.unsqueeze(1) + delta)).sum(dim=2)
else:
# Large N: chunked attention to avoid OOM
out_list = []
for start in range(0, N, chunk_size):
end = min(start + chunk_size, N)
q_chunk = q[:, start:end] # (B, chunk, D)
xyz_chunk = xyz[:, start:end] # (B, chunk, 3)
xyz_i = xyz_chunk.unsqueeze(2) # (B, chunk, 1, 3)
xyz_j = xyz.unsqueeze(1) # (B, 1, N, 3)
delta_xyz = (xyz_i - xyz_j).reshape(B, end - start, N, 3)
delta = self.fc_delta(delta_xyz) # (B, chunk, N, D)
attn_logits = (q_chunk.unsqueeze(2) - k.unsqueeze(1) + delta) / (self.out_dim ** 0.5)
attn = F.softmax(attn_logits, dim=2)
out_chunk = (attn * (v.unsqueeze(1) + delta)).sum(dim=2)
out_list.append(out_chunk)
out = torch.cat(out_list, dim=1)
out = self.fc2(out) + feat
return out
class Backbone(nn.Module):
def __init__(self):
super().__init__()
self.npoints = [384, 128, 32, 8]
self.radius = 0.2
self.nsample = 32
self.fc1 = nn.Sequential(nn.Linear(34, 32), nn.ReLU(), nn.Linear(32, 32))
self.transformer1 = PointTransformerBlock(32, 256)
self.transformers = nn.ModuleList([
PointTransformerBlock(64, 256),
PointTransformerBlock(128, 256),
PointTransformerBlock(256, 256),
PointTransformerBlock(512, 256),
])
self.transition_downs = nn.ModuleList([
SetAbstraction(self.npoints[0], self.radius, self.nsample, 35, [64, 64]),
SetAbstraction(self.npoints[1], self.radius, self.nsample, 67, [128, 128]),
SetAbstraction(self.npoints[2], self.radius, self.nsample, 131, [256, 256]),
SetAbstraction(self.npoints[3], self.radius, self.nsample, 259, [512, 512]),
])
def forward(self, x: torch.Tensor):
xyz = x[:, :, :3]
feat = self.fc1(x)
feat = checkpoint(self.transformer1, xyz, feat, use_reentrant=False)
skips_xyz, skips_feat = [], []
for i, (td, tr) in enumerate(zip(self.transition_downs, self.transformers)):
xyz_b, feat_b = td(xyz, feat)
xyz_perm = xyz_b.permute(0, 2, 1)
feat_perm = feat_b.permute(0, 2, 1)
feat_b = tr(xyz_perm, feat_perm)
skips_xyz.append(xyz_perm)
skips_feat.append(feat_b)
xyz, feat = xyz_perm, feat_b
return feat, skips_xyz, skips_feat
class Landmark3DNative(nn.Module):
"""Python-native Landmark3D - GPU trainable."""
def __init__(self, n_points: int = 768):
super().__init__()
self.n_points = n_points
self.backbone = Backbone()
self.fc2 = nn.Sequential(
nn.Conv1d(512, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Conv1d(512, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Conv1d(512, 512, 1),
)
self.transformer2 = PointTransformerBlock(512, 256)
self.transformers = nn.ModuleList([
PointTransformerBlock(256, 256),
PointTransformerBlock(128, 256),
PointTransformerBlock(64, 256),
PointTransformerBlock(32, 256),
])
self.transition_ups = nn.ModuleList([
TransitionUp(512 + 512, 256),
TransitionUp(256 + 256, 128),
TransitionUp(128 + 128, 64),
TransitionUp(64 + 64, 32),
])
self.fc3 = nn.Sequential(
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 44),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, _ = x.shape
feat, skips_xyz, skips_feat = self.backbone(x)
xyz = skips_xyz[-1]
feat = feat.permute(0, 2, 1)
feat = self.fc2(feat)
feat = self.transformer2(xyz, feat.permute(0, 2, 1))
feat = feat.permute(0, 2, 1)
for i, (tu, tr) in enumerate(zip(self.transition_ups, self.transformers)):
skip_xyz = skips_xyz[3 - i]
skip_feat = skips_feat[3 - i]
feat = tu(feat, skip_feat.permute(0, 2, 1))
feat = tr(skip_xyz, feat.permute(0, 2, 1))
feat = feat.permute(0, 2, 1)
out = self.fc3(feat.permute(0, 2, 1))
weights = F.softmax(out[:, :, :11], dim=1)
offset = out[:, :, 11:].reshape(B, -1, 11, 3)
input_xyz = skips_xyz[0].unsqueeze(2)
landmarks = ((input_xyz + offset) * weights.unsqueeze(-1)).sum(dim=1)
return landmarks