moved spanned mask variables in init for better reliability, also tested
This commit is contained in:
@@ -7,9 +7,12 @@ class SpannedMasker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_vocabulary: int,
|
||||
forbidden_tokens: set[int],
|
||||
change_token_probability: float = 0.15,
|
||||
average_span: int = 1,
|
||||
seed: int = random.randint(0, sys.maxsize),
|
||||
|
||||
) -> None:
|
||||
|
||||
if change_token_probability < 0 or change_token_probability > 1:
|
||||
@@ -18,17 +21,18 @@ class SpannedMasker:
|
||||
self.__change_token_probability = change_token_probability
|
||||
self.__average_span = average_span
|
||||
self.__rng = random.Random(seed)
|
||||
self.__max_vocabulary = max_vocabulary
|
||||
self.__forbidden_tokens = forbidden_tokens
|
||||
|
||||
|
||||
def mask_sequence(
|
||||
self,
|
||||
token_sequence: list[int],
|
||||
max_vocabulary: int,
|
||||
forbidden_tokens: set[int]
|
||||
) -> tuple[list[int], list[int]]:
|
||||
|
||||
MASK = self.__create_mask(token_sequence, forbidden_tokens)
|
||||
MASKED = self.__create_masked_input(token_sequence, MASK, max_vocabulary)
|
||||
TARGET = self.__create_target(token_sequence, MASK, max_vocabulary)
|
||||
MASK = self.__create_mask(token_sequence, self.__forbidden_tokens)
|
||||
MASKED = self.__create_masked_input(token_sequence, MASK, self.__max_vocabulary)
|
||||
TARGET = self.__create_target(token_sequence, MASK, self.__max_vocabulary)
|
||||
|
||||
return (MASKED, TARGET)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user