from pathlib import Path from Project_Model.Libs.BPE.Enums import TokenType import Project_Model.Libs.BPE as BPE import re CACHE_DIR_PATH = Path("Project_Model/Tests/trainer_files/cache/pool-cache.json") class TestTrainBPE: def test_bpe_train_encoding_simple(self): TRAINER = BPE.NanoSocraTrainerPool( int(32E3), ["", ""] ) TEXT = "abababab" TEXT_PATH = Path("Project_Model/Tests/trainer_files/train_simple.txt") EXPECTED = [258] # ab = 256 # 256, 256 = 257 # 257, 257 = 258 BPE_ENCODER = TRAINER.trainBPE( TEXT_PATH, CACHE_DIR_PATH ) ENCODED = BPE_ENCODER.encode(TEXT) assert len(ENCODED) == len(EXPECTED) for encoded, expected in zip(ENCODED, EXPECTED): assert encoded == expected def test_bpe_train_encoding_and_decoding(self): SPECIAL_LIST = ["", ""] TRAINER = BPE.NanoSocraTrainerPool( int(32E3), SPECIAL_LIST ) TEXT_PATH = Path("Project_Model/Tests/trainer_files/train_encode_decode.txt") FILE = open(TEXT_PATH) TEXT = FILE.read() FILE.close() EXPECTED = TEXT # ab = 256 # 256, 256 = 257 # 257, 257 = 258 BPE_ENCODER = TRAINER.trainBPE( TEXT_PATH, CACHE_DIR_PATH ) VOCABULARY = BPE_ENCODER.vocabulary TOKENANO = BPE.TokeNanoCore(VOCABULARY,SPECIAL_LIST) ENCODED = TOKENANO.encode(TEXT) DECODED = TOKENANO.decode(ENCODED) assert len(DECODED) == len(EXPECTED) for decoded, expected in zip(DECODED, EXPECTED): assert decoded == expected # Useful to debug weird cases if __name__ == "__main__": # TestTrainBPE().test_bpe_train_encoding_simple() TestTrainBPE().test_bpe_train_encoding_and_decoding()