20 lines
444 B
Python
20 lines
444 B
Python
import torch
|
|
|
|
|
|
class DeToken(torch.nn.Module):
|
|
|
|
def __init__(self, embedding_size: int, vocabulary_size: int) -> None:
|
|
super().__init__()
|
|
|
|
self.__linear = torch.nn.Linear(embedding_size, vocabulary_size)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
# 1) Go from latent space to vocabularu space
|
|
x = self.__linear(x)
|
|
|
|
# 2) Go to logits
|
|
# x = torch.softmax(x, 2)
|
|
|
|
return x
|