Fixed bug for utf-8 conversion

This commit is contained in:
Christian Risi 2025-09-30 23:58:31 +02:00
parent ccacea18d8
commit 89a0a1f4bb
2 changed files with 23 additions and 4 deletions

View File

@ -15,6 +15,7 @@ class NanoSocraTrainer:
chunk_size: int,
merge_treshold: int = 0,
max_iterations: int = 0,
print_after_iterations: int = 1
) -> None:
# Bytes
BYTE_RESERVED_TOKENS = 256
@ -26,6 +27,7 @@ class NanoSocraTrainer:
self.__chunk_size = chunk_size
self.__merge_treshold = merge_treshold
self.__special_token_regex = special_regex_maker(special_vocabulary)
self.__print_after_iterations = print_after_iterations
def trainBPE(
self, path: Path, cache_dir: Path, bpe: NanoSocratesBPE | None = None
@ -61,7 +63,9 @@ class NanoSocraTrainer:
FILE = open(out_path, "w")
for _, _, output in self.__round_train(input_path, BPE, cached):
last_memory = None
for _, memory, output in self.__round_train(input_path, BPE, cached):
last_memory = memory
FILE.write(output)
FILE.close()
@ -71,6 +75,21 @@ class NanoSocraTrainer:
NEW_VOC_SIZE = BPE.vocabulary_size
if current_iteration % self.__print_after_iterations == 0:
DELIMITER = "==============="
DEBUG = "\n".join([
DELIMITER,
f"ITERATION: {current_iteration}",
DELIMITER,
f"\tVocabulary size: {BPE.vocabulary_size}\n",
f"\tFrequencies:\n{last_memory.frequencies}\n",
f"\tvocabulary:\n{BPE.vocabulary}",
DELIMITER,
""
])
print(DEBUG)
if LAST_VOC_SIZE == NEW_VOC_SIZE:
exit = True
continue
@ -137,7 +156,7 @@ class NanoSocraTrainer:
def __make_list_ids(self, corpus: str, cached: bool):
if not cached:
return list(map(ord, corpus))
return list(corpus.encode("utf-8"))
REDUCED_CORPUS_LEN = len(corpus) -1

View File

@ -90,7 +90,7 @@ class NanoSocratesBPE(Encoder):
def encode(self, piece: str) -> list[int]:
current_piece = list(map(ord, piece))
current_piece = list(piece.encode("utf-8"))
new_piece = self.__round_encode(current_piece)
while len(current_piece) != len(new_piece):
@ -128,7 +128,7 @@ class NanoSocratesBPE(Encoder):
return NEW_PIECE
# TODO: decode
# TODO: Remake decode to take a list of token IDs
def decode(self, token_id: int) -> str:
token_stack: list[int] = [token_id]