Fixed a bug where I took encoder embeddings rather than encoder output
This commit is contained in:
parent
ba592c3480
commit
0158db2dce
@ -38,7 +38,7 @@ class TrainingModel(torch.nn.Module):
|
|||||||
self.__detokener = DeToken(latent_space, vocabulary_size)
|
self.__detokener = DeToken(latent_space, vocabulary_size)
|
||||||
|
|
||||||
def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
|
def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
|
||||||
|
|
||||||
encoder_embedder_input, padding_tensor, decoder_embedder_input = args
|
encoder_embedder_input, padding_tensor, decoder_embedder_input = args
|
||||||
|
|
||||||
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
|
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
|
||||||
@ -47,7 +47,7 @@ class TrainingModel(torch.nn.Module):
|
|||||||
encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor))
|
encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor))
|
||||||
|
|
||||||
decoder_output, _, _, _ = self.__decoder(
|
decoder_output, _, _, _ = self.__decoder(
|
||||||
(decoder_tensor, encoder_tensor, encoder_tensor, None)
|
(decoder_tensor, encoder_output, encoder_output, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user