From 144f8724d6ceb3940a3731215b095bd8cb010d1f Mon Sep 17 00:00:00 2001 From: GassiGiuseppe Date: Sun, 12 Oct 2025 16:35:42 +0200 Subject: [PATCH] Update of the batcher to resolve a bug in the 4th construction --- Project_Model/Libs/Batch/Classes/Batcher.py | 45 ++++++++++++++++++- .../Libs/Transformer/Classes/Decoder.py | 3 +- .../Libs/Transformer/Utils/attention_mask.py | 2 + 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/Project_Model/Libs/Batch/Classes/Batcher.py b/Project_Model/Libs/Batch/Classes/Batcher.py index 7459fa0..be08e0f 100644 --- a/Project_Model/Libs/Batch/Classes/Batcher.py +++ b/Project_Model/Libs/Batch/Classes/Batcher.py @@ -170,13 +170,54 @@ class Batcher: X = [] Y = [] for rdf in batch["RDFs"]: + # here first truncate to max_lenght + rdf = rdf[: self.__max_length] # truncator that uses "eot" so no problem x, y = self._completation_task_token_truncator( rdf, 0.5, continue_triple_token, eot, minibatch_seed ) X.append(x) Y.append(y) - return self.__normalization(X, Y) + return self.__token_cmpletation_task_special_normalization(X, Y) + def __token_cmpletation_task_special_normalization(self, X: list[list[int]], Y: list[list[int]] + ) -> tuple[list[list[int]], list[list[int]], list[list[int]], list[list[int]]]: + + def continue_rdf_padding(sequence: list[int], pad_token: int): + for i, x in enumerate(sequence): + if x == pad_token: + i = i+1 # continueRDF will be excluded by the mask + # fill the tail with True and stop + return [False] * i + [True] * (len(sequence) - i) + return [False] * len(sequence) # no pad token found + + pad_token = self._tokenizer.encode(SpecialToken.PAD.value)[0] + end_token = self._tokenizer.encode(SpecialToken.END_OF_SEQUENCE.value)[0] + continue_rdf = self._tokenizer.encode(SpecialToken.CONTINUE_RDF.value)[0] + out_X = [] + padding_X = [] + out_Y = [] + padding_Y = [] + + for x in X: + out_x, _ = normalize_sequence( + x, self.__max_length, pad_token, end_token, True + ) + out_X.append(out_x) + # padding_X.append(padding_x) + special_padding = continue_rdf_padding(out_x,continue_rdf) + padding_X.append(special_padding) + + for y in Y: + out_y, padding_y = normalize_sequence( + y, self.__max_length, pad_token, end_token, True + ) + out_Y.append(out_y) + # special padding + # special_padding = continue_rdf_padding(out_y,continue_rdf) + # padding_Y.append(special_padding) + padding_Y.append(padding_Y) + + return out_X, out_Y, padding_X, padding_Y if __name__ == "__main__": @@ -194,6 +235,6 @@ if __name__ == "__main__": prova = "Cactus Flower is a 1969 American screwball comedy film directed by Gene Saks, and starring Walter Matthau, Ingrid Bergman and Goldie Hawn, who won an Academy Award for her performance.The screenplay was adapted by I. A. L. Diamond from the 1965 Broadway play of the same title written by Abe Burrows, which, in turn, is based on the French play Fleur de cactus by Pierre Barillet and Jean-Pierre Gredy. Cactus Flower was the ninth highest-grossing film of 1969." print(TOKENANO.encode(prova)) - batcher = Batcher(DATASET_PATH, TOKENANO, MASKER) + batcher = Batcher(DATASET_PATH,256, TOKENANO, MASKER) for batch in batcher.batch(8): print(batch) diff --git a/Project_Model/Libs/Transformer/Classes/Decoder.py b/Project_Model/Libs/Transformer/Classes/Decoder.py index 10422b8..11abe06 100644 --- a/Project_Model/Libs/Transformer/Classes/Decoder.py +++ b/Project_Model/Libs/Transformer/Classes/Decoder.py @@ -20,7 +20,6 @@ class Decoder(nn.Module): super().__init__() - self.__masked_attention = MultiHeadAttention( embedding_dimension, number_of_attention_heads, dropout=0.1 ) @@ -58,7 +57,7 @@ class Decoder(nn.Module): # build of attention mask # TODO: create a prefix causal mask if needed if decoder_only: - attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads) + attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),src_padding_mask,self.__attention_heads) # the correct is tgt however ... else: attention_mask = get_causal_attention_mask(x.size(1)) diff --git a/Project_Model/Libs/Transformer/Utils/attention_mask.py b/Project_Model/Libs/Transformer/Utils/attention_mask.py index a6c838a..6782504 100644 --- a/Project_Model/Libs/Transformer/Utils/attention_mask.py +++ b/Project_Model/Libs/Transformer/Utils/attention_mask.py @@ -24,5 +24,7 @@ def get_prefix_causal_mask_from_padding_mask(seq_len:int, prefix_mask, att_heads prefix_causal_mask = prefix_causal_mask.repeat_interleave(att_heads, dim=0) # B*H,T,T return prefix_causal_mask +#def get_prefix_causal_mask(): +# continue_rdf =