diff --git a/Project_Model/Tests/bpe_test.py b/Project_Model/Tests/bpe_test.py index 7332f65..e6c8f31 100644 --- a/Project_Model/Tests/bpe_test.py +++ b/Project_Model/Tests/bpe_test.py @@ -29,7 +29,7 @@ class TestBPE: def test_bpe_decoding_simple(self): - INPUT = 258 + INPUT = [258] # ab = 256 # 256, 256 = 257 @@ -47,6 +47,27 @@ class TestBPE: for encoded, expected in zip(DECODED, EXPECTED): assert encoded == expected + def test_bpe_decoding_edge_1(self): + + + INPUT = [258, ord("c")] + + # ab = 256 + # 256, 256 = 257 + # 257, 257 = 258 + + VOCABULARY = {(ord("a"), ord("b")): 256, (256, 256): 257, (257, 257): 258} + EXPECTED = "ababababc" + + BPE_ENCODER = BPE.NanoSocratesBPE(VOCABULARY) + + DECODED = BPE_ENCODER.decode(INPUT) + + assert len(DECODED) == len(EXPECTED) + + for encoded, expected in zip(DECODED, EXPECTED): + assert encoded == expected + # Useful to debug weird cases if __name__ == "__main__": TestBPE().test_bpe_decoding_simple()