import torch import Project_Model.Libs.Embedder as Embedder from ..Classes import Encoder, Decoder, DeToken from ..Utils import get_decoder_input from Project_Model.Libs.Batch import TaskType class NanoSocratesCore(torch.nn.Module): def __init__( self, vocabulary_size: int, sentence_max_length: int, sos: int, pad: int, eos: int, latent_space: int = 256, feed_forward_multiplier: int = 4, attention_heads: int = 4, layer_number: int = 2, ) -> None: super().__init__() self.__sos = sos self.__pad = pad self.__eos = eos self.__sentence_len = sentence_max_length feed_forward_latent_space = latent_space * feed_forward_multiplier self.__encoder_embedder = Embedder.NanoSocratesEmbedder( vocabulary_size, latent_space ) self.__decoder_embedder = Embedder.NanoSocratesEmbedder( vocabulary_size, latent_space ) TMP_ENCODERS = [ Encoder(latent_space, feed_forward_latent_space, attention_heads) ] * layer_number TMP_DECODERS = [ Decoder(latent_space, feed_forward_latent_space, attention_heads) ] * layer_number self.__encoder = torch.nn.Sequential(*TMP_ENCODERS) self.__decoder = torch.nn.Sequential(*TMP_DECODERS) self.__detokener = DeToken(latent_space, vocabulary_size) self.__encoder_detokener = DeToken(latent_space, vocabulary_size) def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): encoder_embedder_input, src_padding, decoder_embedder_input, tgt_padding = args encoder_tensor = self.__encoder_embedder(encoder_embedder_input) decoder_tensor = self.__decoder_embedder(decoder_embedder_input) encoder_output, _ = self.__encoder((encoder_tensor, src_padding)) decoder_output, _, _, _, _, _ = self.__decoder( (decoder_tensor, encoder_output, encoder_output, src_padding, tgt_padding, False) ) logits: torch.Tensor = self.__detokener(decoder_output) return logits def inference(self, input: tuple[torch.Tensor, torch.Tensor], task_type: TaskType) -> torch.Tensor: if task_type == TaskType.MASKING: return self.__masking(input) if task_type == TaskType.COMPLETATION: return self.__continue_rdf(input) return self.__text_generation(input) def __text_generation(self, args: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: x, padding = args encoder_tensor = self.__encoder_embedder(x) BATCH, SEQ_LEN, _ = x.shape encoder_output, _ = self.__encoder((encoder_tensor, padding)) decoder_in = get_decoder_input(BATCH, self.__sos, self.__pad, SEQ_LEN) decoder_in_pad_mask = decoder_in.eq(self.__pad) continue_generating = True token_idx = 0 while continue_generating: decoder_in = self.__decoder_embedder(decoder_in) decoder_output, _, _, _, _, _ = self.__decoder( (decoder_in, encoder_output, encoder_output, padding, decoder_in_pad_mask, False) ) logits: torch.Tensor = self.__detokener(decoder_output) logits = torch.softmax(logits, 2) tokens = torch.argmax(logits) if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos: continue_generating = False continue if token_idx < self.__sentence_len - 1: decoder_in[:,token_idx + 1] = tokens[:,token_idx] decoder_in_pad_mask = decoder_in.eq(self.__pad) return decoder_in def __masking(self, args: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: x, padding = args encoder_tensor = self.__encoder_embedder(x) x, _ = self.__encoder((encoder_tensor, padding)) logits: torch.Tensor = self.__encoder_detokener(x) del x logits = torch.softmax(logits, 2) tokens = torch.argmax(logits) return tokens def __continue_rdf(self, args: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: decoder_in, _ = args decoder_in_prefix_mask = decoder_in.eq(self.__pad) decoder_in_pad_mask = decoder_in.eq(self.__pad) continue_generating = True token_idx = 0 while continue_generating: decoder_in = self.__decoder_embedder(decoder_in) decoder_output, _, _, _, _, _ = self.__decoder( (decoder_in, decoder_in, decoder_in, decoder_in_prefix_mask, decoder_in_pad_mask, False) ) logits: torch.Tensor = self.__detokener(decoder_output) logits = torch.softmax(logits, 2) tokens = torch.argmax(logits) if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos: continue_generating = False continue if token_idx < self.__sentence_len - 1: decoder_in[:,token_idx + 1] = tokens[:,token_idx] decoder_in_pad_mask = decoder_in.eq(self.__pad) return decoder_in def take_pieces(self): return ( (self.__encoder_embedder, self.__encoder), (self.__decoder_embedder, self.__decoder, self.__detokener) )