from typing import Optional import torch import torch.nn as nn from .FeedForwardNetwork import FeedForwardNetwork from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention from ..Utils.attention_mask import get_causal_attention_mask, get_prefix_causal_mask_from_padding_mask # B, L(T), E_D class Decoder(nn.Module): def __init__( self, embedding_dimension: int, feed_forward_hidden_layer_dimension: int, number_of_attention_heads: int, ) -> None: self.__attention_heads = number_of_attention_heads super().__init__() self.__masked_attention = MultiHeadAttention( embedding_dimension, number_of_attention_heads, dropout=0.1 ) self.__layer_norm_1 = nn.LayerNorm(embedding_dimension) self.__cross_attention = MultiHeadAttention( embedding_dimension, number_of_attention_heads, dropout=0.1 ) self.__layer_norm_2 = nn.LayerNorm(embedding_dimension) self.__dropout = nn.Dropout(0.1) self.__feed_forward_network = FeedForwardNetwork( embedding_dimension, feed_forward_hidden_layer_dimension ) self.__layer_norm_3 = nn.LayerNorm(embedding_dimension) def forward( self, args: tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[bool] ] ): # -> list[torch.Tensor]: # k_x = v_x . While x_q = x # WARNING: args is needed to have sequential if len(args) < 6: args = args + (False) x, k_x, v_x, src_padding_mask, tgt_padding_mask, decoder_only = args # build of attention mask # TODO: create a prefix causal mask if needed if decoder_only: attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads) else: attention_mask = get_causal_attention_mask(x.size(1)) # 1) Masked Attention MASKED_ATTENTION = self.__masked_attention( x, x, x, key_padding_mask=tgt_padding_mask, attention_mask=attention_mask ) # 2) Dropout DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION) del MASKED_ATTENTION # 3) Residual Connection x = x + DROPPED_MASKED_ATTENTION del DROPPED_MASKED_ATTENTION # 4) Layer Normalization x = self.__layer_norm_1(x) if not decoder_only: # 5) Encoder–decoder (cross) attention CROSS_ATTENTION = self.__cross_attention( x, k_x, v_x, key_padding_mask=src_padding_mask ) # 6) Dropout DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION) del CROSS_ATTENTION # 7) Residual Connection x = x + DROPPED_CROSS_ATTENTION del DROPPED_CROSS_ATTENTION # 8) Layer Normalization x = self.__layer_norm_2(x) # 9) Position-wise feed-forward FEED_FORWARD = self.__feed_forward_network(x) # 10) Dropout DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD) del FEED_FORWARD # 11) Residual Connection x = x + DROPPED_FEED_FORWARD del DROPPED_FEED_FORWARD # 12) Layer Normalization x = self.__layer_norm_3(x) return (x, k_x, v_x, src_padding_mask, tgt_padding_mask, decoder_only) # use eval to disable dropout ecc