diff --git a/Project_Model/Libs/Transformer/Classes/DeToken.py b/Project_Model/Libs/Transformer/Classes/DeToken.py new file mode 100644 index 0000000..c0b961e --- /dev/null +++ b/Project_Model/Libs/Transformer/Classes/DeToken.py @@ -0,0 +1,19 @@ +import torch + + +class DeToken(torch.nn.Module): + + def __init__(self, embedding_size: int, vocabulary_size: int) -> None: + super().__init__() + + self.__linear = torch.nn.Linear(embedding_size, vocabulary_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # 1) Go from latent space to vocabularu space + x = self.__linear(x) + + # 2) Go to logits + x = torch.softmax(x, 2) + + return x diff --git a/Project_Model/Libs/Transformer/Classes/Decoder.py b/Project_Model/Libs/Transformer/Classes/Decoder.py index 0a818ee..a9c7907 100644 --- a/Project_Model/Libs/Transformer/Classes/Decoder.py +++ b/Project_Model/Libs/Transformer/Classes/Decoder.py @@ -35,22 +35,28 @@ class Decoder(nn.Module): ) self.__layer_norm_3 = nn.LayerNorm(embedding_dimension) - - - def forward(self, x, k_x, v_x, padding_mask = None): #-> list[torch.Tensor]: # k_x = v_x . While x_q = x + def forward( + self, + args: tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor + ] + ): # -> list[torch.Tensor]: # k_x = v_x . While x_q = x + # WARNING: args is needed to have sequential + x, k_x, v_x, padding_mask = args # build of attention mask attention_mask = get_causal_attention_mask(x.size(1)) # 1) Masked Attention MASKED_ATTENTION = self.__masked_attention( - x, x, x, key_padding_mask=padding_mask, attn_mask=attention_mask + x, x, x, key_padding_mask=padding_mask, attention_mask=attention_mask ) # 2) Dropout - DROPPED_MASKED_ATTENTION = self.__dropout( - MASKED_ATTENTION - ) + DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION) del MASKED_ATTENTION # 3) Residual Connection @@ -61,7 +67,9 @@ class Decoder(nn.Module): x = self.__layer_norm_1(x) # 5) Encoder–decoder (cross) attention - CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x, key_padding_mask=padding_mask) + CROSS_ATTENTION = self.__cross_attention( + x, k_x, v_x, key_padding_mask=padding_mask + ) # 6) Dropout DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION) @@ -88,7 +96,7 @@ class Decoder(nn.Module): # 12) Layer Normalization x = self.__layer_norm_3(x) - return x, k_x, v_x, padding_mask + return (x, k_x, v_x, padding_mask) # use eval to disable dropout ecc diff --git a/Project_Model/Libs/Transformer/Classes/Encoder.py b/Project_Model/Libs/Transformer/Classes/Encoder.py index 0c46fe0..e232a18 100644 --- a/Project_Model/Libs/Transformer/Classes/Encoder.py +++ b/Project_Model/Libs/Transformer/Classes/Encoder.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from Project_Model.Libs.Transformer.Classes.FeedForwardNetwork import FeedForwardNetwork from Project_Model.Libs.Transformer.Classes.TorchMultiHeadAttention import ( @@ -29,14 +30,17 @@ class Encoder( embedding_dimension ) # norm of second "Add and Normalize" self.__dropout = nn.Dropout(0.1) # ... - pass - def forward(self, x, padding_mask = None): + + def forward(self, args: tuple[torch.Tensor, torch.Tensor]): + # WARNING: args is needed to have sequential + x, padding_mask = args + # -> ATTENTION -> dropout -> add and normalize -> FF -> dropout -> add and normalize -> # Attention with Residual Connection [ input + self-attention] # 1) Multi Head Attention - ATTENTION = self.__attention(x, x, x,key_padding_mask= padding_mask) + ATTENTION = self.__attention(x, x, x, key_padding_mask=padding_mask) # 2) Dropout DROPPED_ATTENTION = self.__dropout(ATTENTION) @@ -62,7 +66,7 @@ class Encoder( # 8) Layer Normalization x = self.__layer_norm_2(x) - return x,padding_mask + return (x, padding_mask) # use eval to disable dropout ecc diff --git a/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py b/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py index 52c0cc5..38aeb6d 100644 --- a/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py +++ b/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py @@ -8,17 +8,16 @@ class TorchMultiHeadAttention(nn.Module): self, embedding_dimension: int, number_of_attention_heads: int, - dropout: float = 0.0, + dropout: float = 0.0 ): super().__init__() - self.attention = nn.MultiheadAttention( + self.attention = torch.nn.MultiheadAttention( embedding_dimension, - number_of_attention_heads, + num_heads=number_of_attention_heads, dropout=dropout, batch_first=True, ) - def forward( self, x_q: torch.Tensor,