32 lines
1.3 KiB
Python
32 lines
1.3 KiB
Python
|
|
import torch
|
|
import torch.nn as nn
|
|
from Transformer.feed_forward_nn import FeedForwardNetwork
|
|
from Transformer.pytorch_multi_head_attention import TorchMultiHeadAttention as MultiHeadAttention
|
|
|
|
|
|
class Encoder(nn.Module): # in this way we expose the primitive of nn.Module for training purpose
|
|
|
|
def __init__(self, d_model:int, d_ff: int, attention_heads:int) -> None:
|
|
super().__init__()
|
|
self.attention = MultiHeadAttention(d_model, attention_heads, dropout=0.1)
|
|
self.norm1 = nn.LayerNorm(d_model) # norm of first "Add and Normalize"
|
|
self.ffn = FeedForwardNetwork(d_model, d_ff)
|
|
self.norm2 = nn.LayerNorm(d_model) # norm of second "Add and Normalize"
|
|
self.dropout = nn.Dropout(0.1) # ...
|
|
pass
|
|
|
|
def forward(self, x):
|
|
# -> ATTENTION -> dropout -> add and normalize -> FF -> dropout -> add and normalize ->
|
|
# Attention with Residual Connection [ input + self-attention]
|
|
x = x + self.dropout(self.attention(x, x, x))
|
|
x = self.norm1(x)
|
|
|
|
# Feedforward with Residual Connection [ normed self-attention + ff]
|
|
x = x + self.dropout(self.ffn(x))
|
|
x = self.norm2(x)
|
|
return x
|
|
|
|
|
|
|
|
# use eval to disable dropout ecc |