Files
FishServer/FishMeasure/weight_estimator/landmark3d_model.py

63 lines
1.7 KiB
Python

#!/usr/bin/env python3
"""
Landmark3D input builder: FDI mapping and build_landmark3d_input.
"""
from __future__ import annotations
import torch
# FDI (11-17, 21-27, 31-37, 41-47) -> 28-class one-hot index. Excludes 18, 28, 38, 48.
FDI_TO_CLASS_INDEX_28 = {
17: 21, 16: 20, 15: 4, 14: 3, 13: 2, 12: 1, 11: 0,
21: 5, 22: 6, 23: 7, 24: 8, 25: 9, 26: 22, 27: 23,
37: 25, 36: 24, 35: 14, 34: 13, 33: 12, 32: 11, 31: 10,
41: 15, 42: 16, 43: 17, 44: 18, 45: 19, 46: 26, 47: 27,
}
NUM_ONEHOT_CLASSES = 28
N_LANDMARKS = 11
def build_landmark3d_input(
points: torch.Tensor,
normals: torch.Tensor,
tooth_class_28: torch.Tensor,
n_points: int = 768,
) -> torch.Tensor:
"""
Build input tensor for landmark3d model.
Args:
points: (B, N, 3) centered xyz
normals: (B, N, 3) unit normals
tooth_class_28: (B,) int in [0, 27]
n_points: Target number of points (768 or 784)
Returns:
(B, n_points, 34) tensor
"""
B, N, _ = points.shape
device = points.device
# Subsample or pad to n_points
if N >= n_points:
step = N / n_points
idx = (torch.arange(n_points, device=device).float() * step).long().clamp(0, N - 1)
pts = points[:, idx, :] # (B, n_points, 3)
nrm = normals[:, idx, :] # (B, n_points, 3)
else:
repeat = (n_points + N - 1) // N
idx = torch.arange(n_points, device=device) % N
pts = points[:, idx, :]
nrm = normals[:, idx, :]
# One-hot: (B, n_points, 28)
oh = torch.nn.functional.one_hot(
tooth_class_28.clamp(0, NUM_ONEHOT_CLASSES - 1),
NUM_ONEHOT_CLASSES,
).float()
oh = oh.unsqueeze(1).expand(-1, n_points, -1)
x = torch.cat([pts, nrm, oh], dim=-1)
return x