Update of the batcher to resolve a bug in the 4th construction

This commit is contained in:
GassiGiuseppe
2025-10-12 16:35:42 +02:00
parent 37a2501a79
commit 144f8724d6
3 changed files with 46 additions and 4 deletions

View File

@@ -20,7 +20,6 @@ class Decoder(nn.Module):
super().__init__()
self.__masked_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1
)
@@ -58,7 +57,7 @@ class Decoder(nn.Module):
# build of attention mask
# TODO: create a prefix causal mask if needed
if decoder_only:
attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads)
attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),src_padding_mask,self.__attention_heads) # the correct is tgt however ...
else:
attention_mask = get_causal_attention_mask(x.size(1))