From 6f219f634f9268cab9a2a41598eb7ee0a00d4636 Mon Sep 17 00:00:00 2001 From: GassiGiuseppe Date: Sun, 5 Oct 2025 17:49:01 +0200 Subject: [PATCH] Added attention_mask --- Project_Model/Libs/Transformer/Classes/Decoder.py | 6 +++++- .../Libs/Transformer/Classes/TorchMultiHeadAttention.py | 8 +++++--- Project_Model/Libs/Transformer/Utils/__init__.py | 3 +++ Project_Model/Libs/Transformer/Utils/attention_mask.py | 4 ++++ Project_Model/Libs/Transformer/__init__.py | 4 +++- 5 files changed, 20 insertions(+), 5 deletions(-) create mode 100644 Project_Model/Libs/Transformer/Utils/__init__.py create mode 100644 Project_Model/Libs/Transformer/Utils/attention_mask.py diff --git a/Project_Model/Libs/Transformer/Classes/Decoder.py b/Project_Model/Libs/Transformer/Classes/Decoder.py index d21a9ea..73fe5a0 100644 --- a/Project_Model/Libs/Transformer/Classes/Decoder.py +++ b/Project_Model/Libs/Transformer/Classes/Decoder.py @@ -2,6 +2,8 @@ import torch import torch.nn as nn from .FeedForwardNetwork import FeedForwardNetwork from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention +from ..Utils.attention_mask import get_attention_mask + class Decoder(nn.Module): @@ -15,7 +17,7 @@ class Decoder(nn.Module): super().__init__() self.__masked_attention = MultiHeadAttention( - embedding_dimension, number_of_attention_heads, dropout=0.1 + embedding_dimension, number_of_attention_heads, dropout=0.1, attention_mask=get_attention_mask(embedding_dimension) ) self.__layer_norm_1 = nn.LayerNorm(embedding_dimension) @@ -32,6 +34,8 @@ class Decoder(nn.Module): ) self.__layer_norm_3 = nn.LayerNorm(embedding_dimension) + + def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x # 1) Masked Attention diff --git a/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py b/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py index 6081f75..d310ac8 100644 --- a/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py +++ b/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn - +from typing import Optional class TorchMultiHeadAttention(nn.Module): @@ -9,6 +9,7 @@ class TorchMultiHeadAttention(nn.Module): embedding_dimension: int, number_of_attention_heads: int, dropout: float = 0.0, + attention_mask: Optional[torch.Tensor] = None ): super().__init__() self.attention = nn.MultiheadAttention( @@ -18,12 +19,13 @@ class TorchMultiHeadAttention(nn.Module): batch_first=True, ) + self.__attention_mask = attention_mask + def forward( self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor, - attention_mask=None, key_padding_mask=None, ) -> torch.Tensor: @@ -32,7 +34,7 @@ class TorchMultiHeadAttention(nn.Module): # x * Wv -> V y, _ = self.attention.forward( - x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask + x_q, x_k, x_v, attn_mask=self.__attention_mask, key_padding_mask=key_padding_mask ) return y diff --git a/Project_Model/Libs/Transformer/Utils/__init__.py b/Project_Model/Libs/Transformer/Utils/__init__.py new file mode 100644 index 0000000..856b51f --- /dev/null +++ b/Project_Model/Libs/Transformer/Utils/__init__.py @@ -0,0 +1,3 @@ +from .attention_mask import get_attention_mask + +__all__ = ["get_attention_mask"] \ No newline at end of file diff --git a/Project_Model/Libs/Transformer/Utils/attention_mask.py b/Project_Model/Libs/Transformer/Utils/attention_mask.py new file mode 100644 index 0000000..a6c595e --- /dev/null +++ b/Project_Model/Libs/Transformer/Utils/attention_mask.py @@ -0,0 +1,4 @@ +import torch + +def get_attention_mask(embedding_dimension: int) -> torch.Tensor: + return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1) \ No newline at end of file diff --git a/Project_Model/Libs/Transformer/__init__.py b/Project_Model/Libs/Transformer/__init__.py index e384727..d906699 100644 --- a/Project_Model/Libs/Transformer/__init__.py +++ b/Project_Model/Libs/Transformer/__init__.py @@ -1,3 +1,5 @@ from .Classes import * +from .Utils import * -from . import Classes \ No newline at end of file +from . import Classes +from . import Utils \ No newline at end of file