Fixed tests to reflect new version of tokenizer
This commit is contained in:
parent
51f491d033
commit
c74689d01d
@ -18,7 +18,8 @@ class TestSplitter:
|
||||
|
||||
EXPECTED_CHUNKS = [
|
||||
("<SOT>", TokenType.SPECIAL),
|
||||
("Lorem ", TokenType.BPE),
|
||||
("Lorem", TokenType.BPE),
|
||||
(" ", TokenType.BPE),
|
||||
("<SEP>", TokenType.SPECIAL),
|
||||
]
|
||||
|
||||
@ -43,9 +44,10 @@ class TestSplitter:
|
||||
EXPECTED_CHUNKS = [
|
||||
("ipsu", TokenType.BPE),
|
||||
("<SEP>", TokenType.SPECIAL),
|
||||
("m d", TokenType.BPE),
|
||||
("m", TokenType.BPE),
|
||||
(" d", TokenType.BPE),
|
||||
("<SEP>", TokenType.SPECIAL),
|
||||
#("olor", TokenType.BPE)
|
||||
# ("olor", TokenType.BPE)
|
||||
]
|
||||
|
||||
CHUNKS = list(SPLITTER.split_text(TEXT))
|
||||
@ -69,7 +71,8 @@ class TestSplitter:
|
||||
EXPECTED_CHUNKS = [
|
||||
("ipsu", TokenType.BPE),
|
||||
("<SEP>", TokenType.SPECIAL),
|
||||
("m d", TokenType.BPE),
|
||||
("m", TokenType.BPE),
|
||||
(" d", TokenType.BPE),
|
||||
("<SEP>", TokenType.SPECIAL),
|
||||
("<SEP>", TokenType.SPECIAL),
|
||||
("<SEP>", TokenType.SPECIAL),
|
||||
@ -134,12 +137,12 @@ class TestSplitter:
|
||||
def test_split_token_decode_simple(self):
|
||||
# to test the token split into special and bpe
|
||||
SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX, max_bpe_token_id=1473)
|
||||
token_list = [100,101,1477]
|
||||
token_list = [100, 101, 1477]
|
||||
|
||||
CHUNKS = list(SPLITTER.split_tokens(token_list))
|
||||
EXPECTED_CHUNKS = [
|
||||
([100,101], TokenType.BPE),
|
||||
(1477, TokenType.SPECIAL),
|
||||
([100, 101], TokenType.BPE),
|
||||
([1477], TokenType.SPECIAL),
|
||||
]
|
||||
|
||||
assert len(CHUNKS) == len(EXPECTED_CHUNKS)
|
||||
@ -155,12 +158,12 @@ class TestSplitter:
|
||||
def test_split_token_decode_simple_malformed(self):
|
||||
# to test the token split into special and bpe
|
||||
SPLITTER = BPE.NanoSocratesSplitter(SYMBOL_REGEX, max_bpe_token_id=1473)
|
||||
token_list = [100,101,1477,100]
|
||||
token_list = [100, 101, 1477, 100]
|
||||
|
||||
CHUNKS = list(SPLITTER.split_tokens(token_list))
|
||||
EXPECTED_CHUNKS = [
|
||||
([100,101], TokenType.BPE),
|
||||
(1477, TokenType.SPECIAL),
|
||||
([100, 101], TokenType.BPE),
|
||||
([1477], TokenType.SPECIAL),
|
||||
]
|
||||
|
||||
assert len(CHUNKS) == len(EXPECTED_CHUNKS)
|
||||
@ -174,7 +177,6 @@ class TestSplitter:
|
||||
assert RECEIVED_TOKEN_TYPE == EXPECTED_TOKEN_TYPE
|
||||
|
||||
|
||||
|
||||
# Useful to debug weird cases
|
||||
if __name__ == "__main__":
|
||||
TestSplitter().test_split_trailing_text()
|
||||
TestSplitter().test_split_trailing_text()
|
||||
|
||||
@ -13,7 +13,7 @@ class TestTokeNano:
|
||||
VOCABULARY = {(ord("a"), ord("b")): 256, (256, 256): 257, (257, 257): 258}
|
||||
# EXPECTED = [258]
|
||||
|
||||
TOKE_NANO = TokeNanoCore(VOCABULARY)
|
||||
TOKE_NANO = TokeNanoCore(VOCABULARY, ["<SOT>", "<EOT>"])
|
||||
|
||||
ENCODED = TOKE_NANO.encode(TEXT)
|
||||
DECODED = TOKE_NANO.decode(ENCODED)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user