moved spanned mask variables in init for better reliability, also tested

This commit is contained in:
GassiGiuseppe 2025-10-07 23:15:50 +02:00
parent 96cbf4eabb
commit fc44929a7b
2 changed files with 13 additions and 7 deletions

View File

@ -7,9 +7,12 @@ class SpannedMasker:
def __init__(
self,
max_vocabulary: int,
forbidden_tokens: set[int],
change_token_probability: float = 0.15,
average_span: int = 1,
seed: int = random.randint(0, sys.maxsize),
) -> None:
if change_token_probability < 0 or change_token_probability > 1:
@ -18,17 +21,18 @@ class SpannedMasker:
self.__change_token_probability = change_token_probability
self.__average_span = average_span
self.__rng = random.Random(seed)
self.__max_vocabulary = max_vocabulary
self.__forbidden_tokens = forbidden_tokens
def mask_sequence(
self,
token_sequence: list[int],
max_vocabulary: int,
forbidden_tokens: set[int]
) -> tuple[list[int], list[int]]:
MASK = self.__create_mask(token_sequence, forbidden_tokens)
MASKED = self.__create_masked_input(token_sequence, MASK, max_vocabulary)
TARGET = self.__create_target(token_sequence, MASK, max_vocabulary)
MASK = self.__create_mask(token_sequence, self.__forbidden_tokens)
MASKED = self.__create_masked_input(token_sequence, MASK, self.__max_vocabulary)
TARGET = self.__create_target(token_sequence, MASK, self.__max_vocabulary)
return (MASKED, TARGET)

View File

@ -21,7 +21,7 @@ class TestSpannedMasker:
TOKENIZER = BPE.TokeNanoCore(VOCABULARY, SPECIAL_LIST)
VOCABULARY_SIZE = TOKENIZER.vocabulary_size
MASKER = Transformer.SpannedMasker(CORRUPTION_PERCENTAGE, 3)
TOKENS = TOKENIZER.encode(TEXT)
@ -31,10 +31,12 @@ class TestSpannedMasker:
ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS)
MASKER = Transformer.SpannedMasker(VOCABULARY_SIZE,ILLEGAL_TOKENS,CORRUPTION_PERCENTAGE, 3)
SPECIAL_FORMATTER = TOKENIZER.encode("*<SOT>")[0]
END_FORMATTER = TOKENIZER.encode("<EOT>")[0]
OUTPUT, TARGET = MASKER.mask_sequence(TOKENS, VOCABULARY_SIZE, ILLEGAL_TOKENS)
OUTPUT, TARGET = MASKER.mask_sequence(TOKENS)
UNCORRUPTED_TOKENS = list(
filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT)