diff --git a/Project_Model/Tests/spanned_masker_test.py b/Project_Model/Tests/spanned_masker_test.py index e78d04e..3b70491 100644 --- a/Project_Model/Tests/spanned_masker_test.py +++ b/Project_Model/Tests/spanned_masker_test.py @@ -8,30 +8,26 @@ VOCABULARY_PATH = Path("Assets/Model/toy_10/toy_dictionary.json") VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH) SPECIAL_LIST = BPE.default_special_tokens() + class TestSpannedMasker: def test_spanned_masking(self): CORPUS_PATH = Path("Project_Model/Tests/spanner_file/mask.txt") TEXT = CORPUS_PATH.read_text("utf-8") + CORRUPTION_PERCENTAGE = 0.15 + TOLERANCE = 0.05 - TOKENIZER = BPE.TokeNanoCore( - VOCABULARY, - SPECIAL_LIST - ) + TOKENIZER = BPE.TokeNanoCore(VOCABULARY, SPECIAL_LIST) VOCABULARY_SIZE = TOKENIZER.vocabulary_size - MASKER = Transformer.SpannedMasker(0.4,average_span=3) + MASKER = Transformer.SpannedMasker(CORRUPTION_PERCENTAGE, 3) TOKENS = TOKENIZER.encode(TEXT) - LEGAL_TOKENS: set[int] = set(TOKENIZER.encode( - "" - )) + LEGAL_TOKENS: set[int] = set(TOKENIZER.encode("")) - SPECIAL_TOKENS: set[int] = set(TOKENIZER.encode( - "".join(SPECIAL_LIST) - )) + SPECIAL_TOKENS: set[int] = set(TOKENIZER.encode("".join(SPECIAL_LIST))) ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS) @@ -40,34 +36,52 @@ class TestSpannedMasker: OUTPUT, TARGET = MASKER.mask_sequence(TOKENS, VOCABULARY_SIZE, ILLEGAL_TOKENS) - UNCORRUPTED_TOKENS = list(filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT)) + UNCORRUPTED_TOKENS = list( + filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT) + ) CORRUPTED_TOKENS = list(filter(lambda token: token <= VOCABULARY_SIZE, TARGET)) TARGET.append(END_FORMATTER) - OUTPUT = list(map(lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, OUTPUT)) - TARGET = list(map(lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, TARGET)) + OUTPUT = list( + map( + lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, + OUTPUT, + ) + ) + TARGET = list( + map( + lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, + TARGET, + ) + ) OUT_TEXT = TOKENIZER.decode(OUTPUT) TAR_TEXT = TOKENIZER.decode(TARGET) + ACTUAL_CORRUPTION_PERCENTAGE = len(CORRUPTED_TOKENS) / len(TOKENS) + print(f"Original text:\n\n{TEXT}") print(f"Inputs:\n\n{OUT_TEXT}") print(f"Targets:\n\n{TAR_TEXT}") + print(f"Target Tokens:\n\n{OUTPUT}") - print("\n".join([ - f"======================", - f"Original length: {len(TOKENS)}", - f"Uncorrupted Chars: {len(UNCORRUPTED_TOKENS)}", - f"Corrupted Chars: {len(CORRUPTED_TOKENS)}", - f"Percentage_corruption: {(len(CORRUPTED_TOKENS)/len(TOKENS))*100}%", - f"======================" - ])) + print( + "\n".join( + [ + f"======================", + f"Original length: {len(TOKENS)}", + f"Uncorrupted Chars: {len(UNCORRUPTED_TOKENS)}", + f"Corrupted Chars: {len(CORRUPTED_TOKENS)}", + f"Percentage_corruption: {(len(CORRUPTED_TOKENS)/len(TOKENS))*100}%", + f"======================", + ] + ) + ) + assert ACTUAL_CORRUPTION_PERCENTAGE > CORRUPTION_PERCENTAGE - TOLERANCE + assert ACTUAL_CORRUPTION_PERCENTAGE < CORRUPTION_PERCENTAGE + TOLERANCE if __name__ == "__main__": TestSpannedMasker().test_spanned_masking() - - -