32 lines
1.3 KiB
Python

import torch
import torch.nn as nn
from Transformer.feed_forward_nn import FeedForwardNetwork
from Transformer.pytorch_multi_head_attention import TorchMultiHeadAttention as MultiHeadAttention
class Encoder(nn.Module): # in this way we expose the primitive of nn.Module for training purpose
def __init__(self, d_model:int, d_ff: int, attention_heads:int) -> None:
super().__init__()
self.attention = MultiHeadAttention(d_model, attention_heads, dropout=0.1)
self.norm1 = nn.LayerNorm(d_model) # norm of first "Add and Normalize"
self.ffn = FeedForwardNetwork(d_model, d_ff)
self.norm2 = nn.LayerNorm(d_model) # norm of second "Add and Normalize"
self.dropout = nn.Dropout(0.1) # ...
pass
def forward(self, x):
# -> ATTENTION -> dropout -> add and normalize -> FF -> dropout -> add and normalize ->
# Attention with Residual Connection [ input + self-attention]
x = x + self.dropout(self.attention(x, x, x))
x = self.norm1(x)
# Feedforward with Residual Connection [ normed self-attention + ff]
x = x + self.dropout(self.ffn(x))
x = self.norm2(x)
return x
# use eval to disable dropout ecc