From fc44929a7b81d34de3337088326c5672cf34252f Mon Sep 17 00:00:00 2001 From: GassiGiuseppe Date: Tue, 7 Oct 2025 23:15:50 +0200 Subject: [PATCH] moved spanned mask variables in init for better reliability, also tested --- .../Libs/Transformer/Classes/SpannedMasker.py | 14 +++++++++----- Project_Model/Tests/spanned_masker_test.py | 6 ++++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/Project_Model/Libs/Transformer/Classes/SpannedMasker.py b/Project_Model/Libs/Transformer/Classes/SpannedMasker.py index 441a3d8..4be18be 100644 --- a/Project_Model/Libs/Transformer/Classes/SpannedMasker.py +++ b/Project_Model/Libs/Transformer/Classes/SpannedMasker.py @@ -7,9 +7,12 @@ class SpannedMasker: def __init__( self, + max_vocabulary: int, + forbidden_tokens: set[int], change_token_probability: float = 0.15, average_span: int = 1, seed: int = random.randint(0, sys.maxsize), + ) -> None: if change_token_probability < 0 or change_token_probability > 1: @@ -18,17 +21,18 @@ class SpannedMasker: self.__change_token_probability = change_token_probability self.__average_span = average_span self.__rng = random.Random(seed) + self.__max_vocabulary = max_vocabulary + self.__forbidden_tokens = forbidden_tokens + def mask_sequence( self, token_sequence: list[int], - max_vocabulary: int, - forbidden_tokens: set[int] ) -> tuple[list[int], list[int]]: - MASK = self.__create_mask(token_sequence, forbidden_tokens) - MASKED = self.__create_masked_input(token_sequence, MASK, max_vocabulary) - TARGET = self.__create_target(token_sequence, MASK, max_vocabulary) + MASK = self.__create_mask(token_sequence, self.__forbidden_tokens) + MASKED = self.__create_masked_input(token_sequence, MASK, self.__max_vocabulary) + TARGET = self.__create_target(token_sequence, MASK, self.__max_vocabulary) return (MASKED, TARGET) diff --git a/Project_Model/Tests/spanned_masker_test.py b/Project_Model/Tests/spanned_masker_test.py index a29bdfa..a256203 100644 --- a/Project_Model/Tests/spanned_masker_test.py +++ b/Project_Model/Tests/spanned_masker_test.py @@ -21,7 +21,7 @@ class TestSpannedMasker: TOKENIZER = BPE.TokeNanoCore(VOCABULARY, SPECIAL_LIST) VOCABULARY_SIZE = TOKENIZER.vocabulary_size - MASKER = Transformer.SpannedMasker(CORRUPTION_PERCENTAGE, 3) + TOKENS = TOKENIZER.encode(TEXT) @@ -31,10 +31,12 @@ class TestSpannedMasker: ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS) + MASKER = Transformer.SpannedMasker(VOCABULARY_SIZE,ILLEGAL_TOKENS,CORRUPTION_PERCENTAGE, 3) + SPECIAL_FORMATTER = TOKENIZER.encode("*")[0] END_FORMATTER = TOKENIZER.encode("")[0] - OUTPUT, TARGET = MASKER.mask_sequence(TOKENS, VOCABULARY_SIZE, ILLEGAL_TOKENS) + OUTPUT, TARGET = MASKER.mask_sequence(TOKENS) UNCORRUPTED_TOKENS = list( filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT)