moved spanned mask variables in init for better reliability, also tested
This commit is contained in:
parent
96cbf4eabb
commit
fc44929a7b
@ -7,9 +7,12 @@ class SpannedMasker:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
max_vocabulary: int,
|
||||||
|
forbidden_tokens: set[int],
|
||||||
change_token_probability: float = 0.15,
|
change_token_probability: float = 0.15,
|
||||||
average_span: int = 1,
|
average_span: int = 1,
|
||||||
seed: int = random.randint(0, sys.maxsize),
|
seed: int = random.randint(0, sys.maxsize),
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
if change_token_probability < 0 or change_token_probability > 1:
|
if change_token_probability < 0 or change_token_probability > 1:
|
||||||
@ -18,17 +21,18 @@ class SpannedMasker:
|
|||||||
self.__change_token_probability = change_token_probability
|
self.__change_token_probability = change_token_probability
|
||||||
self.__average_span = average_span
|
self.__average_span = average_span
|
||||||
self.__rng = random.Random(seed)
|
self.__rng = random.Random(seed)
|
||||||
|
self.__max_vocabulary = max_vocabulary
|
||||||
|
self.__forbidden_tokens = forbidden_tokens
|
||||||
|
|
||||||
|
|
||||||
def mask_sequence(
|
def mask_sequence(
|
||||||
self,
|
self,
|
||||||
token_sequence: list[int],
|
token_sequence: list[int],
|
||||||
max_vocabulary: int,
|
|
||||||
forbidden_tokens: set[int]
|
|
||||||
) -> tuple[list[int], list[int]]:
|
) -> tuple[list[int], list[int]]:
|
||||||
|
|
||||||
MASK = self.__create_mask(token_sequence, forbidden_tokens)
|
MASK = self.__create_mask(token_sequence, self.__forbidden_tokens)
|
||||||
MASKED = self.__create_masked_input(token_sequence, MASK, max_vocabulary)
|
MASKED = self.__create_masked_input(token_sequence, MASK, self.__max_vocabulary)
|
||||||
TARGET = self.__create_target(token_sequence, MASK, max_vocabulary)
|
TARGET = self.__create_target(token_sequence, MASK, self.__max_vocabulary)
|
||||||
|
|
||||||
return (MASKED, TARGET)
|
return (MASKED, TARGET)
|
||||||
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class TestSpannedMasker:
|
|||||||
TOKENIZER = BPE.TokeNanoCore(VOCABULARY, SPECIAL_LIST)
|
TOKENIZER = BPE.TokeNanoCore(VOCABULARY, SPECIAL_LIST)
|
||||||
VOCABULARY_SIZE = TOKENIZER.vocabulary_size
|
VOCABULARY_SIZE = TOKENIZER.vocabulary_size
|
||||||
|
|
||||||
MASKER = Transformer.SpannedMasker(CORRUPTION_PERCENTAGE, 3)
|
|
||||||
|
|
||||||
TOKENS = TOKENIZER.encode(TEXT)
|
TOKENS = TOKENIZER.encode(TEXT)
|
||||||
|
|
||||||
@ -31,10 +31,12 @@ class TestSpannedMasker:
|
|||||||
|
|
||||||
ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS)
|
ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS)
|
||||||
|
|
||||||
|
MASKER = Transformer.SpannedMasker(VOCABULARY_SIZE,ILLEGAL_TOKENS,CORRUPTION_PERCENTAGE, 3)
|
||||||
|
|
||||||
SPECIAL_FORMATTER = TOKENIZER.encode("*<SOT>")[0]
|
SPECIAL_FORMATTER = TOKENIZER.encode("*<SOT>")[0]
|
||||||
END_FORMATTER = TOKENIZER.encode("<EOT>")[0]
|
END_FORMATTER = TOKENIZER.encode("<EOT>")[0]
|
||||||
|
|
||||||
OUTPUT, TARGET = MASKER.mask_sequence(TOKENS, VOCABULARY_SIZE, ILLEGAL_TOKENS)
|
OUTPUT, TARGET = MASKER.mask_sequence(TOKENS)
|
||||||
|
|
||||||
UNCORRUPTED_TOKENS = list(
|
UNCORRUPTED_TOKENS = list(
|
||||||
filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT)
|
filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user