2025-10-05 15:40:29 +02:00
|
|
|
import torch.nn as nn
|
|
|
|
|
from Project_Model.Libs.Transformer.Classes.FeedForwardNetwork import FeedForwardNetwork
|
|
|
|
|
from Project_Model.Libs.Transformer.Classes.TorchMultiHeadAttention import (
|
|
|
|
|
TorchMultiHeadAttention as MultiHeadAttention,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Encoder(
|
|
|
|
|
nn.Module
|
|
|
|
|
): # in this way we expose the primitive of nn.Module for training purpose
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
embedding_dimension: int,
|
|
|
|
|
feed_forward_hidden_layer_dimension: int,
|
|
|
|
|
number_of_attention_heads: int,
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.__attention = MultiHeadAttention(
|
|
|
|
|
embedding_dimension, number_of_attention_heads, dropout=0.1
|
|
|
|
|
)
|
|
|
|
|
self.__layer_norm_1 = nn.LayerNorm(
|
|
|
|
|
embedding_dimension
|
|
|
|
|
) # norm of first "Add and Normalize"
|
|
|
|
|
self.__feed_forward = FeedForwardNetwork(
|
|
|
|
|
embedding_dimension, feed_forward_hidden_layer_dimension
|
|
|
|
|
)
|
|
|
|
|
self.__layer_norm_2 = nn.LayerNorm(
|
|
|
|
|
embedding_dimension
|
|
|
|
|
) # norm of second "Add and Normalize"
|
|
|
|
|
self.__dropout = nn.Dropout(0.1) # ...
|
|
|
|
|
pass
|
|
|
|
|
|
2025-10-05 18:46:06 +02:00
|
|
|
def forward(self, x, padding_mask = None):
|
2025-10-05 15:40:29 +02:00
|
|
|
# -> ATTENTION -> dropout -> add and normalize -> FF -> dropout -> add and normalize ->
|
|
|
|
|
# Attention with Residual Connection [ input + self-attention]
|
|
|
|
|
|
|
|
|
|
# 1) Multi Head Attention
|
2025-10-05 18:46:06 +02:00
|
|
|
ATTENTION = self.__attention(x, x, x,key_padding_mask= padding_mask)
|
2025-10-05 15:40:29 +02:00
|
|
|
|
|
|
|
|
# 2) Dropout
|
|
|
|
|
DROPPED_ATTENTION = self.__dropout(ATTENTION)
|
|
|
|
|
del ATTENTION
|
|
|
|
|
|
|
|
|
|
# 3) Residual Connection
|
|
|
|
|
x = x + DROPPED_ATTENTION
|
|
|
|
|
|
|
|
|
|
# 4) Layer Normalization
|
|
|
|
|
x = self.__layer_norm_1(x)
|
|
|
|
|
|
|
|
|
|
# 5) Feed Forward
|
|
|
|
|
FEED_FORWARD = self.__feed_forward(x)
|
|
|
|
|
|
|
|
|
|
# 6) Dropout
|
|
|
|
|
DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
|
|
|
|
|
del FEED_FORWARD
|
|
|
|
|
|
|
|
|
|
# 7) Residual Connection
|
|
|
|
|
x = x + DROPPED_FEED_FORWARD
|
|
|
|
|
del DROPPED_FEED_FORWARD
|
|
|
|
|
|
|
|
|
|
# 8) Layer Normalization
|
|
|
|
|
x = self.__layer_norm_2(x)
|
|
|
|
|
|
2025-10-06 18:20:46 +02:00
|
|
|
return x,padding_mask
|
2025-10-05 15:40:29 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# use eval to disable dropout ecc
|