import torch import Project_Model.Libs.Embedder as Embedder from ..Classes import Encoder, Decoder, DeToken class TrainingModel(torch.nn.Module): def __init__( self, vocabulary_size: int, latent_space: int = 256, feed_forward_multiplier: int = 4, attention_heads: int = 4, layer_number: int = 2, ) -> None: super().__init__() feed_forward_latent_space = latent_space * feed_forward_multiplier self.__encoder_embedder = Embedder.NanoSocratesEmbedder( vocabulary_size, latent_space ) self.__decoder_embedder = Embedder.NanoSocratesEmbedder( vocabulary_size, latent_space ) TMP_ENCODERS = [ Encoder(latent_space, feed_forward_latent_space, attention_heads) ] * layer_number TMP_DECODERS = [ Decoder(latent_space, feed_forward_latent_space, attention_heads) ] * layer_number self.__encoder = torch.nn.Sequential(*TMP_ENCODERS) self.__decoder = torch.nn.Sequential(*TMP_DECODERS) self.__detokener = DeToken(latent_space, vocabulary_size) def forward(self, args: tuple[list[list[int]], list[list[bool]], list[list[int]]]): encoder_embedder_input, padding_input, decoder_embedder_input = args encoder_tensor = self.__encoder_embedder(encoder_embedder_input) padding_tensor = torch.tensor(padding_input, dtype=torch.bool) decoder_tensor = self.__decoder_embedder(decoder_embedder_input) encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor)) decoder_output, _, _, _ = self.__decoder( (decoder_tensor, encoder_tensor, encoder_tensor, None) ) logits: torch.Tensor = self.__detokener(decoder_output) return logits