Fixed several bugs for task 4
This commit is contained in:
@@ -9,6 +9,7 @@ class NanoSocraDecoder(torch.nn.Module):
|
||||
decoder_embedder: Embedder.NanoSocratesEmbedder,
|
||||
decoder_layers: torch.nn.Sequential,
|
||||
detokener: DeToken
|
||||
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -17,14 +18,14 @@ class NanoSocraDecoder(torch.nn.Module):
|
||||
self.__decoder = decoder_layers
|
||||
self.__detokener = detokener
|
||||
|
||||
def forward(self, args: tuple[torch.Tensor, torch.Tensor]):
|
||||
def forward(self, args: tuple[torch.Tensor,torch.Tensor, torch.Tensor]):
|
||||
|
||||
decoder_embedder_input, tgt_padding = args
|
||||
decoder_embedder_input, prefix_mask, tgt_padding = args
|
||||
|
||||
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
||||
|
||||
decoder_output, _, _, _, _, _ = self.__decoder(
|
||||
(decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, True)
|
||||
(decoder_tensor, decoder_tensor, decoder_tensor, prefix_mask, tgt_padding, True)
|
||||
)
|
||||
|
||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||
|
||||
Reference in New Issue
Block a user