import torch import torch.nn as nn from .FeedForwardNetwork import FeedForwardNetwork from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention from ..Utils.attention_mask import get_causal_attention_mask # B, L(T), E_D class Decoder(nn.Module): def __init__( self, embedding_dimension: int, feed_forward_hidden_layer_dimension: int, number_of_attention_heads: int, ) -> None: super().__init__() self.__masked_attention = MultiHeadAttention( embedding_dimension, number_of_attention_heads, dropout=0.1 ) self.__layer_norm_1 = nn.LayerNorm(embedding_dimension) self.__cross_attention = MultiHeadAttention( embedding_dimension, number_of_attention_heads, dropout=0.1 ) self.__layer_norm_2 = nn.LayerNorm(embedding_dimension) self.__dropout = nn.Dropout(0.1) self.__feed_forward_network = FeedForwardNetwork( embedding_dimension, feed_forward_hidden_layer_dimension ) self.__layer_norm_3 = nn.LayerNorm(embedding_dimension) def forward( self, args: tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor ] ): # -> list[torch.Tensor]: # k_x = v_x . While x_q = x # WARNING: args is needed to have sequential x, k_x, v_x, padding_mask,encoder_padding_mask = args # build of attention mask attention_mask = get_causal_attention_mask(x.size(1)) # 1) Masked Attention MASKED_ATTENTION = self.__masked_attention( x, x, x, key_padding_mask=padding_mask, attention_mask=attention_mask ) # 2) Dropout # DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION) # del MASKED_ATTENTION # 3) Residual Connection x = x + MASKED_ATTENTION del MASKED_ATTENTION # 4) Layer Normalization x = self.__layer_norm_1(x) # 5) Encoder–decoder (cross) attention CROSS_ATTENTION = self.__cross_attention( x, k_x, v_x, key_padding_mask=encoder_padding_mask ) # 6) Dropout # DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION) # del CROSS_ATTENTION # 7) Residual Connection x = x + CROSS_ATTENTION del CROSS_ATTENTION # 8) Layer Normalization x = self.__layer_norm_2(x) # 9) Position-wise feed-forward FEED_FORWARD = self.__feed_forward_network(x) # 10) Dropout # DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD) # del FEED_FORWARD # 11) Residual Connection x = x + FEED_FORWARD del FEED_FORWARD # 12) Layer Normalization x = self.__layer_norm_3(x) return (x, k_x, v_x, padding_mask, encoder_padding_mask) # use eval to disable dropout ecc