import torch import torch.nn as nn from Transformer.feed_forward_nn import FeedForwardNetwork from Transformer.pytorch_multi_head_attention import TorchMultiHeadAttention as MultiHeadAttention class Decoder(nn.Module): def __init__(self, d_model:int, d_ff: int, attention_heads:int) -> None: super().__init__() self._masked_attention = MultiHeadAttention(d_model, attention_heads, dropout=0.1) self.norm1 = nn.LayerNorm(d_model) self.attention = MultiHeadAttention(d_model, attention_heads, dropout=0.1) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(0.1) self.ffn = FeedForwardNetwork(d_model, d_ff) self.norm3 = nn.LayerNorm(d_model) pass def forward(self, x, k_x,v_x, attention_mask): # k_x = v_x . While x_q = x # 1) Masked self-attention x = x + self.dropout(self._masked_attention(x, x, x, attention_mask= attention_mask)) x = self.norm1(x) # 2) Encoder–decoder (cross) attention x = x + self.dropout(self.attention(x, k_x, v_x)) x = self.norm2(x) # 3) Position-wise feed-forward x = x + self.dropout(self.ffn(x)) x = self.norm3(x) return x # use eval to disable dropout ecc