Added attention_mask

This commit is contained in:
GassiGiuseppe 2025-10-05 17:49:01 +02:00
parent b303affd18
commit 6f219f634f
5 changed files with 20 additions and 5 deletions

View File

@ -2,6 +2,8 @@ import torch
import torch.nn as nn
from .FeedForwardNetwork import FeedForwardNetwork
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
from ..Utils.attention_mask import get_attention_mask
class Decoder(nn.Module):
@ -15,7 +17,7 @@ class Decoder(nn.Module):
super().__init__()
self.__masked_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1
embedding_dimension, number_of_attention_heads, dropout=0.1, attention_mask=get_attention_mask(embedding_dimension)
)
self.__layer_norm_1 = nn.LayerNorm(embedding_dimension)
@ -32,6 +34,8 @@ class Decoder(nn.Module):
)
self.__layer_norm_3 = nn.LayerNorm(embedding_dimension)
def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x
# 1) Masked Attention

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from typing import Optional
class TorchMultiHeadAttention(nn.Module):
@ -9,6 +9,7 @@ class TorchMultiHeadAttention(nn.Module):
embedding_dimension: int,
number_of_attention_heads: int,
dropout: float = 0.0,
attention_mask: Optional[torch.Tensor] = None
):
super().__init__()
self.attention = nn.MultiheadAttention(
@ -18,12 +19,13 @@ class TorchMultiHeadAttention(nn.Module):
batch_first=True,
)
self.__attention_mask = attention_mask
def forward(
self,
x_q: torch.Tensor,
x_k: torch.Tensor,
x_v: torch.Tensor,
attention_mask=None,
key_padding_mask=None,
) -> torch.Tensor:
@ -32,7 +34,7 @@ class TorchMultiHeadAttention(nn.Module):
# x * Wv -> V
y, _ = self.attention.forward(
x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask
x_q, x_k, x_v, attn_mask=self.__attention_mask, key_padding_mask=key_padding_mask
)
return y

View File

@ -0,0 +1,3 @@
from .attention_mask import get_attention_mask
__all__ = ["get_attention_mask"]

View File

@ -0,0 +1,4 @@
import torch
def get_attention_mask(embedding_dimension: int) -> torch.Tensor:
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1)

View File

@ -1,3 +1,5 @@
from .Classes import *
from .Utils import *
from . import Classes
from . import Classes
from . import Utils