From 948c3fd7ac6361d586ce44fb9ec920ba1c1c6741 Mon Sep 17 00:00:00 2001 From: GassiGiuseppe Date: Mon, 6 Oct 2025 13:03:03 +0200 Subject: [PATCH] update to batch attention mask --- Project_Model/Libs/Transformer/Classes/Decoder.py | 12 ++++++++---- .../Transformer/Classes/TorchMultiHeadAttention.py | 10 +++++----- .../Libs/Transformer/Utils/attention_mask.py | 10 ++++++---- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/Project_Model/Libs/Transformer/Classes/Decoder.py b/Project_Model/Libs/Transformer/Classes/Decoder.py index a1f5074..11e9aa7 100644 --- a/Project_Model/Libs/Transformer/Classes/Decoder.py +++ b/Project_Model/Libs/Transformer/Classes/Decoder.py @@ -2,8 +2,9 @@ import torch import torch.nn as nn from .FeedForwardNetwork import FeedForwardNetwork from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention -from ..Utils.attention_mask import get_attention_mask +from ..Utils.attention_mask import get_causal_attention_mask +# B, L(T), E_D class Decoder(nn.Module): @@ -17,7 +18,7 @@ class Decoder(nn.Module): super().__init__() self.__masked_attention = MultiHeadAttention( - embedding_dimension, number_of_attention_heads, dropout=0.1, attention_mask=get_attention_mask(embedding_dimension) + embedding_dimension, number_of_attention_heads, dropout=0.1 ) self.__layer_norm_1 = nn.LayerNorm(embedding_dimension) @@ -38,9 +39,12 @@ class Decoder(nn.Module): def forward(self, x, k_x, v_x, padding_mask = None) -> torch.Tensor: # k_x = v_x . While x_q = x + # build of attention mask + attention_mask = get_causal_attention_mask(x.size(1)) + # 1) Masked Attention MASKED_ATTENTION = self.__masked_attention( - x, x, x, key_padding_mask=padding_mask + x, x, x, key_padding_mask=padding_mask, attn_mask=attention_mask ) # 2) Dropout @@ -57,7 +61,7 @@ class Decoder(nn.Module): x = self.__layer_norm_1(x) # 5) Encoder–decoder (cross) attention - CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x key_padding_mask=padding_mask) + CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x, key_padding_mask=padding_mask) # 6) Dropout DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION) diff --git a/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py b/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py index d310ac8..52c0cc5 100644 --- a/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py +++ b/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py @@ -9,7 +9,6 @@ class TorchMultiHeadAttention(nn.Module): embedding_dimension: int, number_of_attention_heads: int, dropout: float = 0.0, - attention_mask: Optional[torch.Tensor] = None ): super().__init__() self.attention = nn.MultiheadAttention( @@ -19,7 +18,6 @@ class TorchMultiHeadAttention(nn.Module): batch_first=True, ) - self.__attention_mask = attention_mask def forward( self, @@ -27,14 +25,16 @@ class TorchMultiHeadAttention(nn.Module): x_k: torch.Tensor, x_v: torch.Tensor, key_padding_mask=None, + attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: # x * Wq -> Q # x * Wk -> K # x * Wv -> V - - y, _ = self.attention.forward( - x_q, x_k, x_v, attn_mask=self.__attention_mask, key_padding_mask=key_padding_mask + # REMEMBER: tochAttention uses Batch internally to build the 3 dimension attention mask given the 2 dimension + y, _ = self.attention( + x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask, + need_weights=False ) return y diff --git a/Project_Model/Libs/Transformer/Utils/attention_mask.py b/Project_Model/Libs/Transformer/Utils/attention_mask.py index cb0ddcf..b1e97f3 100644 --- a/Project_Model/Libs/Transformer/Utils/attention_mask.py +++ b/Project_Model/Libs/Transformer/Utils/attention_mask.py @@ -1,9 +1,11 @@ import torch -def get_causal_attention_mask(embedding_dimension: int) -> torch.Tensor: - return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1) +def get_causal_attention_mask(seq_len: int) -> torch.Tensor: + return torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) -def get_causal_attention_mask_batched(embedding_dimension: int, batch_size: int ) -> torch.Tensor: - base_mask = get_causal_attention_mask(embedding_dimension) + +# there is no need for this since MultiHeadAttention of torch does this internally +def get_causal_attention_mask_batched(seq_len: int, batch_size: int ) -> torch.Tensor: + base_mask = get_causal_attention_mask(seq_len) return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size # the result is that z,x,y where x,y are repeated along z \ No newline at end of file