78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
|
|
import random
|
||
|
|
import sys
|
||
|
|
|
||
|
|
|
||
|
|
class TokenMasker:
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
change_token_probability: float = 0.15,
|
||
|
|
mask_token_probability: float = 0.8,
|
||
|
|
random_token_prob: float = 0.1,
|
||
|
|
seed: int = random.randint(0, sys.maxsize),
|
||
|
|
) -> None:
|
||
|
|
|
||
|
|
if change_token_probability < 0 or change_token_probability > 1:
|
||
|
|
raise ValueError("received a value that is not between 0 or 1")
|
||
|
|
|
||
|
|
if mask_token_probability < 0 or mask_token_probability > 1:
|
||
|
|
raise ValueError("received a value that is not between 0 or 1")
|
||
|
|
|
||
|
|
if random_token_prob < 0 or random_token_prob > 1:
|
||
|
|
raise ValueError("received a value that is not between 0 or 1")
|
||
|
|
|
||
|
|
if mask_token_probability + random_token_prob > 1:
|
||
|
|
raise ValueError("The sum of probabilities is over 1")
|
||
|
|
|
||
|
|
self.__change_token_probability = change_token_probability
|
||
|
|
self.__mask_token_probability = mask_token_probability
|
||
|
|
self.__random_token_prob = random_token_prob
|
||
|
|
self.__rng = random.Random(seed)
|
||
|
|
|
||
|
|
def mask_sequence(
|
||
|
|
self, token_sequence: list[int], max_vocabulary: int, mask_id: int
|
||
|
|
) -> list[int]:
|
||
|
|
|
||
|
|
if mask_id <= max_vocabulary:
|
||
|
|
raise ValueError("mask_id is a value of vocabulary")
|
||
|
|
|
||
|
|
MASKED_SEQUENCE: list[int] = []
|
||
|
|
|
||
|
|
for token in token_sequence:
|
||
|
|
|
||
|
|
if token > max_vocabulary:
|
||
|
|
MASKED_SEQUENCE.append(token)
|
||
|
|
continue
|
||
|
|
|
||
|
|
MASKED_TOKEN = self.__mask(token, max_vocabulary, mask_id)
|
||
|
|
MASKED_SEQUENCE.append(MASKED_TOKEN)
|
||
|
|
|
||
|
|
return MASKED_SEQUENCE
|
||
|
|
|
||
|
|
def __mask(self, token: int, max_vocabulary: int, mask_id: int) -> int:
|
||
|
|
|
||
|
|
if self.__random_probability() > self.__change_token_probability:
|
||
|
|
return token
|
||
|
|
|
||
|
|
MASK_TOKEN_TRESH = self.__mask_token_probability
|
||
|
|
RANDOM_TOKEN_TRESH = MASK_TOKEN_TRESH + self.__random_token_prob
|
||
|
|
CHANCE_PROBABILITY = self.__random_probability()
|
||
|
|
|
||
|
|
# It's over both probabilities, return same token
|
||
|
|
if CHANCE_PROBABILITY > RANDOM_TOKEN_TRESH:
|
||
|
|
return token
|
||
|
|
|
||
|
|
# It's over masking treshold, but lower than random
|
||
|
|
# return random token
|
||
|
|
if CHANCE_PROBABILITY > MASK_TOKEN_TRESH:
|
||
|
|
return self.__random_token(max_vocabulary)
|
||
|
|
|
||
|
|
# It's below masking treshold, mask token
|
||
|
|
return mask_id
|
||
|
|
|
||
|
|
def __random_probability(self) -> float:
|
||
|
|
return self.__rng.random()
|
||
|
|
|
||
|
|
def __random_token(self, max_vocabulary: int) -> int:
|
||
|
|
return self.__rng.randint(0, max_vocabulary)
|