moved spanned mask variables in init for better reliability, also tested
This commit is contained in:
parent
96cbf4eabb
commit
fc44929a7b
@ -7,9 +7,12 @@ class SpannedMasker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_vocabulary: int,
|
||||
forbidden_tokens: set[int],
|
||||
change_token_probability: float = 0.15,
|
||||
average_span: int = 1,
|
||||
seed: int = random.randint(0, sys.maxsize),
|
||||
|
||||
) -> None:
|
||||
|
||||
if change_token_probability < 0 or change_token_probability > 1:
|
||||
@ -18,17 +21,18 @@ class SpannedMasker:
|
||||
self.__change_token_probability = change_token_probability
|
||||
self.__average_span = average_span
|
||||
self.__rng = random.Random(seed)
|
||||
self.__max_vocabulary = max_vocabulary
|
||||
self.__forbidden_tokens = forbidden_tokens
|
||||
|
||||
|
||||
def mask_sequence(
|
||||
self,
|
||||
token_sequence: list[int],
|
||||
max_vocabulary: int,
|
||||
forbidden_tokens: set[int]
|
||||
) -> tuple[list[int], list[int]]:
|
||||
|
||||
MASK = self.__create_mask(token_sequence, forbidden_tokens)
|
||||
MASKED = self.__create_masked_input(token_sequence, MASK, max_vocabulary)
|
||||
TARGET = self.__create_target(token_sequence, MASK, max_vocabulary)
|
||||
MASK = self.__create_mask(token_sequence, self.__forbidden_tokens)
|
||||
MASKED = self.__create_masked_input(token_sequence, MASK, self.__max_vocabulary)
|
||||
TARGET = self.__create_target(token_sequence, MASK, self.__max_vocabulary)
|
||||
|
||||
return (MASKED, TARGET)
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ class TestSpannedMasker:
|
||||
TOKENIZER = BPE.TokeNanoCore(VOCABULARY, SPECIAL_LIST)
|
||||
VOCABULARY_SIZE = TOKENIZER.vocabulary_size
|
||||
|
||||
MASKER = Transformer.SpannedMasker(CORRUPTION_PERCENTAGE, 3)
|
||||
|
||||
|
||||
TOKENS = TOKENIZER.encode(TEXT)
|
||||
|
||||
@ -31,10 +31,12 @@ class TestSpannedMasker:
|
||||
|
||||
ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS)
|
||||
|
||||
MASKER = Transformer.SpannedMasker(VOCABULARY_SIZE,ILLEGAL_TOKENS,CORRUPTION_PERCENTAGE, 3)
|
||||
|
||||
SPECIAL_FORMATTER = TOKENIZER.encode("*<SOT>")[0]
|
||||
END_FORMATTER = TOKENIZER.encode("<EOT>")[0]
|
||||
|
||||
OUTPUT, TARGET = MASKER.mask_sequence(TOKENS, VOCABULARY_SIZE, ILLEGAL_TOKENS)
|
||||
OUTPUT, TARGET = MASKER.mask_sequence(TOKENS)
|
||||
|
||||
UNCORRUPTED_TOKENS = list(
|
||||
filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user