Fixed several bugs for task 4

This commit is contained in:
Christian Risi
2025-10-12 16:30:30 +02:00
parent e0f8a36aa5
commit 07130ff489
5 changed files with 186 additions and 9 deletions

View File

@@ -36,6 +36,7 @@ class TrainingModel(torch.nn.Module):
self.__decoder = torch.nn.Sequential(*TMP_DECODERS)
self.__detokener = DeToken(latent_space, vocabulary_size)
self.__encoder_detokener = DeToken(latent_space, vocabulary_size)
def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
@@ -57,6 +58,6 @@ class TrainingModel(torch.nn.Module):
def take_pieces(self):
return (
(self.__encoder_embedder, self.__encoder),
(self.__encoder_embedder, self.__encoder, self.__encoder_detokener),
(self.__decoder_embedder, self.__decoder, self.__detokener)
)