Added attention_mask
This commit is contained in:
@@ -2,6 +2,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from .FeedForwardNetwork import FeedForwardNetwork
|
||||
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
|
||||
from ..Utils.attention_mask import get_attention_mask
|
||||
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
@@ -15,7 +17,7 @@ class Decoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.__masked_attention = MultiHeadAttention(
|
||||
embedding_dimension, number_of_attention_heads, dropout=0.1
|
||||
embedding_dimension, number_of_attention_heads, dropout=0.1, attention_mask=get_attention_mask(embedding_dimension)
|
||||
)
|
||||
|
||||
self.__layer_norm_1 = nn.LayerNorm(embedding_dimension)
|
||||
@@ -32,6 +34,8 @@ class Decoder(nn.Module):
|
||||
)
|
||||
self.__layer_norm_3 = nn.LayerNorm(embedding_dimension)
|
||||
|
||||
|
||||
|
||||
def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x
|
||||
|
||||
# 1) Masked Attention
|
||||
|
||||
Reference in New Issue
Block a user