import random import sys class TokenMasker: def __init__( self, change_token_probability: float = 0.15, mask_token_probability: float = 0.8, random_token_prob: float = 0.1, seed: int = random.randint(0, sys.maxsize), ) -> None: if change_token_probability < 0 or change_token_probability > 1: raise ValueError("received a value that is not between 0 or 1") if mask_token_probability < 0 or mask_token_probability > 1: raise ValueError("received a value that is not between 0 or 1") if random_token_prob < 0 or random_token_prob > 1: raise ValueError("received a value that is not between 0 or 1") if mask_token_probability + random_token_prob > 1: raise ValueError("The sum of probabilities is over 1") self.__change_token_probability = change_token_probability self.__mask_token_probability = mask_token_probability self.__random_token_prob = random_token_prob self.__rng = random.Random(seed) def mask_sequence( self, token_sequence: list[int], max_vocabulary: int, mask_id: int ) -> list[int]: if mask_id <= max_vocabulary: raise ValueError("mask_id is a value of vocabulary") MASKED_SEQUENCE: list[int] = [] for token in token_sequence: if token > max_vocabulary: MASKED_SEQUENCE.append(token) continue MASKED_TOKEN = self.__mask(token, max_vocabulary, mask_id) MASKED_SEQUENCE.append(MASKED_TOKEN) return MASKED_SEQUENCE def __mask(self, token: int, max_vocabulary: int, mask_id: int) -> int: if self.__random_probability() > self.__change_token_probability: return token MASK_TOKEN_TRESH = self.__mask_token_probability RANDOM_TOKEN_TRESH = MASK_TOKEN_TRESH + self.__random_token_prob CHANCE_PROBABILITY = self.__random_probability() # It's over both probabilities, return same token if CHANCE_PROBABILITY > RANDOM_TOKEN_TRESH: return token # It's over masking treshold, but lower than random # return random token if CHANCE_PROBABILITY > MASK_TOKEN_TRESH: return self.__random_token(max_vocabulary) # It's below masking treshold, mask token return mask_id def __random_probability(self) -> float: return self.__rng.random() def __random_token(self, max_vocabulary: int) -> int: return self.__rng.randint(0, max_vocabulary)