import torch import torch.nn as nn from .FeedForwardNetwork import FeedForwardNetwork from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention from ..Utils.attention_mask import get_causal_attention_mask # B, L(T), E_D class Decoder(nn.Module): def __init__( self, embedding_dimension: int, feed_forward_hidden_layer_dimension: int, number_of_attention_heads: int, ) -> None: super().__init__() self.__masked_attention = MultiHeadAttention( embedding_dimension, number_of_attention_heads, dropout=0.1 ) self.__layer_norm_1 = nn.LayerNorm(embedding_dimension) self.__cross_attention = MultiHeadAttention( embedding_dimension, number_of_attention_heads, dropout=0.1 ) self.__layer_norm_2 = nn.LayerNorm(embedding_dimension) self.__dropout = nn.Dropout(0.1) self.__feed_forward_network = FeedForwardNetwork( embedding_dimension, feed_forward_hidden_layer_dimension ) self.__layer_norm_3 = nn.LayerNorm(embedding_dimension) def forward(self, x, k_x, v_x, padding_mask = None) -> torch.Tensor: # k_x = v_x . While x_q = x # build of attention mask attention_mask = get_causal_attention_mask(x.size(1)) # 1) Masked Attention MASKED_ATTENTION = self.__masked_attention( x, x, x, key_padding_mask=padding_mask, attn_mask=attention_mask ) # 2) Dropout DROPPED_MASKED_ATTENTION = self.__dropout( MASKED_ATTENTION ) del MASKED_ATTENTION # 3) Residual Connection x = x + DROPPED_MASKED_ATTENTION del DROPPED_MASKED_ATTENTION # 4) Layer Normalization x = self.__layer_norm_1(x) # 5) Encoder–decoder (cross) attention CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x, key_padding_mask=padding_mask) # 6) Dropout DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION) del CROSS_ATTENTION # 7) Residual Connection x = x + DROPPED_CROSS_ATTENTION del DROPPED_CROSS_ATTENTION # 8) Layer Normalization x = self.__layer_norm_2(x) # 9) Position-wise feed-forward FEED_FORWARD = self.__feed_forward_network(x) # 10) Dropout DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD) del FEED_FORWARD # 11) Residual Connection x = x + DROPPED_FEED_FORWARD del DROPPED_FEED_FORWARD # 12) Layer Normalization x = self.__layer_norm_3(x) return x # use eval to disable dropout ecc