Fixed bug for utf-8 conversion
This commit is contained in:
parent
ccacea18d8
commit
89a0a1f4bb
@ -15,6 +15,7 @@ class NanoSocraTrainer:
|
||||
chunk_size: int,
|
||||
merge_treshold: int = 0,
|
||||
max_iterations: int = 0,
|
||||
print_after_iterations: int = 1
|
||||
) -> None:
|
||||
# Bytes
|
||||
BYTE_RESERVED_TOKENS = 256
|
||||
@ -26,6 +27,7 @@ class NanoSocraTrainer:
|
||||
self.__chunk_size = chunk_size
|
||||
self.__merge_treshold = merge_treshold
|
||||
self.__special_token_regex = special_regex_maker(special_vocabulary)
|
||||
self.__print_after_iterations = print_after_iterations
|
||||
|
||||
def trainBPE(
|
||||
self, path: Path, cache_dir: Path, bpe: NanoSocratesBPE | None = None
|
||||
@ -61,7 +63,9 @@ class NanoSocraTrainer:
|
||||
|
||||
FILE = open(out_path, "w")
|
||||
|
||||
for _, _, output in self.__round_train(input_path, BPE, cached):
|
||||
last_memory = None
|
||||
for _, memory, output in self.__round_train(input_path, BPE, cached):
|
||||
last_memory = memory
|
||||
FILE.write(output)
|
||||
|
||||
FILE.close()
|
||||
@ -71,6 +75,21 @@ class NanoSocraTrainer:
|
||||
|
||||
NEW_VOC_SIZE = BPE.vocabulary_size
|
||||
|
||||
if current_iteration % self.__print_after_iterations == 0:
|
||||
DELIMITER = "==============="
|
||||
|
||||
DEBUG = "\n".join([
|
||||
DELIMITER,
|
||||
f"ITERATION: {current_iteration}",
|
||||
DELIMITER,
|
||||
f"\tVocabulary size: {BPE.vocabulary_size}\n",
|
||||
f"\tFrequencies:\n{last_memory.frequencies}\n",
|
||||
f"\tvocabulary:\n{BPE.vocabulary}",
|
||||
DELIMITER,
|
||||
""
|
||||
])
|
||||
print(DEBUG)
|
||||
|
||||
if LAST_VOC_SIZE == NEW_VOC_SIZE:
|
||||
exit = True
|
||||
continue
|
||||
@ -137,7 +156,7 @@ class NanoSocraTrainer:
|
||||
def __make_list_ids(self, corpus: str, cached: bool):
|
||||
|
||||
if not cached:
|
||||
return list(map(ord, corpus))
|
||||
return list(corpus.encode("utf-8"))
|
||||
|
||||
REDUCED_CORPUS_LEN = len(corpus) -1
|
||||
|
||||
|
||||
@ -90,7 +90,7 @@ class NanoSocratesBPE(Encoder):
|
||||
|
||||
def encode(self, piece: str) -> list[int]:
|
||||
|
||||
current_piece = list(map(ord, piece))
|
||||
current_piece = list(piece.encode("utf-8"))
|
||||
new_piece = self.__round_encode(current_piece)
|
||||
|
||||
while len(current_piece) != len(new_piece):
|
||||
@ -128,7 +128,7 @@ class NanoSocratesBPE(Encoder):
|
||||
|
||||
return NEW_PIECE
|
||||
|
||||
# TODO: decode
|
||||
# TODO: Remake decode to take a list of token IDs
|
||||
def decode(self, token_id: int) -> str:
|
||||
|
||||
token_stack: list[int] = [token_id]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user