Fixed several bugs for task 4
This commit is contained in:
@@ -36,6 +36,7 @@ class TrainingModel(torch.nn.Module):
|
||||
self.__decoder = torch.nn.Sequential(*TMP_DECODERS)
|
||||
|
||||
self.__detokener = DeToken(latent_space, vocabulary_size)
|
||||
self.__encoder_detokener = DeToken(latent_space, vocabulary_size)
|
||||
|
||||
def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
|
||||
|
||||
@@ -57,6 +58,6 @@ class TrainingModel(torch.nn.Module):
|
||||
def take_pieces(self):
|
||||
|
||||
return (
|
||||
(self.__encoder_embedder, self.__encoder),
|
||||
(self.__encoder_embedder, self.__encoder, self.__encoder_detokener),
|
||||
(self.__decoder_embedder, self.__decoder, self.__detokener)
|
||||
)
|
||||
Reference in New Issue
Block a user