18 lines
820 B
Python
18 lines
820 B
Python
# it is position wise!
|
|
# https://stackoverflow.com/questions/74979359/how-is-position-wise-feed-forward-neural-network-implemented-for-transformers
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class FeedForwardNetwork(nn.Module):
|
|
def __init__(self, d_model, d_ff):
|
|
super(FeedForwardNetwork, self).__init__()
|
|
self.fc1 = nn.Linear(d_model, d_ff) # expand in higher dimension
|
|
self.activation = nn.ReLU()
|
|
self.dropout = nn.Dropout(0.1) # during training we drop something, with eval it got deactivated
|
|
self.fc2 = nn.Linear(d_ff, d_model) # return into the model dimension
|
|
|
|
|
|
def forward(self, x):
|
|
# -> NN1 -> RELU -> (Droput during training) -> NN2 ->
|
|
return self.fc2(self.dropout(self.activation(self.fc1(x)))) |