2025-10-06 15:45:45 +02:00
|
|
|
import math
|
|
|
|
|
import random
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpannedMasker:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
change_token_probability: float = 0.15,
|
|
|
|
|
average_span: int = 1,
|
|
|
|
|
seed: int = random.randint(0, sys.maxsize),
|
|
|
|
|
) -> None:
|
|
|
|
|
|
|
|
|
|
if change_token_probability < 0 or change_token_probability > 1:
|
|
|
|
|
raise ValueError("received a value that is not between 0 or 1")
|
|
|
|
|
|
|
|
|
|
self.__change_token_probability = change_token_probability
|
|
|
|
|
self.__average_span = average_span
|
|
|
|
|
self.__rng = random.Random(seed)
|
|
|
|
|
|
|
|
|
|
def mask_sequence(
|
|
|
|
|
self,
|
|
|
|
|
token_sequence: list[int],
|
|
|
|
|
max_vocabulary: int,
|
|
|
|
|
forbidden_tokens: set[int]
|
|
|
|
|
) -> tuple[list[int], list[int]]:
|
|
|
|
|
|
|
|
|
|
MASK = self.__create_mask(token_sequence, forbidden_tokens)
|
|
|
|
|
MASKED = self.__create_masked_input(token_sequence, MASK, max_vocabulary)
|
|
|
|
|
TARGET = self.__create_target(token_sequence, MASK, max_vocabulary)
|
|
|
|
|
|
|
|
|
|
return (MASKED, TARGET)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __number_of_spans(self, legal_token_number: int):
|
|
|
|
|
EXPECTED_NUM_OF_CORRUPTED_TOKENS = self.__number_of_corrupted_tokens(legal_token_number)
|
|
|
|
|
|
|
|
|
|
return math.ceil(EXPECTED_NUM_OF_CORRUPTED_TOKENS / self.__average_span)
|
|
|
|
|
|
|
|
|
|
def __number_of_corrupted_tokens(self, legal_token_number: int):
|
|
|
|
|
EXPECTED_NUM_OF_CORRUPTED_TOKENS = math.ceil(
|
|
|
|
|
legal_token_number * self.__change_token_probability
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return EXPECTED_NUM_OF_CORRUPTED_TOKENS
|
|
|
|
|
|
|
|
|
|
def __create_mask(self, sequence: list[int], forbidden_tokens: set[int]) -> list[bool]:
|
|
|
|
|
|
|
|
|
|
SEQ_LEN = len(sequence)
|
|
|
|
|
LEGAL_TOKENS = self.__count_legal_tokens(sequence, forbidden_tokens)
|
|
|
|
|
NUM_OF_CORRUPTIONS = self.__number_of_corrupted_tokens(LEGAL_TOKENS)
|
|
|
|
|
NUM_OF_SPANS = self.__number_of_spans(LEGAL_TOKENS)
|
|
|
|
|
MASK = [False] * SEQ_LEN
|
|
|
|
|
|
|
|
|
|
mask_index = 0
|
|
|
|
|
number_of_spans = 0
|
|
|
|
|
exit_loop = False
|
|
|
|
|
|
|
|
|
|
while not exit_loop:
|
|
|
|
|
|
|
|
|
|
TOKEN = sequence[mask_index]
|
|
|
|
|
MASKED = MASK[mask_index]
|
|
|
|
|
SHOULD_MASK = self.__random_mask()
|
|
|
|
|
skip = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.__is_illegal_token(TOKEN, forbidden_tokens):
|
|
|
|
|
skip = True
|
|
|
|
|
|
|
|
|
|
if MASKED:
|
|
|
|
|
skip = True
|
|
|
|
|
|
|
|
|
|
if not SHOULD_MASK:
|
|
|
|
|
skip = True
|
|
|
|
|
|
|
|
|
|
if skip:
|
|
|
|
|
mask_index = (mask_index + 1) % SEQ_LEN
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CANDIDATE_SPAN = self.__random_span(
|
|
|
|
|
self.__average_span
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
REMAINING_MASK = SEQ_LEN - (mask_index + 1)
|
|
|
|
|
|
|
|
|
|
SPAN_LENGTH = min(CANDIDATE_SPAN, REMAINING_MASK)
|
|
|
|
|
|
|
|
|
|
for _ in range(0, SPAN_LENGTH):
|
2025-10-06 16:16:47 +02:00
|
|
|
INNER_TOKEN = sequence[mask_index]
|
|
|
|
|
|
|
|
|
|
if self.__is_illegal_token(INNER_TOKEN, forbidden_tokens):
|
|
|
|
|
continue
|
|
|
|
|
|
2025-10-06 15:45:45 +02:00
|
|
|
MASK[mask_index] = True
|
|
|
|
|
mask_index += 1
|
|
|
|
|
|
|
|
|
|
number_of_spans += 1
|
|
|
|
|
mask_index += 1
|
|
|
|
|
|
|
|
|
|
if number_of_spans == NUM_OF_SPANS:
|
|
|
|
|
exit_loop = True
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if mask_index >= SEQ_LEN - 1:
|
|
|
|
|
exit_loop = True
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
return MASK
|
|
|
|
|
|
|
|
|
|
def __create_masked_input(self, sequence: list[int], mask: list[bool], max_voc: int) -> list[int]:
|
|
|
|
|
|
|
|
|
|
OUT: list[int] = []
|
|
|
|
|
mask_token_id = max_voc + 1
|
|
|
|
|
index = 0
|
|
|
|
|
while index < len(sequence):
|
|
|
|
|
|
|
|
|
|
TOKEN = sequence[index]
|
|
|
|
|
MASKED = mask[index]
|
|
|
|
|
|
|
|
|
|
if not MASKED:
|
|
|
|
|
OUT.append(
|
|
|
|
|
TOKEN
|
|
|
|
|
)
|
|
|
|
|
index += 1
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
MASK_TOKEN = mask_token_id
|
|
|
|
|
OUT.append(
|
|
|
|
|
MASK_TOKEN
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
while mask[index]:
|
|
|
|
|
index += 1
|
|
|
|
|
|
|
|
|
|
mask_token_id += 1
|
|
|
|
|
|
|
|
|
|
return OUT
|
|
|
|
|
|
|
|
|
|
def __create_target(self, sequence: list[int], mask: list[bool], max_voc: int) -> list[int]:
|
|
|
|
|
|
|
|
|
|
OUT: list[int] = []
|
|
|
|
|
mask_token_id = max_voc + 1
|
|
|
|
|
index = 0
|
|
|
|
|
while index < len(sequence):
|
|
|
|
|
|
|
|
|
|
TOKEN = sequence[index]
|
|
|
|
|
MASKED = mask[index]
|
|
|
|
|
|
|
|
|
|
if MASKED:
|
|
|
|
|
OUT.append(
|
|
|
|
|
TOKEN
|
|
|
|
|
)
|
|
|
|
|
index += 1
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
MASK_TOKEN = mask_token_id
|
|
|
|
|
OUT.append(
|
|
|
|
|
MASK_TOKEN
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
while index < len(mask) and not mask[index]:
|
|
|
|
|
index += 1
|
|
|
|
|
|
|
|
|
|
mask_token_id += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return OUT
|
|
|
|
|
|
|
|
|
|
def __is_illegal_token(self, token: int, illegal_voc: set[int]) -> bool:
|
|
|
|
|
if token in illegal_voc:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def __count_legal_tokens(self, sequence: list[int], illegal_voc: set[int]) -> int:
|
|
|
|
|
legal_count = 0
|
|
|
|
|
|
|
|
|
|
for token in sequence:
|
|
|
|
|
if self.__is_illegal_token(token, illegal_voc):
|
|
|
|
|
continue
|
|
|
|
|
legal_count += 1
|
|
|
|
|
|
|
|
|
|
return legal_count
|
|
|
|
|
|
|
|
|
|
def __random_mask(self) -> bool:
|
|
|
|
|
|
|
|
|
|
if self.__random_probability() > self.__change_token_probability:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def __random_probability(self) -> float:
|
|
|
|
|
return self.__rng.random()
|
|
|
|
|
|
|
|
|
|
def __random_token(self, max_vocabulary: int) -> int:
|
|
|
|
|
return self.__rng.randint(0, max_vocabulary)
|
|
|
|
|
|
|
|
|
|
def __random_int_range(self, min: int, max: int) -> int:
|
|
|
|
|
return self.__rng.randint(min, max)
|
|
|
|
|
|
|
|
|
|
def __random_span(self, average: int) -> int:
|
|
|
|
|
candidate_span = self.__rng.gauss(mu=average)
|
|
|
|
|
candidate_span = max(1, candidate_span)
|
|
|
|
|
candidate_span = round(candidate_span)
|
|
|
|
|
return candidate_span
|