NanoSocrates/Project_Model/Tests/splitter_test.py
2025-10-03 01:04:47 +02:00

180 lines
6.0 KiB
Python

from Project_Model.Libs.BPE.Enums import TokenType
import Project_Model.Libs.BPE as BPE
import re
PATTERN = "<(TOKEN|SOT|SEP|EOT)>"
SYMBOL_REGEX = re.compile(PATTERN)
class TestSplitter:
def test_split(self):
TEXT = "<SOT>Lorem <SEP>"
SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX)
EXPECTED_CHUNKS = [
("<SOT>", TokenType.SPECIAL),
("Lorem ", TokenType.BPE),
("<SEP>", TokenType.SPECIAL),
]
CHUNKS = list(SPLITTER.split_text(TEXT))
assert len(CHUNKS) == len(EXPECTED_CHUNKS)
for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS):
print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}")
RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk
EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk
assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING
assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE
def test_split_trailing_text(self):
TEXT = "ipsu<SEP>m d<SEP>olor"
SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX)
EXPECTED_CHUNKS = [
("ipsu", TokenType.BPE),
("<SEP>", TokenType.SPECIAL),
("m d", TokenType.BPE),
("<SEP>", TokenType.SPECIAL),
#("olor", TokenType.BPE)
]
CHUNKS = list(SPLITTER.split_text(TEXT))
assert len(CHUNKS) == len(EXPECTED_CHUNKS)
for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS):
print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}")
RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk
EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk
assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING
assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE
def test_split_multi_token(self):
TEXT = "ipsu<SEP>m d<SEP><SEP><SEP>dsg<SEP>olor"
SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX)
EXPECTED_CHUNKS = [
("ipsu", TokenType.BPE),
("<SEP>", TokenType.SPECIAL),
("m d", TokenType.BPE),
("<SEP>", TokenType.SPECIAL),
("<SEP>", TokenType.SPECIAL),
("<SEP>", TokenType.SPECIAL),
("dsg", TokenType.BPE),
("<SEP>", TokenType.SPECIAL),
]
CHUNKS = list(SPLITTER.split_text(TEXT))
assert len(CHUNKS) == len(EXPECTED_CHUNKS)
for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS):
print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}")
RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk
EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk
assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING
assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE
def test_split_malformed_1(self):
TEXT = "<SEP>lerisque"
SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX)
EXPECTED_CHUNKS = [
("<SEP>", TokenType.SPECIAL),
]
CHUNKS = list(SPLITTER.split_text(TEXT))
assert len(CHUNKS) == len(EXPECTED_CHUNKS)
for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS):
print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}")
RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk
EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk
assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING
assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE
def test_split_malformed_2(self):
TEXT = "lerisque"
SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX)
EXPECTED_CHUNKS = []
CHUNKS = list(SPLITTER.split_text(TEXT))
assert len(CHUNKS) == len(EXPECTED_CHUNKS)
for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS):
print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}")
RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk
EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk
assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING
assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE
def test_split_token_decode_simple(self):
# to test the token split into special and bpe
SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX, max_bpe_token_id=1473)
token_list = [100,101,1477]
CHUNKS = list(SPLITTER.split_tokens(token_list))
EXPECTED_CHUNKS = [
([100,101], TokenType.BPE),
(1477, TokenType.SPECIAL),
]
assert len(CHUNKS) == len(EXPECTED_CHUNKS)
for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS):
print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}")
RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk
EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk
assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING
assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE
def test_split_token_decode_simple_malformed(self):
# to test the token split into special and bpe
SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX, max_bpe_token_id=1473)
token_list = [100,101,1477,100]
CHUNKS = list(SPLITTER.split_tokens(token_list))
EXPECTED_CHUNKS = [
([100,101], TokenType.BPE),
(1477, TokenType.SPECIAL),
]
assert len(CHUNKS) == len(EXPECTED_CHUNKS)
for chunk, expected_chunk in zip(EXPECTED_CHUNKS, CHUNKS):
print(f"TEST:\n\tCHUNK:\t\t{chunk}\n\tEXPECTED:\t\t{expected_chunk}")
RECEIVED_TOKEN_STRING, RECEIVED_TOKEN_TYPE = chunk
EXPECTED_TOKEN_STRING, EXPECTED_TOKEN_TYPE = expected_chunk
assert RECEIVED_TOKEN_STRING == EXPECTED_TOKEN_STRING
assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE
# Useful to debug weird cases
if __name__ == "__main__":
TestSplitter().test_split_trailing_text()