diff --git a/Project_Model/Libs/Transformer/Classes/Decoder.py b/Project_Model/Libs/Transformer/Classes/Decoder.py index b441ebd..97cd148 100644 --- a/Project_Model/Libs/Transformer/Classes/Decoder.py +++ b/Project_Model/Libs/Transformer/Classes/Decoder.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from .FeedForwardNetwork import FeedForwardNetwork from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention -from ..Utils.attention_mask import get_causal_attention_mask +from ..Utils.attention_mask import get_causal_attention_mask, get_prefix_causal_mask_from_padding_mask # B, L(T), E_D @@ -16,8 +16,11 @@ class Decoder(nn.Module): feed_forward_hidden_layer_dimension: int, number_of_attention_heads: int, ) -> None: + self.__attention_heads = number_of_attention_heads super().__init__() + + self.__masked_attention = MultiHeadAttention( embedding_dimension, number_of_attention_heads, dropout=0.1 ) @@ -54,7 +57,10 @@ class Decoder(nn.Module): # build of attention mask # TODO: create a prefix causal mask if needed - attention_mask = get_causal_attention_mask(x.size(1)) + if decoder_only: + attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads) + else: + attention_mask = get_causal_attention_mask(x.size(1)) # 1) Masked Attention MASKED_ATTENTION = self.__masked_attention(