import torch import torch.nn as nn from Transformer.feed_forward_nn import FeedForwardNetwork from Transformer.pytorch_multi_head_attention import TorchMultiHeadAttention as MultiHeadAttention class Encoder(nn.Module): # in this way we expose the primitive of nn.Module for training purpose def __init__(self, d_model:int, d_ff: int, attention_heads:int) -> None: super().__init__() self.attention = MultiHeadAttention(d_model, attention_heads, dropout=0.1) self.norm1 = nn.LayerNorm(d_model) # norm of first "Add and Normalize" self.ffn = FeedForwardNetwork(d_model, d_ff) self.norm2 = nn.LayerNorm(d_model) # norm of second "Add and Normalize" self.dropout = nn.Dropout(0.1) # ... pass def forward(self, x): # -> ATTENTION -> dropout -> add and normalize -> FF -> dropout -> add and normalize -> # Attention with Residual Connection [ input + self-attention] x = x + self.dropout(self.attention(x, x, x)) x = self.norm1(x) # Feedforward with Residual Connection [ normed self-attention + ff] x = x + self.dropout(self.ffn(x)) x = self.norm2(x) return x # use eval to disable dropout ecc