2025-09-30 13:33:54 +02:00
|
|
|
from pathlib import Path
|
|
|
|
|
from Project_Model.Libs.BPE.Enums import TokenType
|
|
|
|
|
import Project_Model.Libs.BPE as BPE
|
|
|
|
|
|
|
|
|
|
import re
|
|
|
|
|
|
2025-10-02 09:56:05 +02:00
|
|
|
CACHE_DIR_PATH = Path("Project_Model/Tests/trainer_files/cache/pool-cache.json")
|
2025-09-30 13:33:54 +02:00
|
|
|
|
|
|
|
|
class TestTrainBPE:
|
|
|
|
|
|
|
|
|
|
def test_bpe_train_encoding_simple(self):
|
|
|
|
|
|
2025-10-02 09:56:05 +02:00
|
|
|
TRAINER = BPE.NanoSocraTrainerPool(
|
2025-09-30 13:33:54 +02:00
|
|
|
int(32E3),
|
2025-10-02 20:11:43 +02:00
|
|
|
["<SOT>", "<EOT>"]
|
2025-09-30 13:33:54 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
TEXT = "abababab"
|
|
|
|
|
TEXT_PATH = Path("Project_Model/Tests/trainer_files/train_simple.txt")
|
|
|
|
|
|
|
|
|
|
EXPECTED = [258]
|
|
|
|
|
|
|
|
|
|
# ab = 256
|
|
|
|
|
# 256, 256 = 257
|
|
|
|
|
# 257, 257 = 258
|
|
|
|
|
|
|
|
|
|
BPE_ENCODER = TRAINER.trainBPE(
|
|
|
|
|
TEXT_PATH,
|
|
|
|
|
CACHE_DIR_PATH
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
ENCODED = BPE_ENCODER.encode(TEXT)
|
|
|
|
|
|
|
|
|
|
assert len(ENCODED) == len(EXPECTED)
|
|
|
|
|
|
|
|
|
|
for encoded, expected in zip(ENCODED, EXPECTED):
|
|
|
|
|
assert encoded == expected
|
|
|
|
|
|
2025-10-04 18:58:04 +02:00
|
|
|
|
|
|
|
|
def test_bpe_train_encoding_and_decoding(self):
|
|
|
|
|
|
|
|
|
|
SPECIAL_LIST = ["<ABS>", "<SOTL>"]
|
|
|
|
|
TRAINER = BPE.NanoSocraTrainerPool(
|
|
|
|
|
int(32E3),
|
|
|
|
|
SPECIAL_LIST
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
TEXT_PATH = Path("Project_Model/Tests/trainer_files/train_encode_decode.txt")
|
|
|
|
|
FILE = open(TEXT_PATH)
|
|
|
|
|
TEXT = FILE.read()
|
|
|
|
|
FILE.close()
|
|
|
|
|
|
|
|
|
|
EXPECTED = TEXT
|
|
|
|
|
|
|
|
|
|
# ab = 256
|
|
|
|
|
# 256, 256 = 257
|
|
|
|
|
# 257, 257 = 258
|
|
|
|
|
|
|
|
|
|
BPE_ENCODER = TRAINER.trainBPE(
|
|
|
|
|
TEXT_PATH,
|
|
|
|
|
CACHE_DIR_PATH
|
|
|
|
|
)
|
|
|
|
|
VOCABULARY = BPE_ENCODER.vocabulary
|
|
|
|
|
TOKENANO = BPE.TokeNanoCore(VOCABULARY,SPECIAL_LIST)
|
|
|
|
|
|
|
|
|
|
ENCODED = TOKENANO.encode(TEXT)
|
|
|
|
|
DECODED = TOKENANO.decode(ENCODED)
|
|
|
|
|
|
|
|
|
|
assert len(DECODED) == len(EXPECTED)
|
|
|
|
|
|
|
|
|
|
for decoded, expected in zip(DECODED, EXPECTED):
|
|
|
|
|
assert decoded == expected
|
|
|
|
|
|
2025-09-30 13:33:54 +02:00
|
|
|
# Useful to debug weird cases
|
|
|
|
|
if __name__ == "__main__":
|
2025-10-04 18:58:04 +02:00
|
|
|
# TestTrainBPE().test_bpe_train_encoding_simple()
|
|
|
|
|
TestTrainBPE().test_bpe_train_encoding_and_decoding()
|