from Project_Model.Libs.BPE.Enums import TokenType import Project_Model.Libs.BPE as BPE import re PATTERN = "<(TOKEN|SOT|SEP|EOT)>" SYMBOL_REGEX = re.compile(PATTERN) class TestSplitter: def test_split(self): TEXT = "Lorem " SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX) EXPECTED_CHUNKS = [ ("", TokenType.SPECIAL), ("Lorem", TokenType.BPE), (" ", TokenType.BPE), ("", TokenType.SPECIAL), ] CHUNKS = list(SPLITTER.split_text(TEXT)) assert len(CHUNKS) == len(EXPECTED_CHUNKS) for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS): print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}") RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE def test_split_trailing_text(self): TEXT = "ipsum dolor" SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX) EXPECTED_CHUNKS = [ ("ipsu", TokenType.BPE), ("", TokenType.SPECIAL), ("m", TokenType.BPE), (" d", TokenType.BPE), ("", TokenType.SPECIAL), # ("olor", TokenType.BPE) ] CHUNKS = list(SPLITTER.split_text(TEXT)) assert len(CHUNKS) == len(EXPECTED_CHUNKS) for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS): print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}") RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE def test_split_multi_token(self): TEXT = "ipsum ddsgolor" SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX) EXPECTED_CHUNKS = [ ("ipsu", TokenType.BPE), ("", TokenType.SPECIAL), ("m", TokenType.BPE), (" d", TokenType.BPE), ("", TokenType.SPECIAL), ("", TokenType.SPECIAL), ("", TokenType.SPECIAL), ("dsg", TokenType.BPE), ("", TokenType.SPECIAL), ] CHUNKS = list(SPLITTER.split_text(TEXT)) assert len(CHUNKS) == len(EXPECTED_CHUNKS) for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS): print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}") RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE def test_split_malformed_1(self): TEXT = "lerisque" SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX) EXPECTED_CHUNKS = [ ("", TokenType.SPECIAL), ] CHUNKS = list(SPLITTER.split_text(TEXT)) assert len(CHUNKS) == len(EXPECTED_CHUNKS) for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS): print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}") RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE def test_split_malformed_2(self): TEXT = "lerisque" SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX) EXPECTED_CHUNKS = [] CHUNKS = list(SPLITTER.split_text(TEXT)) assert len(CHUNKS) == len(EXPECTED_CHUNKS) for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS): print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}") RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE def test_split_token_decode_simple(self): # to test the token split into special and bpe SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX, max_bpe_token_id=1473) token_list = [100, 101, 1477] CHUNKS = list(SPLITTER.split_tokens(token_list)) EXPECTED_CHUNKS = [ ([100, 101], TokenType.BPE), ([1477], TokenType.SPECIAL), ] assert len(CHUNKS) == len(EXPECTED_CHUNKS) for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS): print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}") RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE def test_split_token_decode_simple_malformed(self): # to test the token split into special and bpe SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX, max_bpe_token_id=1473) token_list = [100, 101, 1477, 100] CHUNKS = list(SPLITTER.split_tokens(token_list)) EXPECTED_CHUNKS = [ ([100, 101], TokenType.BPE), ([1477], TokenType.SPECIAL), ] assert len(CHUNKS) == len(EXPECTED_CHUNKS) for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS): print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}") RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE # Useful to debug weird cases if __name__ == "__main__": TestSplitter().test_split_trailing_text()