Added a way to recover iteration work

This commit is contained in:
Christian Risi 2025-10-01 12:21:42 +02:00
parent dbf1d99408
commit 66bcf6e55f

View File

@ -1,9 +1,22 @@
from collections import deque from collections import deque
import datetime
from pathlib import Path from pathlib import Path
import re import re
from ..Classes import NanoSocratesBPE, NanoSocratesChunker, NanoSocratesSplitter, NanoSocratesBatchMemoryBPE from ..Classes import (
NanoSocratesBPE,
NanoSocratesChunker,
NanoSocratesSplitter,
NanoSocratesBatchMemoryBPE,
)
from ..Enums import TokenType from ..Enums import TokenType
from ..Utils import special_regex_maker, iterator_with_checks from ..Utils import (
special_regex_maker,
iterator_with_checks,
save_nanos_vocabulary,
load_nanos_vocabulary,
save_json,
load_json,
)
class NanoSocraTrainer: class NanoSocraTrainer:
@ -15,7 +28,7 @@ class NanoSocraTrainer:
chunk_size: int, chunk_size: int,
merge_treshold: int = 0, merge_treshold: int = 0,
max_iterations: int = 0, max_iterations: int = 0,
print_after_iterations: int = 1 print_after_iterations: int = 1,
) -> None: ) -> None:
# Bytes # Bytes
BYTE_RESERVED_TOKENS = 256 BYTE_RESERVED_TOKENS = 256
@ -30,7 +43,11 @@ class NanoSocraTrainer:
self.__print_after_iterations = print_after_iterations self.__print_after_iterations = print_after_iterations
def trainBPE( def trainBPE(
self, path: Path, cache_dir: Path, bpe: NanoSocratesBPE | None = None self,
path: Path,
cache_dir: Path,
bpe: NanoSocratesBPE | None = None,
resume_from_iter: int = 0,
) -> NanoSocratesBPE: ) -> NanoSocratesBPE:
if not path.is_file(): if not path.is_file():
@ -49,45 +66,76 @@ class NanoSocraTrainer:
exit = False exit = False
cached = False cached = False
current_iteration = 0 current_iteration = 0
input_path = path
PATH_GEN = self.__switch_paths(path, cache_dir) NEXT_ITERATION = resume_from_iter + 1 if resume_from_iter != 0 else 0
input_path = next(PATH_GEN) PATH_GEN = self.__switch_paths(path, cache_dir, NEXT_ITERATION)
MEMORY_PATH_GEN = self.__switch_memory(cache_dir, resume_from_iter)
if resume_from_iter != 0:
cached = True
current_iteration = resume_from_iter
input_path = next(PATH_GEN)
# UGLY: fixes a bug immediately, unfortunately
_, _ = next(MEMORY_PATH_GEN)
_, voc_cache_path = next(MEMORY_PATH_GEN)
vocabulary = load_nanos_vocabulary(voc_cache_path)
BPE = NanoSocratesBPE(vocabulary)
while not exit: while not exit:
out_path = next(PATH_GEN) out_path = next(PATH_GEN)
internal_cache_path, vocabulary_cache = next(MEMORY_PATH_GEN)
current_iteration = self.__increment_counter(current_iteration) current_iteration = self.__increment_counter(current_iteration)
LAST_VOC_SIZE = BPE.vocabulary_size LAST_VOC_SIZE = BPE.vocabulary_size
FILE = open(out_path, "w") FILE = open(out_path, "w")
last_memory = None last_memory = None
for _, memory, output in self.__round_train(input_path, BPE, cached): for _, memory, output in self.__round_train(input_path, BPE, cached):
last_memory = memory last_memory = memory
FILE.write(output) FILE.write(output)
FILE.close() FILE.close()
internal_cache = {
"finished_iter": current_iteration,
"read_from": f"{input_path}",
"wrote_to": f"{out_path}",
"at": datetime.datetime.now(datetime.timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S.%f"
)[:-3],
}
VOCABULARY = BPE.vocabulary
save_json(internal_cache, internal_cache_path)
save_nanos_vocabulary(VOCABULARY, vocabulary_cache)
cached = True cached = True
input_path = out_path input_path = out_path
NEW_VOC_SIZE = BPE.vocabulary_size NEW_VOC_SIZE = BPE.vocabulary_size
if current_iteration % self.__print_after_iterations == 0: if current_iteration % self.__print_after_iterations == 0:
DELIMITER = "===============" DELIMITER = "==============="
DEBUG = "\n".join([ DEBUG = "\n".join(
DELIMITER, [
f"ITERATION: {current_iteration}", DELIMITER,
DELIMITER, f"ITERATION: {current_iteration}",
f"\tVocabulary size: {BPE.vocabulary_size}\n", DELIMITER,
f"\tFrequencies:\n{last_memory.frequencies}\n", f"\tVocabulary size: {BPE.vocabulary_size}\n",
f"\tvocabulary:\n{BPE.vocabulary}", f"\tFrequencies:\n{last_memory.frequencies}\n", # type: ignore (pretty sure it's not None)
DELIMITER, f"\tvocabulary:\n{BPE.vocabulary}",
"" DELIMITER,
]) "",
]
)
print(DEBUG) print(DEBUG)
if LAST_VOC_SIZE == NEW_VOC_SIZE: if LAST_VOC_SIZE == NEW_VOC_SIZE:
@ -104,12 +152,7 @@ class NanoSocraTrainer:
return BPE return BPE
def __round_train( def __round_train(self, path: Path, bpe: NanoSocratesBPE, cached: bool):
self,
path: Path,
bpe: NanoSocratesBPE,
cached: bool
):
CHUNKER = NanoSocratesChunker(self.__chunk_size, self.__special_token_regex) CHUNKER = NanoSocratesChunker(self.__chunk_size, self.__special_token_regex)
SPLITTER = NanoSocratesSplitter(self.__special_token_regex) SPLITTER = NanoSocratesSplitter(self.__special_token_regex)
@ -121,9 +164,7 @@ class NanoSocraTrainer:
for chunk, last_chunk in CHUNKER_GENERATOR: for chunk, last_chunk in CHUNKER_GENERATOR:
PIECE_GENERATOR = iterator_with_checks( PIECE_GENERATOR = iterator_with_checks(SPLITTER.split_text(chunk))
SPLITTER.split_text(chunk)
)
for piece, last_piece in PIECE_GENERATOR: for piece, last_piece in PIECE_GENERATOR:
@ -158,26 +199,50 @@ class NanoSocraTrainer:
if not cached: if not cached:
return list(corpus.encode("utf-8")) return list(corpus.encode("utf-8"))
REDUCED_CORPUS_LEN = len(corpus) -1 REDUCED_CORPUS_LEN = len(corpus) - 1
# Skip these cars "[" "]" # Skip these cars "[" "]"
INTS = corpus[1:REDUCED_CORPUS_LEN] INTS = corpus[1:REDUCED_CORPUS_LEN]
INT_LIST = list(map(int,INTS.split(","))) INT_LIST = list(map(int, INTS.split(",")))
return INT_LIST return INT_LIST
def __switch_paths(self, path: Path, cache_path: Path): def __switch_paths(self, path: Path, cache_path: Path, initial_iteration: int):
yield path CORPUS_TMP_1 = cache_path / "corpus-tmp1.txt"
CORPUS_TMP_2 = cache_path / "corpus-tmp2.txt"
TMP_1 = cache_path / "tmp1.txt"
TMP_2 = cache_path / "tmp2.txt"
switch = True switch = True
if initial_iteration % 2 == 1:
switch = False
del initial_iteration
while True: while True:
if switch: if switch:
yield TMP_1 yield CORPUS_TMP_1
else: else:
yield TMP_2 yield CORPUS_TMP_2
switch = not switch switch = not switch
def __switch_memory(self, cache_path: Path, initial_iteration: int):
INTERNAL_TMP_1 = cache_path / "internal-tmp1.json"
INTERNAL_TMP_2 = cache_path / "internal-tmp2.json"
VOCAB_TMP_1 = cache_path / "voc-tmp1.json"
VOCAB_TMP_2 = cache_path / "voc-tmp2.json"
switch = False
if initial_iteration % 2 == 1:
switch = True
del initial_iteration
while True:
if switch:
yield (INTERNAL_TMP_1, VOCAB_TMP_1)
else:
yield (INTERNAL_TMP_2, VOCAB_TMP_2)
switch = not switch