diff --git a/Project_Model/Libs/Transformer/Models/TrainingModel.py b/Project_Model/Libs/Transformer/Models/TrainingModel.py index 2a72717..c88ba6c 100644 --- a/Project_Model/Libs/Transformer/Models/TrainingModel.py +++ b/Project_Model/Libs/Transformer/Models/TrainingModel.py @@ -38,7 +38,7 @@ class TrainingModel(torch.nn.Module): self.__detokener = DeToken(latent_space, vocabulary_size) def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor]): - + encoder_embedder_input, padding_tensor, decoder_embedder_input = args encoder_tensor = self.__encoder_embedder(encoder_embedder_input) @@ -47,7 +47,7 @@ class TrainingModel(torch.nn.Module): encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor)) decoder_output, _, _, _ = self.__decoder( - (decoder_tensor, encoder_tensor, encoder_tensor, None) + (decoder_tensor, encoder_output, encoder_output, None) ) logits: torch.Tensor = self.__detokener(decoder_output)