Added a way to recover iteration work
This commit is contained in:
parent
dbf1d99408
commit
66bcf6e55f
@ -1,9 +1,22 @@
|
||||
from collections import deque
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
import re
|
||||
from ..Classes import NanoSocratesBPE, NanoSocratesChunker, NanoSocratesSplitter, NanoSocratesBatchMemoryBPE
|
||||
from ..Classes import (
|
||||
NanoSocratesBPE,
|
||||
NanoSocratesChunker,
|
||||
NanoSocratesSplitter,
|
||||
NanoSocratesBatchMemoryBPE,
|
||||
)
|
||||
from ..Enums import TokenType
|
||||
from ..Utils import special_regex_maker, iterator_with_checks
|
||||
from ..Utils import (
|
||||
special_regex_maker,
|
||||
iterator_with_checks,
|
||||
save_nanos_vocabulary,
|
||||
load_nanos_vocabulary,
|
||||
save_json,
|
||||
load_json,
|
||||
)
|
||||
|
||||
|
||||
class NanoSocraTrainer:
|
||||
@ -15,7 +28,7 @@ class NanoSocraTrainer:
|
||||
chunk_size: int,
|
||||
merge_treshold: int = 0,
|
||||
max_iterations: int = 0,
|
||||
print_after_iterations: int = 1
|
||||
print_after_iterations: int = 1,
|
||||
) -> None:
|
||||
# Bytes
|
||||
BYTE_RESERVED_TOKENS = 256
|
||||
@ -30,7 +43,11 @@ class NanoSocraTrainer:
|
||||
self.__print_after_iterations = print_after_iterations
|
||||
|
||||
def trainBPE(
|
||||
self, path: Path, cache_dir: Path, bpe: NanoSocratesBPE | None = None
|
||||
self,
|
||||
path: Path,
|
||||
cache_dir: Path,
|
||||
bpe: NanoSocratesBPE | None = None,
|
||||
resume_from_iter: int = 0,
|
||||
) -> NanoSocratesBPE:
|
||||
|
||||
if not path.is_file():
|
||||
@ -49,45 +66,76 @@ class NanoSocraTrainer:
|
||||
exit = False
|
||||
cached = False
|
||||
current_iteration = 0
|
||||
input_path = path
|
||||
|
||||
PATH_GEN = self.__switch_paths(path, cache_dir)
|
||||
NEXT_ITERATION = resume_from_iter + 1 if resume_from_iter != 0 else 0
|
||||
|
||||
PATH_GEN = self.__switch_paths(path, cache_dir, NEXT_ITERATION)
|
||||
MEMORY_PATH_GEN = self.__switch_memory(cache_dir, resume_from_iter)
|
||||
|
||||
if resume_from_iter != 0:
|
||||
cached = True
|
||||
current_iteration = resume_from_iter
|
||||
input_path = next(PATH_GEN)
|
||||
# UGLY: fixes a bug immediately, unfortunately
|
||||
_, _ = next(MEMORY_PATH_GEN)
|
||||
_, voc_cache_path = next(MEMORY_PATH_GEN)
|
||||
vocabulary = load_nanos_vocabulary(voc_cache_path)
|
||||
BPE = NanoSocratesBPE(vocabulary)
|
||||
|
||||
while not exit:
|
||||
|
||||
|
||||
out_path = next(PATH_GEN)
|
||||
internal_cache_path, vocabulary_cache = next(MEMORY_PATH_GEN)
|
||||
|
||||
current_iteration = self.__increment_counter(current_iteration)
|
||||
LAST_VOC_SIZE = BPE.vocabulary_size
|
||||
|
||||
FILE = open(out_path, "w")
|
||||
|
||||
last_memory = None
|
||||
|
||||
for _, memory, output in self.__round_train(input_path, BPE, cached):
|
||||
last_memory = memory
|
||||
FILE.write(output)
|
||||
|
||||
FILE.close()
|
||||
|
||||
internal_cache = {
|
||||
"finished_iter": current_iteration,
|
||||
"read_from": f"{input_path}",
|
||||
"wrote_to": f"{out_path}",
|
||||
"at": datetime.datetime.now(datetime.timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M:%S.%f"
|
||||
)[:-3],
|
||||
}
|
||||
|
||||
VOCABULARY = BPE.vocabulary
|
||||
|
||||
save_json(internal_cache, internal_cache_path)
|
||||
save_nanos_vocabulary(VOCABULARY, vocabulary_cache)
|
||||
|
||||
cached = True
|
||||
input_path = out_path
|
||||
|
||||
NEW_VOC_SIZE = BPE.vocabulary_size
|
||||
|
||||
if current_iteration % self.__print_after_iterations == 0:
|
||||
|
||||
DELIMITER = "==============="
|
||||
|
||||
DEBUG = "\n".join([
|
||||
DEBUG = "\n".join(
|
||||
[
|
||||
DELIMITER,
|
||||
f"ITERATION: {current_iteration}",
|
||||
DELIMITER,
|
||||
f"\tVocabulary size: {BPE.vocabulary_size}\n",
|
||||
f"\tFrequencies:\n{last_memory.frequencies}\n",
|
||||
f"\tFrequencies:\n{last_memory.frequencies}\n", # type: ignore (pretty sure it's not None)
|
||||
f"\tvocabulary:\n{BPE.vocabulary}",
|
||||
DELIMITER,
|
||||
""
|
||||
])
|
||||
"",
|
||||
]
|
||||
)
|
||||
print(DEBUG)
|
||||
|
||||
if LAST_VOC_SIZE == NEW_VOC_SIZE:
|
||||
@ -104,12 +152,7 @@ class NanoSocraTrainer:
|
||||
|
||||
return BPE
|
||||
|
||||
def __round_train(
|
||||
self,
|
||||
path: Path,
|
||||
bpe: NanoSocratesBPE,
|
||||
cached: bool
|
||||
):
|
||||
def __round_train(self, path: Path, bpe: NanoSocratesBPE, cached: bool):
|
||||
|
||||
CHUNKER = NanoSocratesChunker(self.__chunk_size, self.__special_token_regex)
|
||||
SPLITTER = NanoSocratesSplitter(self.__special_token_regex)
|
||||
@ -121,9 +164,7 @@ class NanoSocraTrainer:
|
||||
|
||||
for chunk, last_chunk in CHUNKER_GENERATOR:
|
||||
|
||||
PIECE_GENERATOR = iterator_with_checks(
|
||||
SPLITTER.split_text(chunk)
|
||||
)
|
||||
PIECE_GENERATOR = iterator_with_checks(SPLITTER.split_text(chunk))
|
||||
|
||||
for piece, last_piece in PIECE_GENERATOR:
|
||||
|
||||
@ -165,19 +206,43 @@ class NanoSocraTrainer:
|
||||
INT_LIST = list(map(int, INTS.split(",")))
|
||||
return INT_LIST
|
||||
|
||||
def __switch_paths(self, path: Path, cache_path: Path):
|
||||
def __switch_paths(self, path: Path, cache_path: Path, initial_iteration: int):
|
||||
|
||||
yield path
|
||||
|
||||
TMP_1 = cache_path / "tmp1.txt"
|
||||
TMP_2 = cache_path / "tmp2.txt"
|
||||
CORPUS_TMP_1 = cache_path / "corpus-tmp1.txt"
|
||||
CORPUS_TMP_2 = cache_path / "corpus-tmp2.txt"
|
||||
|
||||
switch = True
|
||||
|
||||
if initial_iteration % 2 == 1:
|
||||
switch = False
|
||||
|
||||
del initial_iteration
|
||||
|
||||
while True:
|
||||
if switch:
|
||||
yield TMP_1
|
||||
yield CORPUS_TMP_1
|
||||
else:
|
||||
yield TMP_2
|
||||
yield CORPUS_TMP_2
|
||||
switch = not switch
|
||||
|
||||
def __switch_memory(self, cache_path: Path, initial_iteration: int):
|
||||
|
||||
INTERNAL_TMP_1 = cache_path / "internal-tmp1.json"
|
||||
INTERNAL_TMP_2 = cache_path / "internal-tmp2.json"
|
||||
|
||||
VOCAB_TMP_1 = cache_path / "voc-tmp1.json"
|
||||
VOCAB_TMP_2 = cache_path / "voc-tmp2.json"
|
||||
|
||||
switch = False
|
||||
|
||||
if initial_iteration % 2 == 1:
|
||||
switch = True
|
||||
|
||||
del initial_iteration
|
||||
|
||||
while True:
|
||||
if switch:
|
||||
yield (INTERNAL_TMP_1, VOCAB_TMP_1)
|
||||
else:
|
||||
yield (INTERNAL_TMP_2, VOCAB_TMP_2)
|
||||
switch = not switch
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user