Added actual test
This commit is contained in:
parent
b1e7af0607
commit
d3bba9b944
@ -8,30 +8,26 @@ VOCABULARY_PATH = Path("Assets/Model/toy_10/toy_dictionary.json")
|
|||||||
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
|
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
|
||||||
SPECIAL_LIST = BPE.default_special_tokens()
|
SPECIAL_LIST = BPE.default_special_tokens()
|
||||||
|
|
||||||
|
|
||||||
class TestSpannedMasker:
|
class TestSpannedMasker:
|
||||||
|
|
||||||
def test_spanned_masking(self):
|
def test_spanned_masking(self):
|
||||||
|
|
||||||
CORPUS_PATH = Path("Project_Model/Tests/spanner_file/mask.txt")
|
CORPUS_PATH = Path("Project_Model/Tests/spanner_file/mask.txt")
|
||||||
TEXT = CORPUS_PATH.read_text("utf-8")
|
TEXT = CORPUS_PATH.read_text("utf-8")
|
||||||
|
CORRUPTION_PERCENTAGE = 0.15
|
||||||
|
TOLERANCE = 0.05
|
||||||
|
|
||||||
TOKENIZER = BPE.TokeNanoCore(
|
TOKENIZER = BPE.TokeNanoCore(VOCABULARY, SPECIAL_LIST)
|
||||||
VOCABULARY,
|
|
||||||
SPECIAL_LIST
|
|
||||||
)
|
|
||||||
VOCABULARY_SIZE = TOKENIZER.vocabulary_size
|
VOCABULARY_SIZE = TOKENIZER.vocabulary_size
|
||||||
|
|
||||||
MASKER = Transformer.SpannedMasker(0.4,average_span=3)
|
MASKER = Transformer.SpannedMasker(CORRUPTION_PERCENTAGE, 3)
|
||||||
|
|
||||||
TOKENS = TOKENIZER.encode(TEXT)
|
TOKENS = TOKENIZER.encode(TEXT)
|
||||||
|
|
||||||
LEGAL_TOKENS: set[int] = set(TOKENIZER.encode(
|
LEGAL_TOKENS: set[int] = set(TOKENIZER.encode("<SUBJ><OBJ><PRED>"))
|
||||||
"<SUBJ><OBJ><PRED>"
|
|
||||||
))
|
|
||||||
|
|
||||||
SPECIAL_TOKENS: set[int] = set(TOKENIZER.encode(
|
SPECIAL_TOKENS: set[int] = set(TOKENIZER.encode("".join(SPECIAL_LIST)))
|
||||||
"".join(SPECIAL_LIST)
|
|
||||||
))
|
|
||||||
|
|
||||||
ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS)
|
ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS)
|
||||||
|
|
||||||
@ -40,34 +36,52 @@ class TestSpannedMasker:
|
|||||||
|
|
||||||
OUTPUT, TARGET = MASKER.mask_sequence(TOKENS, VOCABULARY_SIZE, ILLEGAL_TOKENS)
|
OUTPUT, TARGET = MASKER.mask_sequence(TOKENS, VOCABULARY_SIZE, ILLEGAL_TOKENS)
|
||||||
|
|
||||||
UNCORRUPTED_TOKENS = list(filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT))
|
UNCORRUPTED_TOKENS = list(
|
||||||
|
filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT)
|
||||||
|
)
|
||||||
CORRUPTED_TOKENS = list(filter(lambda token: token <= VOCABULARY_SIZE, TARGET))
|
CORRUPTED_TOKENS = list(filter(lambda token: token <= VOCABULARY_SIZE, TARGET))
|
||||||
|
|
||||||
TARGET.append(END_FORMATTER)
|
TARGET.append(END_FORMATTER)
|
||||||
|
|
||||||
OUTPUT = list(map(lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, OUTPUT))
|
OUTPUT = list(
|
||||||
TARGET = list(map(lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, TARGET))
|
map(
|
||||||
|
lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token,
|
||||||
|
OUTPUT,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
TARGET = list(
|
||||||
|
map(
|
||||||
|
lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token,
|
||||||
|
TARGET,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
OUT_TEXT = TOKENIZER.decode(OUTPUT)
|
OUT_TEXT = TOKENIZER.decode(OUTPUT)
|
||||||
TAR_TEXT = TOKENIZER.decode(TARGET)
|
TAR_TEXT = TOKENIZER.decode(TARGET)
|
||||||
|
|
||||||
|
ACTUAL_CORRUPTION_PERCENTAGE = len(CORRUPTED_TOKENS) / len(TOKENS)
|
||||||
|
|
||||||
print(f"Original text:\n\n{TEXT}")
|
print(f"Original text:\n\n{TEXT}")
|
||||||
print(f"Inputs:\n\n{OUT_TEXT}")
|
print(f"Inputs:\n\n{OUT_TEXT}")
|
||||||
print(f"Targets:\n\n{TAR_TEXT}")
|
print(f"Targets:\n\n{TAR_TEXT}")
|
||||||
|
print(f"Target Tokens:\n\n{OUTPUT}")
|
||||||
|
|
||||||
print("\n".join([
|
print(
|
||||||
f"======================",
|
"\n".join(
|
||||||
f"Original length: {len(TOKENS)}",
|
[
|
||||||
f"Uncorrupted Chars: {len(UNCORRUPTED_TOKENS)}",
|
f"======================",
|
||||||
f"Corrupted Chars: {len(CORRUPTED_TOKENS)}",
|
f"Original length: {len(TOKENS)}",
|
||||||
f"Percentage_corruption: {(len(CORRUPTED_TOKENS)/len(TOKENS))*100}%",
|
f"Uncorrupted Chars: {len(UNCORRUPTED_TOKENS)}",
|
||||||
f"======================"
|
f"Corrupted Chars: {len(CORRUPTED_TOKENS)}",
|
||||||
]))
|
f"Percentage_corruption: {(len(CORRUPTED_TOKENS)/len(TOKENS))*100}%",
|
||||||
|
f"======================",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ACTUAL_CORRUPTION_PERCENTAGE > CORRUPTION_PERCENTAGE - TOLERANCE
|
||||||
|
assert ACTUAL_CORRUPTION_PERCENTAGE < CORRUPTION_PERCENTAGE + TOLERANCE
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
TestSpannedMasker().test_spanned_masking()
|
TestSpannedMasker().test_spanned_masking()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user