Added 2 types of masking
This commit is contained in:
parent
49f0beb6ea
commit
c217f5dec9
204
Project_Model/Libs/Transformer/Classes/SpannedMasker.py
Normal file
204
Project_Model/Libs/Transformer/Classes/SpannedMasker.py
Normal file
@ -0,0 +1,204 @@
|
||||
import math
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
||||
class SpannedMasker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
change_token_probability: float = 0.15,
|
||||
average_span: int = 1,
|
||||
seed: int = random.randint(0, sys.maxsize),
|
||||
) -> None:
|
||||
|
||||
if change_token_probability < 0 or change_token_probability > 1:
|
||||
raise ValueError("received a value that is not between 0 or 1")
|
||||
|
||||
self.__change_token_probability = change_token_probability
|
||||
self.__average_span = average_span
|
||||
self.__rng = random.Random(seed)
|
||||
|
||||
def mask_sequence(
|
||||
self,
|
||||
token_sequence: list[int],
|
||||
max_vocabulary: int,
|
||||
forbidden_tokens: set[int]
|
||||
) -> tuple[list[int], list[int]]:
|
||||
|
||||
MASK = self.__create_mask(token_sequence, forbidden_tokens)
|
||||
MASKED = self.__create_masked_input(token_sequence, MASK, max_vocabulary)
|
||||
TARGET = self.__create_target(token_sequence, MASK, max_vocabulary)
|
||||
|
||||
return (MASKED, TARGET)
|
||||
|
||||
|
||||
|
||||
|
||||
def __number_of_spans(self, legal_token_number: int):
|
||||
EXPECTED_NUM_OF_CORRUPTED_TOKENS = self.__number_of_corrupted_tokens(legal_token_number)
|
||||
|
||||
return math.ceil(EXPECTED_NUM_OF_CORRUPTED_TOKENS / self.__average_span)
|
||||
|
||||
def __number_of_corrupted_tokens(self, legal_token_number: int):
|
||||
EXPECTED_NUM_OF_CORRUPTED_TOKENS = math.ceil(
|
||||
legal_token_number * self.__change_token_probability
|
||||
)
|
||||
|
||||
return EXPECTED_NUM_OF_CORRUPTED_TOKENS
|
||||
|
||||
def __create_mask(self, sequence: list[int], forbidden_tokens: set[int]) -> list[bool]:
|
||||
|
||||
SEQ_LEN = len(sequence)
|
||||
LEGAL_TOKENS = self.__count_legal_tokens(sequence, forbidden_tokens)
|
||||
NUM_OF_CORRUPTIONS = self.__number_of_corrupted_tokens(LEGAL_TOKENS)
|
||||
NUM_OF_SPANS = self.__number_of_spans(LEGAL_TOKENS)
|
||||
MASK = [False] * SEQ_LEN
|
||||
|
||||
mask_index = 0
|
||||
number_of_spans = 0
|
||||
exit_loop = False
|
||||
|
||||
while not exit_loop:
|
||||
|
||||
TOKEN = sequence[mask_index]
|
||||
MASKED = MASK[mask_index]
|
||||
SHOULD_MASK = self.__random_mask()
|
||||
skip = False
|
||||
|
||||
|
||||
if self.__is_illegal_token(TOKEN, forbidden_tokens):
|
||||
skip = True
|
||||
|
||||
if MASKED:
|
||||
skip = True
|
||||
|
||||
if not SHOULD_MASK:
|
||||
skip = True
|
||||
|
||||
if skip:
|
||||
mask_index = (mask_index + 1) % SEQ_LEN
|
||||
continue
|
||||
|
||||
|
||||
CANDIDATE_SPAN = self.__random_span(
|
||||
self.__average_span
|
||||
)
|
||||
|
||||
REMAINING_MASK = SEQ_LEN - (mask_index + 1)
|
||||
|
||||
SPAN_LENGTH = min(CANDIDATE_SPAN, REMAINING_MASK)
|
||||
|
||||
for _ in range(0, SPAN_LENGTH):
|
||||
MASK[mask_index] = True
|
||||
mask_index += 1
|
||||
|
||||
number_of_spans += 1
|
||||
mask_index += 1
|
||||
|
||||
if number_of_spans == NUM_OF_SPANS:
|
||||
exit_loop = True
|
||||
continue
|
||||
|
||||
if mask_index >= SEQ_LEN - 1:
|
||||
exit_loop = True
|
||||
continue
|
||||
|
||||
return MASK
|
||||
|
||||
def __create_masked_input(self, sequence: list[int], mask: list[bool], max_voc: int) -> list[int]:
|
||||
|
||||
OUT: list[int] = []
|
||||
mask_token_id = max_voc + 1
|
||||
index = 0
|
||||
while index < len(sequence):
|
||||
|
||||
TOKEN = sequence[index]
|
||||
MASKED = mask[index]
|
||||
|
||||
if not MASKED:
|
||||
OUT.append(
|
||||
TOKEN
|
||||
)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
MASK_TOKEN = mask_token_id
|
||||
OUT.append(
|
||||
MASK_TOKEN
|
||||
)
|
||||
|
||||
while mask[index]:
|
||||
index += 1
|
||||
|
||||
mask_token_id += 1
|
||||
|
||||
return OUT
|
||||
|
||||
def __create_target(self, sequence: list[int], mask: list[bool], max_voc: int) -> list[int]:
|
||||
|
||||
OUT: list[int] = []
|
||||
mask_token_id = max_voc + 1
|
||||
index = 0
|
||||
while index < len(sequence):
|
||||
|
||||
TOKEN = sequence[index]
|
||||
MASKED = mask[index]
|
||||
|
||||
if MASKED:
|
||||
OUT.append(
|
||||
TOKEN
|
||||
)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
MASK_TOKEN = mask_token_id
|
||||
OUT.append(
|
||||
MASK_TOKEN
|
||||
)
|
||||
|
||||
while index < len(mask) and not mask[index]:
|
||||
index += 1
|
||||
|
||||
mask_token_id += 1
|
||||
|
||||
|
||||
return OUT
|
||||
|
||||
def __is_illegal_token(self, token: int, illegal_voc: set[int]) -> bool:
|
||||
if token in illegal_voc:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def __count_legal_tokens(self, sequence: list[int], illegal_voc: set[int]) -> int:
|
||||
legal_count = 0
|
||||
|
||||
for token in sequence:
|
||||
if self.__is_illegal_token(token, illegal_voc):
|
||||
continue
|
||||
legal_count += 1
|
||||
|
||||
return legal_count
|
||||
|
||||
def __random_mask(self) -> bool:
|
||||
|
||||
if self.__random_probability() > self.__change_token_probability:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __random_probability(self) -> float:
|
||||
return self.__rng.random()
|
||||
|
||||
def __random_token(self, max_vocabulary: int) -> int:
|
||||
return self.__rng.randint(0, max_vocabulary)
|
||||
|
||||
def __random_int_range(self, min: int, max: int) -> int:
|
||||
return self.__rng.randint(min, max)
|
||||
|
||||
def __random_span(self, average: int) -> int:
|
||||
candidate_span = self.__rng.gauss(mu=average)
|
||||
candidate_span = max(1, candidate_span)
|
||||
candidate_span = round(candidate_span)
|
||||
return candidate_span
|
||||
77
Project_Model/Libs/Transformer/Classes/TokenMasker.py
Normal file
77
Project_Model/Libs/Transformer/Classes/TokenMasker.py
Normal file
@ -0,0 +1,77 @@
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
||||
class TokenMasker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
change_token_probability: float = 0.15,
|
||||
mask_token_probability: float = 0.8,
|
||||
random_token_prob: float = 0.1,
|
||||
seed: int = random.randint(0, sys.maxsize),
|
||||
) -> None:
|
||||
|
||||
if change_token_probability < 0 or change_token_probability > 1:
|
||||
raise ValueError("received a value that is not between 0 or 1")
|
||||
|
||||
if mask_token_probability < 0 or mask_token_probability > 1:
|
||||
raise ValueError("received a value that is not between 0 or 1")
|
||||
|
||||
if random_token_prob < 0 or random_token_prob > 1:
|
||||
raise ValueError("received a value that is not between 0 or 1")
|
||||
|
||||
if mask_token_probability + random_token_prob > 1:
|
||||
raise ValueError("The sum of probabilities is over 1")
|
||||
|
||||
self.__change_token_probability = change_token_probability
|
||||
self.__mask_token_probability = mask_token_probability
|
||||
self.__random_token_prob = random_token_prob
|
||||
self.__rng = random.Random(seed)
|
||||
|
||||
def mask_sequence(
|
||||
self, token_sequence: list[int], max_vocabulary: int, mask_id: int
|
||||
) -> list[int]:
|
||||
|
||||
if mask_id <= max_vocabulary:
|
||||
raise ValueError("mask_id is a value of vocabulary")
|
||||
|
||||
MASKED_SEQUENCE: list[int] = []
|
||||
|
||||
for token in token_sequence:
|
||||
|
||||
if token > max_vocabulary:
|
||||
MASKED_SEQUENCE.append(token)
|
||||
continue
|
||||
|
||||
MASKED_TOKEN = self.__mask(token, max_vocabulary, mask_id)
|
||||
MASKED_SEQUENCE.append(MASKED_TOKEN)
|
||||
|
||||
return MASKED_SEQUENCE
|
||||
|
||||
def __mask(self, token: int, max_vocabulary: int, mask_id: int) -> int:
|
||||
|
||||
if self.__random_probability() > self.__change_token_probability:
|
||||
return token
|
||||
|
||||
MASK_TOKEN_TRESH = self.__mask_token_probability
|
||||
RANDOM_TOKEN_TRESH = MASK_TOKEN_TRESH + self.__random_token_prob
|
||||
CHANCE_PROBABILITY = self.__random_probability()
|
||||
|
||||
# It's over both probabilities, return same token
|
||||
if CHANCE_PROBABILITY > RANDOM_TOKEN_TRESH:
|
||||
return token
|
||||
|
||||
# It's over masking treshold, but lower than random
|
||||
# return random token
|
||||
if CHANCE_PROBABILITY > MASK_TOKEN_TRESH:
|
||||
return self.__random_token(max_vocabulary)
|
||||
|
||||
# It's below masking treshold, mask token
|
||||
return mask_id
|
||||
|
||||
def __random_probability(self) -> float:
|
||||
return self.__rng.random()
|
||||
|
||||
def __random_token(self, max_vocabulary: int) -> int:
|
||||
return self.__rng.randint(0, max_vocabulary)
|
||||
Loading…
x
Reference in New Issue
Block a user