NanoSocrates/Project_Model/Tests/spanned_masker_test.py
2025-10-06 15:55:40 +02:00

74 lines
2.3 KiB
Python

from functools import reduce
from pathlib import Path
import pytest
import Project_Model.Libs.BPE as BPE
import Project_Model.Libs.Transformer as Transformer
VOCABULARY_PATH = Path("Assets/Model/toy_10/toy_dictionary.json")
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
SPECIAL_LIST = BPE.default_special_tokens()
class TestSpannedMasker:
def test_spanned_masking(self):
CORPUS_PATH = Path("Project_Model/Tests/spanner_file/mask.txt")
TEXT = CORPUS_PATH.read_text("utf-8")
TOKENIZER = BPE.TokeNanoCore(
VOCABULARY,
SPECIAL_LIST
)
VOCABULARY_SIZE = TOKENIZER.vocabulary_size
MASKER = Transformer.SpannedMasker(0.4,average_span=3)
TOKENS = TOKENIZER.encode(TEXT)
LEGAL_TOKENS: set[int] = set(TOKENIZER.encode(
"<SUBJ><OBJ><PRED>"
))
SPECIAL_TOKENS: set[int] = set(TOKENIZER.encode(
"".join(SPECIAL_LIST)
))
ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS)
SPECIAL_FORMATTER = TOKENIZER.encode("*<SOT>")[0]
END_FORMATTER = TOKENIZER.encode("<EOT>")[0]
OUTPUT, TARGET = MASKER.mask_sequence(TOKENS, VOCABULARY_SIZE, ILLEGAL_TOKENS)
UNCORRUPTED_TOKENS = list(filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT))
CORRUPTED_TOKENS = list(filter(lambda token: token <= VOCABULARY_SIZE, TARGET))
TARGET.append(END_FORMATTER)
OUTPUT = list(map(lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, OUTPUT))
TARGET = list(map(lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, TARGET))
OUT_TEXT = TOKENIZER.decode(OUTPUT)
TAR_TEXT = TOKENIZER.decode(TARGET)
print(f"Original text:\n\n{TEXT}")
print(f"Inputs:\n\n{OUT_TEXT}")
print(f"Targets:\n\n{TAR_TEXT}")
print("\n".join([
f"======================",
f"Original length: {len(TOKENS)}",
f"Uncorrupted Chars: {len(UNCORRUPTED_TOKENS)}",
f"Corrupted Chars: {len(CORRUPTED_TOKENS)}",
f"Percentage_corruption: {(len(CORRUPTED_TOKENS)/len(TOKENS))*100}%",
f"======================"
]))
if __name__ == "__main__":
TestSpannedMasker().test_spanned_masking()