Added attention_mask

This commit is contained in:
GassiGiuseppe 2025-10-05 17:49:01 +02:00
parent b303affd18
commit 6f219f634f
5 changed files with 20 additions and 5 deletions

View File

@ -2,6 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from .FeedForwardNetwork import FeedForwardNetwork from .FeedForwardNetwork import FeedForwardNetwork
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
from ..Utils.attention_mask import get_attention_mask
class Decoder(nn.Module): class Decoder(nn.Module):
@ -15,7 +17,7 @@ class Decoder(nn.Module):
super().__init__() super().__init__()
self.__masked_attention = MultiHeadAttention( self.__masked_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1 embedding_dimension, number_of_attention_heads, dropout=0.1, attention_mask=get_attention_mask(embedding_dimension)
) )
self.__layer_norm_1 = nn.LayerNorm(embedding_dimension) self.__layer_norm_1 = nn.LayerNorm(embedding_dimension)
@ -32,6 +34,8 @@ class Decoder(nn.Module):
) )
self.__layer_norm_3 = nn.LayerNorm(embedding_dimension) self.__layer_norm_3 = nn.LayerNorm(embedding_dimension)
def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x
# 1) Masked Attention # 1) Masked Attention

View File

@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional
class TorchMultiHeadAttention(nn.Module): class TorchMultiHeadAttention(nn.Module):
@ -9,6 +9,7 @@ class TorchMultiHeadAttention(nn.Module):
embedding_dimension: int, embedding_dimension: int,
number_of_attention_heads: int, number_of_attention_heads: int,
dropout: float = 0.0, dropout: float = 0.0,
attention_mask: Optional[torch.Tensor] = None
): ):
super().__init__() super().__init__()
self.attention = nn.MultiheadAttention( self.attention = nn.MultiheadAttention(
@ -18,12 +19,13 @@ class TorchMultiHeadAttention(nn.Module):
batch_first=True, batch_first=True,
) )
self.__attention_mask = attention_mask
def forward( def forward(
self, self,
x_q: torch.Tensor, x_q: torch.Tensor,
x_k: torch.Tensor, x_k: torch.Tensor,
x_v: torch.Tensor, x_v: torch.Tensor,
attention_mask=None,
key_padding_mask=None, key_padding_mask=None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -32,7 +34,7 @@ class TorchMultiHeadAttention(nn.Module):
# x * Wv -> V # x * Wv -> V
y, _ = self.attention.forward( y, _ = self.attention.forward(
x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask x_q, x_k, x_v, attn_mask=self.__attention_mask, key_padding_mask=key_padding_mask
) )
return y return y

View File

@ -0,0 +1,3 @@
from .attention_mask import get_attention_mask
__all__ = ["get_attention_mask"]

View File

@ -0,0 +1,4 @@
import torch
def get_attention_mask(embedding_dimension: int) -> torch.Tensor:
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1)

View File

@ -1,3 +1,5 @@
from .Classes import * from .Classes import *
from .Utils import *
from . import Classes from . import Classes
from . import Utils