Added 2 types of masking
This commit is contained in:
77
Project_Model/Libs/Transformer/Classes/TokenMasker.py
Normal file
77
Project_Model/Libs/Transformer/Classes/TokenMasker.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
||||
class TokenMasker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
change_token_probability: float = 0.15,
|
||||
mask_token_probability: float = 0.8,
|
||||
random_token_prob: float = 0.1,
|
||||
seed: int = random.randint(0, sys.maxsize),
|
||||
) -> None:
|
||||
|
||||
if change_token_probability < 0 or change_token_probability > 1:
|
||||
raise ValueError("received a value that is not between 0 or 1")
|
||||
|
||||
if mask_token_probability < 0 or mask_token_probability > 1:
|
||||
raise ValueError("received a value that is not between 0 or 1")
|
||||
|
||||
if random_token_prob < 0 or random_token_prob > 1:
|
||||
raise ValueError("received a value that is not between 0 or 1")
|
||||
|
||||
if mask_token_probability + random_token_prob > 1:
|
||||
raise ValueError("The sum of probabilities is over 1")
|
||||
|
||||
self.__change_token_probability = change_token_probability
|
||||
self.__mask_token_probability = mask_token_probability
|
||||
self.__random_token_prob = random_token_prob
|
||||
self.__rng = random.Random(seed)
|
||||
|
||||
def mask_sequence(
|
||||
self, token_sequence: list[int], max_vocabulary: int, mask_id: int
|
||||
) -> list[int]:
|
||||
|
||||
if mask_id <= max_vocabulary:
|
||||
raise ValueError("mask_id is a value of vocabulary")
|
||||
|
||||
MASKED_SEQUENCE: list[int] = []
|
||||
|
||||
for token in token_sequence:
|
||||
|
||||
if token > max_vocabulary:
|
||||
MASKED_SEQUENCE.append(token)
|
||||
continue
|
||||
|
||||
MASKED_TOKEN = self.__mask(token, max_vocabulary, mask_id)
|
||||
MASKED_SEQUENCE.append(MASKED_TOKEN)
|
||||
|
||||
return MASKED_SEQUENCE
|
||||
|
||||
def __mask(self, token: int, max_vocabulary: int, mask_id: int) -> int:
|
||||
|
||||
if self.__random_probability() > self.__change_token_probability:
|
||||
return token
|
||||
|
||||
MASK_TOKEN_TRESH = self.__mask_token_probability
|
||||
RANDOM_TOKEN_TRESH = MASK_TOKEN_TRESH + self.__random_token_prob
|
||||
CHANCE_PROBABILITY = self.__random_probability()
|
||||
|
||||
# It's over both probabilities, return same token
|
||||
if CHANCE_PROBABILITY > RANDOM_TOKEN_TRESH:
|
||||
return token
|
||||
|
||||
# It's over masking treshold, but lower than random
|
||||
# return random token
|
||||
if CHANCE_PROBABILITY > MASK_TOKEN_TRESH:
|
||||
return self.__random_token(max_vocabulary)
|
||||
|
||||
# It's below masking treshold, mask token
|
||||
return mask_id
|
||||
|
||||
def __random_probability(self) -> float:
|
||||
return self.__rng.random()
|
||||
|
||||
def __random_token(self, max_vocabulary: int) -> int:
|
||||
return self.__rng.randint(0, max_vocabulary)
|
||||
Reference in New Issue
Block a user