Added attention_mask

This commit is contained in:
GassiGiuseppe
2025-10-05 17:49:01 +02:00
parent b303affd18
commit 6f219f634f
5 changed files with 20 additions and 5 deletions

View File

@@ -2,6 +2,8 @@ import torch
import torch.nn as nn
from .FeedForwardNetwork import FeedForwardNetwork
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
from ..Utils.attention_mask import get_attention_mask
class Decoder(nn.Module):
@@ -15,7 +17,7 @@ class Decoder(nn.Module):
super().__init__()
self.__masked_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1
embedding_dimension, number_of_attention_heads, dropout=0.1, attention_mask=get_attention_mask(embedding_dimension)
)
self.__layer_norm_1 = nn.LayerNorm(embedding_dimension)
@@ -32,6 +34,8 @@ class Decoder(nn.Module):
)
self.__layer_norm_3 = nn.LayerNorm(embedding_dimension)
def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x
# 1) Masked Attention

View File

@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from typing import Optional
class TorchMultiHeadAttention(nn.Module):
@@ -9,6 +9,7 @@ class TorchMultiHeadAttention(nn.Module):
embedding_dimension: int,
number_of_attention_heads: int,
dropout: float = 0.0,
attention_mask: Optional[torch.Tensor] = None
):
super().__init__()
self.attention = nn.MultiheadAttention(
@@ -18,12 +19,13 @@ class TorchMultiHeadAttention(nn.Module):
batch_first=True,
)
self.__attention_mask = attention_mask
def forward(
self,
x_q: torch.Tensor,
x_k: torch.Tensor,
x_v: torch.Tensor,
attention_mask=None,
key_padding_mask=None,
) -> torch.Tensor:
@@ -32,7 +34,7 @@ class TorchMultiHeadAttention(nn.Module):
# x * Wv -> V
y, _ = self.attention.forward(
x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask
x_q, x_k, x_v, attn_mask=self.__attention_mask, key_padding_mask=key_padding_mask
)
return y