Pipeline fix and added a util to decode
This commit is contained in:
@@ -37,17 +37,17 @@ class TrainingModel(torch.nn.Module):
|
||||
|
||||
self.__detokener = DeToken(latent_space, vocabulary_size)
|
||||
|
||||
def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
|
||||
def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
|
||||
|
||||
encoder_embedder_input, padding_tensor, decoder_embedder_input = args
|
||||
encoder_embedder_input, src_padding, decoder_embedder_input, tgt_padding = args
|
||||
|
||||
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
|
||||
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
||||
|
||||
encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor))
|
||||
encoder_output, _ = self.__encoder((encoder_tensor, src_padding))
|
||||
|
||||
decoder_output, _, _, _ = self.__decoder(
|
||||
(decoder_tensor, encoder_output, encoder_output, None)
|
||||
decoder_output, _, _, _, _ = self.__decoder(
|
||||
(decoder_tensor, encoder_output, encoder_output, src_padding, tgt_padding)
|
||||
)
|
||||
|
||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||
|
||||
Reference in New Issue
Block a user