diff --git a/Project_Model/Libs/Transformer/Classes/Decoder.py b/Project_Model/Libs/Transformer/Classes/Decoder.py index 73fe5a0..a1f5074 100644 --- a/Project_Model/Libs/Transformer/Classes/Decoder.py +++ b/Project_Model/Libs/Transformer/Classes/Decoder.py @@ -36,11 +36,11 @@ class Decoder(nn.Module): - def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x + def forward(self, x, k_x, v_x, padding_mask = None) -> torch.Tensor: # k_x = v_x . While x_q = x # 1) Masked Attention MASKED_ATTENTION = self.__masked_attention( - x, x, x, attention_mask=attention_mask + x, x, x, key_padding_mask=padding_mask ) # 2) Dropout @@ -57,7 +57,7 @@ class Decoder(nn.Module): x = self.__layer_norm_1(x) # 5) Encoder–decoder (cross) attention - CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x) + CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x key_padding_mask=padding_mask) # 6) Dropout DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION) diff --git a/Project_Model/Libs/Transformer/Classes/Encoder.py b/Project_Model/Libs/Transformer/Classes/Encoder.py index 8adfc76..cdec92a 100644 --- a/Project_Model/Libs/Transformer/Classes/Encoder.py +++ b/Project_Model/Libs/Transformer/Classes/Encoder.py @@ -31,12 +31,12 @@ class Encoder( self.__dropout = nn.Dropout(0.1) # ... pass - def forward(self, x): + def forward(self, x, padding_mask = None): # -> ATTENTION -> dropout -> add and normalize -> FF -> dropout -> add and normalize -> # Attention with Residual Connection [ input + self-attention] # 1) Multi Head Attention - ATTENTION = self.__attention(x, x, x) + ATTENTION = self.__attention(x, x, x,key_padding_mask= padding_mask) # 2) Dropout DROPPED_ATTENTION = self.__dropout(ATTENTION)