import torch import Project_Model.Libs.Embedder as Embedder from ..Classes import Encoder, Decoder, DeToken class TrainingModel(torch.nn.Module): def __init__( self, vocabulary_size: int, latent_space: int = 256, feed_forward_multiplier: int = 4, attention_heads: int = 4, layer_number: int = 2, ) -> None: super().__init__() feed_forward_latent_space = latent_space * feed_forward_multiplier self.__encoder_embedder = Embedder.NanoSocratesEmbedder( vocabulary_size, latent_space ) self.__decoder_embedder = Embedder.NanoSocratesEmbedder( vocabulary_size, latent_space ) # do NOT share layer weights enc_layers = [ Encoder(latent_space, feed_forward_latent_space, attention_heads) for _ in range(layer_number) ] dec_layers = [ Decoder(latent_space, feed_forward_latent_space, attention_heads) for _ in range(layer_number) ] self.__encoder = torch.nn.Sequential(*enc_layers) self.__decoder = torch.nn.Sequential(*dec_layers) self.__detokener = DeToken(latent_space, vocabulary_size) def forward( self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ): # returns logits for the LAST decoder position only -> [B, V] ( encoder_embedder_input, # [B,S] encoder tokens encoder_padding_mask, # [B,S] True where encoder is PAD decoder_embedder_prefix, # [B,Tp] decoder prefix (e.g., + tokens so far) decoder_padding_mask, # [B,Tp] True where decoder prefix has PAD ) = args # 1) embeddings encoder_tensor = self.__encoder_embedder(encoder_embedder_input) # [B,S,E] decoder_tensor = self.__decoder_embedder(decoder_embedder_prefix) # [B,Tp,E] # 2) encode encoder_output, _ = self.__encoder((encoder_tensor, encoder_padding_mask)) # [B,S,E], [B,S] # 3) decode (causal mask is built inside the decoder) decoder_output, _, _, _, _ = self.__decoder( (decoder_tensor, encoder_output, encoder_output, decoder_padding_mask, encoder_padding_mask) ) # [B,Tp,E], ... # 4) project only the last time step last_hidden = decoder_output[:, -1:, :] # [B,1,E] step_logits = self.__detokener(last_hidden) # [B,1,V] step_logits = step_logits[:, -1, :] # [B,V] return step_logits # logits for one token