18 lines
820 B
Python

# it is position wise!
# https://stackoverflow.com/questions/74979359/how-is-position-wise-feed-forward-neural-network-implemented-for-transformers
import torch
import torch.nn as nn
class FeedForwardNetwork(nn.Module):
def __init__(self, d_model, d_ff):
super(FeedForwardNetwork, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff) # expand in higher dimension
self.activation = nn.ReLU()
self.dropout = nn.Dropout(0.1) # during training we drop something, with eval it got deactivated
self.fc2 = nn.Linear(d_ff, d_model) # return into the model dimension
def forward(self, x):
# -> NN1 -> RELU -> (Droput during training) -> NN2 ->
return self.fc2(self.dropout(self.activation(self.fc1(x))))