2025-10-09 11:36:56 +02:00

20 lines
444 B
Python

import torch
class DeToken(torch.nn.Module):
def __init__(self, embedding_size: int, vocabulary_size: int) -> None:
super().__init__()
self.__linear = torch.nn.Linear(embedding_size, vocabulary_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 1) Go from latent space to vocabularu space
x = self.__linear(x)
# 2) Go to logits
# x = torch.softmax(x, 2)
return x