Created files to test BPE training
This commit is contained in:
parent
b09bd4acba
commit
ccacea18d8
42
Project_Model/Tests/bpe_trainer.py
Normal file
42
Project_Model/Tests/bpe_trainer.py
Normal file
@ -0,0 +1,42 @@
|
||||
from pathlib import Path
|
||||
from Project_Model.Libs.BPE.Enums import TokenType
|
||||
import Project_Model.Libs.BPE as BPE
|
||||
|
||||
import re
|
||||
|
||||
CACHE_DIR_PATH = Path("Project_Model/Tests/trainer_files/cache")
|
||||
|
||||
class TestTrainBPE:
|
||||
|
||||
def test_bpe_train_encoding_simple(self):
|
||||
|
||||
TRAINER = BPE.NanoSocraTrainer(
|
||||
int(32E3),
|
||||
["<SOT>", "<EOT>"],
|
||||
40
|
||||
)
|
||||
|
||||
TEXT = "abababab"
|
||||
TEXT_PATH = Path("Project_Model/Tests/trainer_files/train_simple.txt")
|
||||
|
||||
EXPECTED = [258]
|
||||
|
||||
# ab = 256
|
||||
# 256, 256 = 257
|
||||
# 257, 257 = 258
|
||||
|
||||
BPE_ENCODER = TRAINER.trainBPE(
|
||||
TEXT_PATH,
|
||||
CACHE_DIR_PATH
|
||||
)
|
||||
|
||||
ENCODED = BPE_ENCODER.encode(TEXT)
|
||||
|
||||
assert len(ENCODED) == len(EXPECTED)
|
||||
|
||||
for encoded, expected in zip(ENCODED, EXPECTED):
|
||||
assert encoded == expected
|
||||
|
||||
# Useful to debug weird cases
|
||||
if __name__ == "__main__":
|
||||
TestTrainBPE().test_bpe_train_encoding_simple()
|
||||
1
Project_Model/Tests/trainer_files/train_simple.txt
Normal file
1
Project_Model/Tests/trainer_files/train_simple.txt
Normal file
@ -0,0 +1 @@
|
||||
<SOT>abababab<EOT>
|
||||
Loading…
x
Reference in New Issue
Block a user