import math import random import sys class SpannedMasker: def __init__( self, max_vocabulary: int, forbidden_tokens: set[int], change_token_probability: float = 0.15, average_span: int = 1, seed: int = random.randint(0, sys.maxsize), ) -> None: if change_token_probability < 0 or change_token_probability > 1: raise ValueError("received a value that is not between 0 or 1") self.__change_token_probability = change_token_probability self.__average_span = average_span self.__rng = random.Random(seed) self.__max_vocabulary = max_vocabulary self.__forbidden_tokens = forbidden_tokens def reseed(self, seed:int): self.__rng = random.Random(seed) def mask_sequence( self, token_sequence: list[int], ) -> tuple[list[int], list[int]]: MASK = self.__create_mask(token_sequence, self.__forbidden_tokens) MASKED = self.__create_masked_input(token_sequence, MASK, self.__max_vocabulary) TARGET = self.__create_target(token_sequence, MASK, self.__max_vocabulary) return (MASKED, TARGET) def __number_of_spans(self, legal_token_number: int): EXPECTED_NUM_OF_CORRUPTED_TOKENS = self.__number_of_corrupted_tokens(legal_token_number) return math.ceil(EXPECTED_NUM_OF_CORRUPTED_TOKENS / self.__average_span) def __number_of_corrupted_tokens(self, legal_token_number: int): EXPECTED_NUM_OF_CORRUPTED_TOKENS = math.ceil( legal_token_number * self.__change_token_probability ) return EXPECTED_NUM_OF_CORRUPTED_TOKENS def __create_mask(self, sequence: list[int], forbidden_tokens: set[int]) -> list[bool]: SEQ_LEN = len(sequence) LEGAL_TOKENS = self.__count_legal_tokens(sequence, forbidden_tokens) NUM_OF_CORRUPTIONS = self.__number_of_corrupted_tokens(LEGAL_TOKENS) NUM_OF_SPANS = self.__number_of_spans(LEGAL_TOKENS) MASK = [False] * SEQ_LEN mask_index = 0 number_of_spans = 0 exit_loop = False while not exit_loop: TOKEN = sequence[mask_index] MASKED = MASK[mask_index] SHOULD_MASK = self.__random_mask() skip = False if self.__is_illegal_token(TOKEN, forbidden_tokens): skip = True if MASKED: skip = True if not SHOULD_MASK: skip = True if skip: mask_index = (mask_index + 1) % SEQ_LEN continue CANDIDATE_SPAN = self.__random_span( self.__average_span ) REMAINING_MASK = SEQ_LEN - (mask_index + 1) SPAN_LENGTH = min(CANDIDATE_SPAN, REMAINING_MASK) for _ in range(0, SPAN_LENGTH): INNER_TOKEN = sequence[mask_index] if self.__is_illegal_token(INNER_TOKEN, forbidden_tokens): continue MASK[mask_index] = True mask_index += 1 number_of_spans += 1 mask_index += 1 if number_of_spans == NUM_OF_SPANS: exit_loop = True continue if mask_index >= SEQ_LEN - 1: exit_loop = True continue return MASK def __create_masked_input(self, sequence: list[int], mask: list[bool], max_voc: int) -> list[int]: OUT: list[int] = [] mask_token_id = max_voc + 1 index = 0 while index < len(sequence): TOKEN = sequence[index] MASKED = mask[index] if not MASKED: OUT.append( TOKEN ) index += 1 continue MASK_TOKEN = mask_token_id OUT.append( MASK_TOKEN ) while mask[index]: index += 1 mask_token_id += 1 return OUT def __create_target(self, sequence: list[int], mask: list[bool], max_voc: int) -> list[int]: OUT: list[int] = [] mask_token_id = max_voc + 1 index = 0 while index < len(sequence): TOKEN = sequence[index] MASKED = mask[index] if MASKED: OUT.append( TOKEN ) index += 1 continue MASK_TOKEN = mask_token_id OUT.append( MASK_TOKEN ) while index < len(mask) and not mask[index]: index += 1 mask_token_id += 1 return OUT def __is_illegal_token(self, token: int, illegal_voc: set[int]) -> bool: if token in illegal_voc: return True return False def __count_legal_tokens(self, sequence: list[int], illegal_voc: set[int]) -> int: legal_count = 0 for token in sequence: if self.__is_illegal_token(token, illegal_voc): continue legal_count += 1 return legal_count def __random_mask(self) -> bool: if self.__random_probability() > self.__change_token_probability: return False return True def __random_probability(self) -> float: return self.__rng.random() def __random_token(self, max_vocabulary: int) -> int: return self.__rng.randint(0, max_vocabulary) def __random_int_range(self, min: int, max: int) -> int: return self.__rng.randint(min, max) def __random_span(self, average: int) -> int: candidate_span = self.__rng.gauss(mu=average) candidate_span = max(1, candidate_span) candidate_span = round(candidate_span) return candidate_span