Added testing for spanned masking
This commit is contained in:
parent
c217f5dec9
commit
d3b1f7da91
73
Project_Model/Tests/spanned_masker_test.py
Normal file
73
Project_Model/Tests/spanned_masker_test.py
Normal file
@ -0,0 +1,73 @@
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import Project_Model.Libs.BPE as BPE
|
||||
import Project_Model.Libs.Transformer as Transformer
|
||||
|
||||
VOCABULARY_PATH = Path("Assets/Model/toy_10/toy_dictionary.json")
|
||||
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
|
||||
SPECIAL_LIST = BPE.default_special_tokens()
|
||||
|
||||
class TestSpannedMasker:
|
||||
|
||||
def test_spanned_masking(self):
|
||||
|
||||
CORPUS_PATH = Path("Project_Model/Tests/spanner_file/mask.txt")
|
||||
TEXT = CORPUS_PATH.read_text("utf-8")
|
||||
|
||||
TOKENIZER = BPE.TokeNanoCore(
|
||||
VOCABULARY,
|
||||
SPECIAL_LIST
|
||||
)
|
||||
VOCABULARY_SIZE = TOKENIZER.vocabulary_size
|
||||
|
||||
MASKER = Transformer.SpannedMasker(0.4,average_span=3)
|
||||
|
||||
TOKENS = TOKENIZER.encode(TEXT)
|
||||
|
||||
LEGAL_TOKENS: set[int] = set(TOKENIZER.encode(
|
||||
"<SUBJ><OBJ><PRED>"
|
||||
))
|
||||
|
||||
SPECIAL_TOKENS: set[int] = set(TOKENIZER.encode(
|
||||
"".join(SPECIAL_LIST)
|
||||
))
|
||||
|
||||
ILLEGAL_TOKENS: set[int] = SPECIAL_TOKENS.difference(LEGAL_TOKENS)
|
||||
|
||||
SPECIAL_FORMATTER = TOKENIZER.encode("*<SOT>")[0]
|
||||
END_FORMATTER = TOKENIZER.encode("<EOT>")[0]
|
||||
|
||||
OUTPUT, TARGET = MASKER.mask_sequence(TOKENS, VOCABULARY_SIZE, ILLEGAL_TOKENS)
|
||||
|
||||
UNCORRUPTED_TOKENS = list(filter(lambda token: token <= VOCABULARY_SIZE, OUTPUT))
|
||||
CORRUPTED_TOKENS = list(filter(lambda token: token <= VOCABULARY_SIZE, TARGET))
|
||||
|
||||
TARGET.append(END_FORMATTER)
|
||||
|
||||
OUTPUT = list(map(lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, OUTPUT))
|
||||
TARGET = list(map(lambda token: SPECIAL_FORMATTER if token > VOCABULARY_SIZE else token, TARGET))
|
||||
|
||||
OUT_TEXT = TOKENIZER.decode(OUTPUT)
|
||||
TAR_TEXT = TOKENIZER.decode(TARGET)
|
||||
|
||||
print(f"Original text:\n\n{TEXT}")
|
||||
print(f"Inputs:\n\n{OUT_TEXT}")
|
||||
print(f"Targets:\n\n{TAR_TEXT}")
|
||||
|
||||
print("\n".join([
|
||||
f"======================",
|
||||
f"Original length: {len(TOKENS)}",
|
||||
f"Uncorrupted Chars: {len(UNCORRUPTED_TOKENS)}",
|
||||
f"Corrupted Chars: {len(CORRUPTED_TOKENS)}",
|
||||
f"Percentage_corruption: {(len(CORRUPTED_TOKENS)/len(TOKENS))*100}%",
|
||||
f"======================"
|
||||
]))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TestSpannedMasker().test_spanned_masking()
|
||||
|
||||
|
||||
|
||||
1
Project_Model/Tests/spanner_file/mask.txt
Normal file
1
Project_Model/Tests/spanner_file/mask.txt
Normal file
@ -0,0 +1 @@
|
||||
<SOT><SUBJ>dbp-dbr:How_It_Should_Have_Ended<PRED>dbp-dbp:title<OBJ>dbp-dbr:The_Dark_Knight<EOT><SOT><SUBJ>dbp-dbr:The_Dark_Knight<PRED>dbp-dbp:caption<OBJ>Theatrical release poster<EOT><SOT><SUBJ>dbp-dbr:The_Dark_Knight<PRED>dbp-dbp:director<OBJ>dbp-dbr:Christopher_Nolan<EOT><SOT><SUBJ>dbp-dbr:The_Dark_Knight<PRED>dbp-dbp:distributor<OBJ>Warner Bros. Pictures<EOT><SOT><SUBJ>dbp-dbr:The_Dark_Knight<PRED>dbp-dbp:producer<OBJ>Charles Roven<EOT><SOT><SUBJ>dbp-dbr:The_Dark_Knight<PRED>dbp-dbp:producer<OBJ>Christopher Nolan<EOT><SOT><SUBJ>dbp-dbr:The_Dark_Knight<PRED>dbp-dbp:producer<OBJ>Emma Thomas<EOT><SOT><SUBJ>dbp-dbr:The_Dark_Knight<PRED>dbp-dbp:starring<OBJ>Christian Bale<EOT>
|
||||
Loading…
x
Reference in New Issue
Block a user