diff --git a/Project_Model/Libs/Transformer/Classes/SpannedMasker.py b/Project_Model/Libs/Transformer/Classes/SpannedMasker.py new file mode 100644 index 0000000..156f512 --- /dev/null +++ b/Project_Model/Libs/Transformer/Classes/SpannedMasker.py @@ -0,0 +1,204 @@ +import math +import random +import sys + + +class SpannedMasker: + + def __init__( + self, + change_token_probability: float = 0.15, + average_span: int = 1, + seed: int = random.randint(0, sys.maxsize), + ) -> None: + + if change_token_probability < 0 or change_token_probability > 1: + raise ValueError("received a value that is not between 0 or 1") + + self.__change_token_probability = change_token_probability + self.__average_span = average_span + self.__rng = random.Random(seed) + + def mask_sequence( + self, + token_sequence: list[int], + max_vocabulary: int, + forbidden_tokens: set[int] + ) -> tuple[list[int], list[int]]: + + MASK = self.__create_mask(token_sequence, forbidden_tokens) + MASKED = self.__create_masked_input(token_sequence, MASK, max_vocabulary) + TARGET = self.__create_target(token_sequence, MASK, max_vocabulary) + + return (MASKED, TARGET) + + + + + def __number_of_spans(self, legal_token_number: int): + EXPECTED_NUM_OF_CORRUPTED_TOKENS = self.__number_of_corrupted_tokens(legal_token_number) + + return math.ceil(EXPECTED_NUM_OF_CORRUPTED_TOKENS / self.__average_span) + + def __number_of_corrupted_tokens(self, legal_token_number: int): + EXPECTED_NUM_OF_CORRUPTED_TOKENS = math.ceil( + legal_token_number * self.__change_token_probability + ) + + return EXPECTED_NUM_OF_CORRUPTED_TOKENS + + def __create_mask(self, sequence: list[int], forbidden_tokens: set[int]) -> list[bool]: + + SEQ_LEN = len(sequence) + LEGAL_TOKENS = self.__count_legal_tokens(sequence, forbidden_tokens) + NUM_OF_CORRUPTIONS = self.__number_of_corrupted_tokens(LEGAL_TOKENS) + NUM_OF_SPANS = self.__number_of_spans(LEGAL_TOKENS) + MASK = [False] * SEQ_LEN + + mask_index = 0 + number_of_spans = 0 + exit_loop = False + + while not exit_loop: + + TOKEN = sequence[mask_index] + MASKED = MASK[mask_index] + SHOULD_MASK = self.__random_mask() + skip = False + + + if self.__is_illegal_token(TOKEN, forbidden_tokens): + skip = True + + if MASKED: + skip = True + + if not SHOULD_MASK: + skip = True + + if skip: + mask_index = (mask_index + 1) % SEQ_LEN + continue + + + CANDIDATE_SPAN = self.__random_span( + self.__average_span + ) + + REMAINING_MASK = SEQ_LEN - (mask_index + 1) + + SPAN_LENGTH = min(CANDIDATE_SPAN, REMAINING_MASK) + + for _ in range(0, SPAN_LENGTH): + MASK[mask_index] = True + mask_index += 1 + + number_of_spans += 1 + mask_index += 1 + + if number_of_spans == NUM_OF_SPANS: + exit_loop = True + continue + + if mask_index >= SEQ_LEN - 1: + exit_loop = True + continue + + return MASK + + def __create_masked_input(self, sequence: list[int], mask: list[bool], max_voc: int) -> list[int]: + + OUT: list[int] = [] + mask_token_id = max_voc + 1 + index = 0 + while index < len(sequence): + + TOKEN = sequence[index] + MASKED = mask[index] + + if not MASKED: + OUT.append( + TOKEN + ) + index += 1 + continue + + MASK_TOKEN = mask_token_id + OUT.append( + MASK_TOKEN + ) + + while mask[index]: + index += 1 + + mask_token_id += 1 + + return OUT + + def __create_target(self, sequence: list[int], mask: list[bool], max_voc: int) -> list[int]: + + OUT: list[int] = [] + mask_token_id = max_voc + 1 + index = 0 + while index < len(sequence): + + TOKEN = sequence[index] + MASKED = mask[index] + + if MASKED: + OUT.append( + TOKEN + ) + index += 1 + continue + + MASK_TOKEN = mask_token_id + OUT.append( + MASK_TOKEN + ) + + while index < len(mask) and not mask[index]: + index += 1 + + mask_token_id += 1 + + + return OUT + + def __is_illegal_token(self, token: int, illegal_voc: set[int]) -> bool: + if token in illegal_voc: + return True + + return False + + def __count_legal_tokens(self, sequence: list[int], illegal_voc: set[int]) -> int: + legal_count = 0 + + for token in sequence: + if self.__is_illegal_token(token, illegal_voc): + continue + legal_count += 1 + + return legal_count + + def __random_mask(self) -> bool: + + if self.__random_probability() > self.__change_token_probability: + return False + + return True + + def __random_probability(self) -> float: + return self.__rng.random() + + def __random_token(self, max_vocabulary: int) -> int: + return self.__rng.randint(0, max_vocabulary) + + def __random_int_range(self, min: int, max: int) -> int: + return self.__rng.randint(min, max) + + def __random_span(self, average: int) -> int: + candidate_span = self.__rng.gauss(mu=average) + candidate_span = max(1, candidate_span) + candidate_span = round(candidate_span) + return candidate_span diff --git a/Project_Model/Libs/Transformer/Classes/TokenMasker.py b/Project_Model/Libs/Transformer/Classes/TokenMasker.py new file mode 100644 index 0000000..9664328 --- /dev/null +++ b/Project_Model/Libs/Transformer/Classes/TokenMasker.py @@ -0,0 +1,77 @@ +import random +import sys + + +class TokenMasker: + + def __init__( + self, + change_token_probability: float = 0.15, + mask_token_probability: float = 0.8, + random_token_prob: float = 0.1, + seed: int = random.randint(0, sys.maxsize), + ) -> None: + + if change_token_probability < 0 or change_token_probability > 1: + raise ValueError("received a value that is not between 0 or 1") + + if mask_token_probability < 0 or mask_token_probability > 1: + raise ValueError("received a value that is not between 0 or 1") + + if random_token_prob < 0 or random_token_prob > 1: + raise ValueError("received a value that is not between 0 or 1") + + if mask_token_probability + random_token_prob > 1: + raise ValueError("The sum of probabilities is over 1") + + self.__change_token_probability = change_token_probability + self.__mask_token_probability = mask_token_probability + self.__random_token_prob = random_token_prob + self.__rng = random.Random(seed) + + def mask_sequence( + self, token_sequence: list[int], max_vocabulary: int, mask_id: int + ) -> list[int]: + + if mask_id <= max_vocabulary: + raise ValueError("mask_id is a value of vocabulary") + + MASKED_SEQUENCE: list[int] = [] + + for token in token_sequence: + + if token > max_vocabulary: + MASKED_SEQUENCE.append(token) + continue + + MASKED_TOKEN = self.__mask(token, max_vocabulary, mask_id) + MASKED_SEQUENCE.append(MASKED_TOKEN) + + return MASKED_SEQUENCE + + def __mask(self, token: int, max_vocabulary: int, mask_id: int) -> int: + + if self.__random_probability() > self.__change_token_probability: + return token + + MASK_TOKEN_TRESH = self.__mask_token_probability + RANDOM_TOKEN_TRESH = MASK_TOKEN_TRESH + self.__random_token_prob + CHANCE_PROBABILITY = self.__random_probability() + + # It's over both probabilities, return same token + if CHANCE_PROBABILITY > RANDOM_TOKEN_TRESH: + return token + + # It's over masking treshold, but lower than random + # return random token + if CHANCE_PROBABILITY > MASK_TOKEN_TRESH: + return self.__random_token(max_vocabulary) + + # It's below masking treshold, mask token + return mask_id + + def __random_probability(self) -> float: + return self.__rng.random() + + def __random_token(self, max_vocabulary: int) -> int: + return self.__rng.randint(0, max_vocabulary)