63 lines
1.7 KiB
Python
63 lines
1.7 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Landmark3D input builder: FDI mapping and build_landmark3d_input.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
# FDI (11-17, 21-27, 31-37, 41-47) -> 28-class one-hot index. Excludes 18, 28, 38, 48.
|
||
|
|
FDI_TO_CLASS_INDEX_28 = {
|
||
|
|
17: 21, 16: 20, 15: 4, 14: 3, 13: 2, 12: 1, 11: 0,
|
||
|
|
21: 5, 22: 6, 23: 7, 24: 8, 25: 9, 26: 22, 27: 23,
|
||
|
|
37: 25, 36: 24, 35: 14, 34: 13, 33: 12, 32: 11, 31: 10,
|
||
|
|
41: 15, 42: 16, 43: 17, 44: 18, 45: 19, 46: 26, 47: 27,
|
||
|
|
}
|
||
|
|
NUM_ONEHOT_CLASSES = 28
|
||
|
|
N_LANDMARKS = 11
|
||
|
|
|
||
|
|
|
||
|
|
def build_landmark3d_input(
|
||
|
|
points: torch.Tensor,
|
||
|
|
normals: torch.Tensor,
|
||
|
|
tooth_class_28: torch.Tensor,
|
||
|
|
n_points: int = 768,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
"""
|
||
|
|
Build input tensor for landmark3d model.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
points: (B, N, 3) centered xyz
|
||
|
|
normals: (B, N, 3) unit normals
|
||
|
|
tooth_class_28: (B,) int in [0, 27]
|
||
|
|
n_points: Target number of points (768 or 784)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
(B, n_points, 34) tensor
|
||
|
|
"""
|
||
|
|
B, N, _ = points.shape
|
||
|
|
device = points.device
|
||
|
|
|
||
|
|
# Subsample or pad to n_points
|
||
|
|
if N >= n_points:
|
||
|
|
step = N / n_points
|
||
|
|
idx = (torch.arange(n_points, device=device).float() * step).long().clamp(0, N - 1)
|
||
|
|
pts = points[:, idx, :] # (B, n_points, 3)
|
||
|
|
nrm = normals[:, idx, :] # (B, n_points, 3)
|
||
|
|
else:
|
||
|
|
repeat = (n_points + N - 1) // N
|
||
|
|
idx = torch.arange(n_points, device=device) % N
|
||
|
|
pts = points[:, idx, :]
|
||
|
|
nrm = normals[:, idx, :]
|
||
|
|
|
||
|
|
# One-hot: (B, n_points, 28)
|
||
|
|
oh = torch.nn.functional.one_hot(
|
||
|
|
tooth_class_28.clamp(0, NUM_ONEHOT_CLASSES - 1),
|
||
|
|
NUM_ONEHOT_CLASSES,
|
||
|
|
).float()
|
||
|
|
oh = oh.unsqueeze(1).expand(-1, n_points, -1)
|
||
|
|
|
||
|
|
x = torch.cat([pts, nrm, oh], dim=-1)
|
||
|
|
return x
|