Fixed several bugs for task 4

This commit is contained in:
Christian Risi
2025-10-12 16:30:30 +02:00
parent e0f8a36aa5
commit 07130ff489
5 changed files with 186 additions and 9 deletions

View File

@@ -9,6 +9,7 @@ class NanoSocraDecoder(torch.nn.Module):
decoder_embedder: Embedder.NanoSocratesEmbedder,
decoder_layers: torch.nn.Sequential,
detokener: DeToken
) -> None:
super().__init__()
@@ -17,14 +18,14 @@ class NanoSocraDecoder(torch.nn.Module):
self.__decoder = decoder_layers
self.__detokener = detokener
def forward(self, args: tuple[torch.Tensor, torch.Tensor]):
def forward(self, args: tuple[torch.Tensor,torch.Tensor, torch.Tensor]):
decoder_embedder_input, tgt_padding = args
decoder_embedder_input, prefix_mask, tgt_padding = args
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
decoder_output, _, _, _, _, _ = self.__decoder(
(decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, True)
(decoder_tensor, decoder_tensor, decoder_tensor, prefix_mask, tgt_padding, True)
)
logits: torch.Tensor = self.__detokener(decoder_output)