Made model Batch ready

This commit is contained in:
Christian Risi
2025-10-07 16:37:20 +02:00
parent 109ad9f36b
commit fdece42462
4 changed files with 47 additions and 17 deletions

View File

@@ -0,0 +1,19 @@
import torch
class DeToken(torch.nn.Module):
def __init__(self, embedding_size: int, vocabulary_size: int) -> None:
super().__init__()
self.__linear = torch.nn.Linear(embedding_size, vocabulary_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 1) Go from latent space to vocabularu space
x = self.__linear(x)
# 2) Go to logits
x = torch.softmax(x, 2)
return x