diff --git a/Project_Model/Libs/Transformer/Models/TrainingModel.py b/Project_Model/Libs/Transformer/Models/TrainingModel.py new file mode 100644 index 0000000..ebd8010 --- /dev/null +++ b/Project_Model/Libs/Transformer/Models/TrainingModel.py @@ -0,0 +1,56 @@ +import torch +import Project_Model.Libs.Embedder as Embedder +from ..Classes import Encoder, Decoder, DeToken + + +class TrainingModel(torch.nn.Module): + + def __init__( + self, + vocabulary_size: int, + latent_space: int = 256, + feed_forward_multiplier: int = 4, + attention_heads: int = 4, + layer_number: int = 2, + ) -> None: + super().__init__() + + feed_forward_latent_space = latent_space * feed_forward_multiplier + + self.__encoder_embedder = Embedder.NanoSocratesEmbedder( + vocabulary_size, latent_space + ) + self.__decoder_embedder = Embedder.NanoSocratesEmbedder( + vocabulary_size, latent_space + ) + + TMP_ENCODERS = [ + Encoder(latent_space, feed_forward_latent_space, attention_heads) + ] * layer_number + + TMP_DECODERS = [ + Decoder(latent_space, feed_forward_latent_space, attention_heads) + ] * layer_number + + self.__encoder = torch.nn.Sequential(*TMP_ENCODERS) + self.__decoder = torch.nn.Sequential(*TMP_DECODERS) + + self.__detokener = DeToken(latent_space, vocabulary_size) + + def forward(self, args: tuple[list[list[int]], list[list[bool]], list[list[int]]]): + + encoder_embedder_input, padding_input, decoder_embedder_input = args + + encoder_tensor = self.__encoder_embedder(encoder_embedder_input) + padding_tensor = torch.tensor(padding_input, dtype=torch.bool) + decoder_tensor = self.__decoder_embedder(decoder_embedder_input) + + encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor)) + + decoder_output, _, _, _ = self.__decoder( + (decoder_tensor, encoder_tensor, encoder_tensor, None) + ) + + logits: torch.Tensor = self.__detokener(decoder_output) + + return logits