from functools import reduce from pathlib import Path import pytest import Project_Model.Libs.BPE as BPE import Project_Model.Libs.Transformer as Transformer VOCABULARY_PATH = Path("Assets/Model/toy_10/toy_dictionary.json") VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH) SPECIAL_LIST = BPE.default_special_tokens() class TestSpannedMasker: def test_spanned_masking(self): CORPUS_PATH = Path("Project_Model/Tests/spanner_file/mask.txt") TEXT = CORPUS_PATH.read_text("utf-8") CORRUPTION_PERCENTAGE = 0.15 TOLERANCE = 0.15 TOKENIZER = BPE.TokeNanoCore(VOCABULARY, SPECIAL_LIST) VOCABULARY_SIZE = TOKENIZER.vocabulary_size MASKER = Transformer.SpannedMasker(CORRUPTION_PERCENTAGE, 3) TOKENS = TOKENIZER.encode(TEXT) LEGAL_TOKENS: set[int] = set(TOKENIZER.encode("")) SPECIAL_TOKENS: set[int] = set(TOKENIZER.encode("".join(SPECIAL_LIST))) ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS) SPECIAL_FORMATTER = TOKENIZER.encode("*")[0] END_FORMATTER = TOKENIZER.encode("")[0] OUTPUT, TARGET = MASKER.mask_sequence(TOKENS, VOCABULARY_SIZE, ILLEGAL_TOKENS) UNCORRUPTED_TOKENS = list( filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT) ) CORRUPTED_TOKENS = list(filter(lambda token: token <= VOCABULARY_SIZE, TARGET)) TARGET.append(END_FORMATTER) OUTPUT = list( map( lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, OUTPUT, ) ) TARGET = list( map( lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, TARGET, ) ) OUT_TEXT = TOKENIZER.decode(OUTPUT) TAR_TEXT = TOKENIZER.decode(TARGET) ACTUAL_CORRUPTION_PERCENTAGE = len(CORRUPTED_TOKENS) / len(TOKENS) print(f"Original text:\n\n{TEXT}") print(f"Inputs:\n\n{OUT_TEXT}") print(f"Targets:\n\n{TAR_TEXT}") print(f"Target Tokens:\n\n{OUTPUT}") print( "\n".join( [ f"======================", f"Original length: {len(TOKENS)}", f"Uncorrupted Chars: {len(UNCORRUPTED_TOKENS)}", f"Corrupted Chars: {len(CORRUPTED_TOKENS)}", f"Percentage_corruption: {(len(CORRUPTED_TOKENS)/len(TOKENS))*100}%", f"======================", ] ) ) for token in TARGET[:len(TARGET) - 1]: assert token not in ILLEGAL_TOKENS assert ACTUAL_CORRUPTION_PERCENTAGE > CORRUPTION_PERCENTAGE - TOLERANCE assert ACTUAL_CORRUPTION_PERCENTAGE < CORRUPTION_PERCENTAGE + TOLERANCE if __name__ == "__main__": TestSpannedMasker().test_spanned_masking()