Added attention_mask

This commit is contained in:
GassiGiuseppe
2025-10-05 17:49:01 +02:00
parent b303affd18
commit 6f219f634f
5 changed files with 20 additions and 5 deletions

View File

@@ -2,6 +2,8 @@ import torch
import torch.nn as nn
from .FeedForwardNetwork import FeedForwardNetwork
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
from ..Utils.attention_mask import get_attention_mask
class Decoder(nn.Module):
@@ -15,7 +17,7 @@ class Decoder(nn.Module):
super().__init__()
self.__masked_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1
embedding_dimension, number_of_attention_heads, dropout=0.1, attention_mask=get_attention_mask(embedding_dimension)
)
self.__layer_norm_1 = nn.LayerNorm(embedding_dimension)
@@ -32,6 +34,8 @@ class Decoder(nn.Module):
)
self.__layer_norm_3 = nn.LayerNorm(embedding_dimension)
def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x
# 1) Masked Attention