From 0f243eaac24fa54d6f2881ad608d07910d355afb Mon Sep 17 00:00:00 2001 From: GassiGiuseppe Date: Sun, 5 Oct 2025 18:46:06 +0200 Subject: [PATCH] added padding_mask entry to decoder and encoder --- Project_Model/Libs/Transformer/Classes/Decoder.py | 6 +++--- Project_Model/Libs/Transformer/Classes/Encoder.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project_Model/Libs/Transformer/Classes/Decoder.py b/Project_Model/Libs/Transformer/Classes/Decoder.py index 73fe5a0..a1f5074 100644 --- a/Project_Model/Libs/Transformer/Classes/Decoder.py +++ b/Project_Model/Libs/Transformer/Classes/Decoder.py @@ -36,11 +36,11 @@ class Decoder(nn.Module): - def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x + def forward(self, x, k_x, v_x, padding_mask = None) -> torch.Tensor: # k_x = v_x . While x_q = x # 1) Masked Attention MASKED_ATTENTION = self.__masked_attention( - x, x, x, attention_mask=attention_mask + x, x, x, key_padding_mask=padding_mask ) # 2) Dropout @@ -57,7 +57,7 @@ class Decoder(nn.Module): x = self.__layer_norm_1(x) # 5) Encoder–decoder (cross) attention - CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x) + CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x key_padding_mask=padding_mask) # 6) Dropout DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION) diff --git a/Project_Model/Libs/Transformer/Classes/Encoder.py b/Project_Model/Libs/Transformer/Classes/Encoder.py index 8adfc76..cdec92a 100644 --- a/Project_Model/Libs/Transformer/Classes/Encoder.py +++ b/Project_Model/Libs/Transformer/Classes/Encoder.py @@ -31,12 +31,12 @@ class Encoder( self.__dropout = nn.Dropout(0.1) # ... pass - def forward(self, x): + def forward(self, x, padding_mask = None): # -> ATTENTION -> dropout -> add and normalize -> FF -> dropout -> add and normalize -> # Attention with Residual Connection [ input + self-attention] # 1) Multi Head Attention - ATTENTION = self.__attention(x, x, x) + ATTENTION = self.__attention(x, x, x,key_padding_mask= padding_mask) # 2) Dropout DROPPED_ATTENTION = self.__dropout(ATTENTION)