#!/usr/bin/env python3 """ Landmark3D input builder: FDI mapping and build_landmark3d_input. """ from __future__ import annotations import torch # FDI (11-17, 21-27, 31-37, 41-47) -> 28-class one-hot index. Excludes 18, 28, 38, 48. FDI_TO_CLASS_INDEX_28 = { 17: 21, 16: 20, 15: 4, 14: 3, 13: 2, 12: 1, 11: 0, 21: 5, 22: 6, 23: 7, 24: 8, 25: 9, 26: 22, 27: 23, 37: 25, 36: 24, 35: 14, 34: 13, 33: 12, 32: 11, 31: 10, 41: 15, 42: 16, 43: 17, 44: 18, 45: 19, 46: 26, 47: 27, } NUM_ONEHOT_CLASSES = 28 N_LANDMARKS = 11 def build_landmark3d_input( points: torch.Tensor, normals: torch.Tensor, tooth_class_28: torch.Tensor, n_points: int = 768, ) -> torch.Tensor: """ Build input tensor for landmark3d model. Args: points: (B, N, 3) centered xyz normals: (B, N, 3) unit normals tooth_class_28: (B,) int in [0, 27] n_points: Target number of points (768 or 784) Returns: (B, n_points, 34) tensor """ B, N, _ = points.shape device = points.device # Subsample or pad to n_points if N >= n_points: step = N / n_points idx = (torch.arange(n_points, device=device).float() * step).long().clamp(0, N - 1) pts = points[:, idx, :] # (B, n_points, 3) nrm = normals[:, idx, :] # (B, n_points, 3) else: repeat = (n_points + N - 1) // N idx = torch.arange(n_points, device=device) % N pts = points[:, idx, :] nrm = normals[:, idx, :] # One-hot: (B, n_points, 28) oh = torch.nn.functional.one_hot( tooth_class_28.clamp(0, NUM_ONEHOT_CLASSES - 1), NUM_ONEHOT_CLASSES, ).float() oh = oh.unsqueeze(1).expand(-1, n_points, -1) x = torch.cat([pts, nrm, oh], dim=-1) return x