42 lines
983 B
Python
42 lines
983 B
Python
from pathlib import Path
|
|
from Project_Model.Libs.BPE.Enums import TokenType
|
|
import Project_Model.Libs.BPE as BPE
|
|
|
|
import re
|
|
|
|
CACHE_DIR_PATH = Path("Project_Model/Tests/trainer_files/cache/pool-cache.json")
|
|
|
|
class TestTrainBPE:
|
|
|
|
def test_bpe_train_encoding_simple(self):
|
|
|
|
TRAINER = BPE.NanoSocraTrainerPool(
|
|
int(32E3),
|
|
["<SOT>", "<EOT>"]
|
|
)
|
|
|
|
TEXT = "abababab"
|
|
TEXT_PATH = Path("Project_Model/Tests/trainer_files/train_simple.txt")
|
|
|
|
EXPECTED = [258]
|
|
|
|
# ab = 256
|
|
# 256, 256 = 257
|
|
# 257, 257 = 258
|
|
|
|
BPE_ENCODER = TRAINER.trainBPE(
|
|
TEXT_PATH,
|
|
CACHE_DIR_PATH
|
|
)
|
|
|
|
ENCODED = BPE_ENCODER.encode(TEXT)
|
|
|
|
assert len(ENCODED) == len(EXPECTED)
|
|
|
|
for encoded, expected in zip(ENCODED, EXPECTED):
|
|
assert encoded == expected
|
|
|
|
# Useful to debug weird cases
|
|
if __name__ == "__main__":
|
|
TestTrainBPE().test_bpe_train_encoding_simple()
|