import torch class DeToken(torch.nn.Module): def __init__(self, embedding_size: int, vocabulary_size: int) -> None: super().__init__() self.__linear = torch.nn.Linear(embedding_size, vocabulary_size) def forward(self, x: torch.Tensor) -> torch.Tensor: # 1) Go from latent space to vocabularu space x = self.__linear(x) # 2) Go to logits # x = torch.softmax(x, 2) return x