from Project_Model.Libs.BPE.Enums import TokenType import Project_Model.Libs.BPE as BPE import re class TestBPE: def test_bpe_encoding_simple(self): TEXT = "abababab" # ab = 256 # 256, 256 = 257 # 257, 257 = 258 VOCABULARY = {(ord("a"), ord("b")): 256, (256, 256): 257, (257, 257): 258} EXPECTED = [258] BPE_ENCODER = BPE.NanoSocratesBPE(VOCABULARY) ENCODED = BPE_ENCODER.encode(TEXT) assert len(ENCODED) == len(EXPECTED) for encoded, expected in zip(ENCODED, EXPECTED): assert encoded == expected def test_bpe_decoding_simple(self): INPUT = [258] # ab = 256 # 256, 256 = 257 # 257, 257 = 258 VOCABULARY = {(ord("a"), ord("b")): 256, (256, 256): 257, (257, 257): 258} EXPECTED = "abababab" BPE_ENCODER = BPE.NanoSocratesBPE(VOCABULARY) DECODED = BPE_ENCODER.decode(INPUT) assert len(DECODED) == len(EXPECTED) for encoded, expected in zip(DECODED, EXPECTED): assert encoded == expected def test_bpe_decoding_edge_1(self): INPUT = [258, ord("c")] # ab = 256 # 256, 256 = 257 # 257, 257 = 258 VOCABULARY = {(ord("a"), ord("b")): 256, (256, 256): 257, (257, 257): 258} EXPECTED = "ababababc" BPE_ENCODER = BPE.NanoSocratesBPE(VOCABULARY) DECODED = BPE_ENCODER.decode(INPUT) assert len(DECODED) == len(EXPECTED) for encoded, expected in zip(DECODED, EXPECTED): assert encoded == expected # Useful to debug weird cases if __name__ == "__main__": # TestBPE().test_bpe_decoding_simple() TestBPE().test_bpe_encoding_simple()