2025-09-30 13:33:40 +02:00
|
|
|
from collections import deque
|
2025-10-01 12:21:42 +02:00
|
|
|
import datetime
|
2025-09-30 13:33:40 +02:00
|
|
|
from pathlib import Path
|
|
|
|
|
import re
|
2025-10-01 12:21:42 +02:00
|
|
|
from ..Classes import (
|
|
|
|
|
NanoSocratesBPE,
|
|
|
|
|
NanoSocratesChunker,
|
|
|
|
|
NanoSocratesSplitter,
|
|
|
|
|
NanoSocratesBatchMemoryBPE,
|
|
|
|
|
)
|
2025-09-30 13:33:40 +02:00
|
|
|
from ..Enums import TokenType
|
2025-10-01 12:21:42 +02:00
|
|
|
from ..Utils import (
|
|
|
|
|
special_regex_maker,
|
|
|
|
|
iterator_with_checks,
|
|
|
|
|
save_nanos_vocabulary,
|
|
|
|
|
load_nanos_vocabulary,
|
|
|
|
|
save_json,
|
|
|
|
|
load_json,
|
|
|
|
|
)
|
2025-09-30 13:33:40 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class NanoSocraTrainer:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
max_vocabulary: int,
|
|
|
|
|
special_vocabulary: list[str],
|
|
|
|
|
chunk_size: int,
|
|
|
|
|
merge_treshold: int = 0,
|
|
|
|
|
max_iterations: int = 0,
|
2025-10-01 12:21:42 +02:00
|
|
|
print_after_iterations: int = 1,
|
2025-09-30 13:33:40 +02:00
|
|
|
) -> None:
|
|
|
|
|
# Bytes
|
|
|
|
|
BYTE_RESERVED_TOKENS = 256
|
|
|
|
|
SPECIAL_RESERVED_TOKENS = len(special_vocabulary)
|
|
|
|
|
RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS
|
|
|
|
|
|
|
|
|
|
self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS
|
|
|
|
|
self.__max_iterations = max_iterations
|
|
|
|
|
self.__chunk_size = chunk_size
|
|
|
|
|
self.__merge_treshold = merge_treshold
|
|
|
|
|
self.__special_token_regex = special_regex_maker(special_vocabulary)
|
2025-09-30 23:58:31 +02:00
|
|
|
self.__print_after_iterations = print_after_iterations
|
2025-09-30 13:33:40 +02:00
|
|
|
|
|
|
|
|
def trainBPE(
|
2025-10-01 12:21:42 +02:00
|
|
|
self,
|
|
|
|
|
path: Path,
|
|
|
|
|
cache_dir: Path,
|
|
|
|
|
bpe: NanoSocratesBPE | None = None,
|
|
|
|
|
resume_from_iter: int = 0,
|
2025-09-30 13:33:40 +02:00
|
|
|
) -> NanoSocratesBPE:
|
|
|
|
|
|
|
|
|
|
if not path.is_file():
|
|
|
|
|
raise FileNotFoundError()
|
|
|
|
|
|
|
|
|
|
if not cache_dir.is_dir():
|
|
|
|
|
raise NotADirectoryError()
|
|
|
|
|
|
|
|
|
|
if bpe is None:
|
|
|
|
|
bpe = NanoSocratesBPE()
|
|
|
|
|
BPE = bpe
|
|
|
|
|
|
|
|
|
|
if BPE.vocabulary_size > self.__max_vocabulary:
|
|
|
|
|
return BPE
|
|
|
|
|
|
|
|
|
|
exit = False
|
|
|
|
|
cached = False
|
|
|
|
|
current_iteration = 0
|
2025-10-01 12:21:42 +02:00
|
|
|
input_path = path
|
2025-09-30 13:33:40 +02:00
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
NEXT_ITERATION = resume_from_iter + 1 if resume_from_iter != 0 else 0
|
2025-09-30 13:33:40 +02:00
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
PATH_GEN = self.__switch_paths(path, cache_dir, NEXT_ITERATION)
|
|
|
|
|
MEMORY_PATH_GEN = self.__switch_memory(cache_dir, resume_from_iter)
|
2025-09-30 13:33:40 +02:00
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
if resume_from_iter != 0:
|
|
|
|
|
cached = True
|
|
|
|
|
current_iteration = resume_from_iter
|
|
|
|
|
input_path = next(PATH_GEN)
|
|
|
|
|
# UGLY: fixes a bug immediately, unfortunately
|
|
|
|
|
_, _ = next(MEMORY_PATH_GEN)
|
|
|
|
|
_, voc_cache_path = next(MEMORY_PATH_GEN)
|
|
|
|
|
vocabulary = load_nanos_vocabulary(voc_cache_path)
|
|
|
|
|
BPE = NanoSocratesBPE(vocabulary)
|
2025-09-30 13:33:40 +02:00
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
while not exit:
|
2025-09-30 13:33:40 +02:00
|
|
|
|
|
|
|
|
out_path = next(PATH_GEN)
|
2025-10-01 12:21:42 +02:00
|
|
|
internal_cache_path, vocabulary_cache = next(MEMORY_PATH_GEN)
|
|
|
|
|
|
2025-09-30 13:33:40 +02:00
|
|
|
current_iteration = self.__increment_counter(current_iteration)
|
|
|
|
|
LAST_VOC_SIZE = BPE.vocabulary_size
|
|
|
|
|
|
|
|
|
|
FILE = open(out_path, "w")
|
|
|
|
|
|
2025-09-30 23:58:31 +02:00
|
|
|
last_memory = None
|
2025-10-01 12:21:42 +02:00
|
|
|
|
2025-09-30 23:58:31 +02:00
|
|
|
for _, memory, output in self.__round_train(input_path, BPE, cached):
|
|
|
|
|
last_memory = memory
|
2025-09-30 13:33:40 +02:00
|
|
|
FILE.write(output)
|
|
|
|
|
|
|
|
|
|
FILE.close()
|
|
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
internal_cache = {
|
|
|
|
|
"finished_iter": current_iteration,
|
|
|
|
|
"read_from": f"{input_path}",
|
|
|
|
|
"wrote_to": f"{out_path}",
|
|
|
|
|
"at": datetime.datetime.now(datetime.timezone.utc).strftime(
|
|
|
|
|
"%Y-%m-%d %H:%M:%S.%f"
|
|
|
|
|
)[:-3],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VOCABULARY = BPE.vocabulary
|
|
|
|
|
|
|
|
|
|
save_json(internal_cache, internal_cache_path)
|
|
|
|
|
save_nanos_vocabulary(VOCABULARY, vocabulary_cache)
|
|
|
|
|
|
2025-09-30 13:33:40 +02:00
|
|
|
cached = True
|
|
|
|
|
input_path = out_path
|
|
|
|
|
|
|
|
|
|
NEW_VOC_SIZE = BPE.vocabulary_size
|
|
|
|
|
|
2025-09-30 23:58:31 +02:00
|
|
|
if current_iteration % self.__print_after_iterations == 0:
|
2025-10-01 12:21:42 +02:00
|
|
|
|
2025-09-30 23:58:31 +02:00
|
|
|
DELIMITER = "==============="
|
|
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
DEBUG = "\n".join(
|
|
|
|
|
[
|
|
|
|
|
DELIMITER,
|
|
|
|
|
f"ITERATION: {current_iteration}",
|
|
|
|
|
DELIMITER,
|
|
|
|
|
f"\tVocabulary size: {BPE.vocabulary_size}\n",
|
|
|
|
|
f"\tFrequencies:\n{last_memory.frequencies}\n", # type: ignore (pretty sure it's not None)
|
|
|
|
|
f"\tvocabulary:\n{BPE.vocabulary}",
|
|
|
|
|
DELIMITER,
|
|
|
|
|
"",
|
|
|
|
|
]
|
|
|
|
|
)
|
2025-09-30 23:58:31 +02:00
|
|
|
print(DEBUG)
|
|
|
|
|
|
2025-09-30 13:33:40 +02:00
|
|
|
if LAST_VOC_SIZE == NEW_VOC_SIZE:
|
|
|
|
|
exit = True
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if current_iteration == self.__max_iterations:
|
|
|
|
|
exit = True
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if BPE.vocabulary_size == self.__max_vocabulary:
|
|
|
|
|
exit = True
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
return BPE
|
|
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
def __round_train(self, path: Path, bpe: NanoSocratesBPE, cached: bool):
|
2025-09-30 13:33:40 +02:00
|
|
|
|
|
|
|
|
CHUNKER = NanoSocratesChunker(self.__chunk_size, self.__special_token_regex)
|
|
|
|
|
SPLITTER = NanoSocratesSplitter(self.__special_token_regex)
|
|
|
|
|
|
|
|
|
|
BPE = bpe
|
|
|
|
|
memory = NanoSocratesBatchMemoryBPE({}, self.__merge_treshold)
|
|
|
|
|
|
|
|
|
|
CHUNKER_GENERATOR = iterator_with_checks(CHUNKER.chunk(path))
|
|
|
|
|
|
|
|
|
|
for chunk, last_chunk in CHUNKER_GENERATOR:
|
|
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
PIECE_GENERATOR = iterator_with_checks(SPLITTER.split_text(chunk))
|
2025-09-30 13:33:40 +02:00
|
|
|
|
|
|
|
|
for piece, last_piece in PIECE_GENERATOR:
|
|
|
|
|
|
|
|
|
|
LAST_BATCH = last_chunk and last_piece
|
|
|
|
|
PIECE, TOKEN_TYPE = piece
|
|
|
|
|
|
|
|
|
|
if TOKEN_TYPE != TokenType.BPE:
|
|
|
|
|
_, _, out = BPE.fit([], memory, LAST_BATCH)
|
|
|
|
|
yield (BPE, memory, PIECE)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
PIECE_DATA = self.__make_list_ids(PIECE, cached)
|
|
|
|
|
|
|
|
|
|
_, _, out = BPE.fit(PIECE_DATA, memory, LAST_BATCH)
|
|
|
|
|
|
|
|
|
|
OUT_STRING = f"{out}"
|
|
|
|
|
yield (BPE, memory, OUT_STRING)
|
|
|
|
|
|
|
|
|
|
def __increment_counter(self, counter: int):
|
|
|
|
|
|
|
|
|
|
# What if overflows???
|
|
|
|
|
try:
|
|
|
|
|
counter += 1
|
|
|
|
|
except:
|
|
|
|
|
print("Integer overflow")
|
|
|
|
|
counter = 1
|
|
|
|
|
|
|
|
|
|
return counter
|
|
|
|
|
|
|
|
|
|
def __make_list_ids(self, corpus: str, cached: bool):
|
|
|
|
|
|
|
|
|
|
if not cached:
|
2025-09-30 23:58:31 +02:00
|
|
|
return list(corpus.encode("utf-8"))
|
2025-09-30 13:33:40 +02:00
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
REDUCED_CORPUS_LEN = len(corpus) - 1
|
2025-09-30 13:33:40 +02:00
|
|
|
|
|
|
|
|
# Skip these cars "[" "]"
|
|
|
|
|
INTS = corpus[1:REDUCED_CORPUS_LEN]
|
2025-10-01 12:21:42 +02:00
|
|
|
INT_LIST = list(map(int, INTS.split(",")))
|
2025-09-30 13:33:40 +02:00
|
|
|
return INT_LIST
|
|
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
def __switch_paths(self, path: Path, cache_path: Path, initial_iteration: int):
|
2025-09-30 13:33:40 +02:00
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
CORPUS_TMP_1 = cache_path / "corpus-tmp1.txt"
|
|
|
|
|
CORPUS_TMP_2 = cache_path / "corpus-tmp2.txt"
|
2025-09-30 13:33:40 +02:00
|
|
|
|
|
|
|
|
switch = True
|
|
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
if initial_iteration % 2 == 1:
|
|
|
|
|
switch = False
|
|
|
|
|
|
|
|
|
|
del initial_iteration
|
|
|
|
|
|
2025-09-30 13:33:40 +02:00
|
|
|
while True:
|
|
|
|
|
if switch:
|
2025-10-01 12:21:42 +02:00
|
|
|
yield CORPUS_TMP_1
|
2025-09-30 13:33:40 +02:00
|
|
|
else:
|
2025-10-01 12:21:42 +02:00
|
|
|
yield CORPUS_TMP_2
|
2025-09-30 13:33:40 +02:00
|
|
|
switch = not switch
|
|
|
|
|
|
2025-10-01 12:21:42 +02:00
|
|
|
def __switch_memory(self, cache_path: Path, initial_iteration: int):
|
|
|
|
|
|
|
|
|
|
INTERNAL_TMP_1 = cache_path / "internal-tmp1.json"
|
|
|
|
|
INTERNAL_TMP_2 = cache_path / "internal-tmp2.json"
|
|
|
|
|
|
|
|
|
|
VOCAB_TMP_1 = cache_path / "voc-tmp1.json"
|
|
|
|
|
VOCAB_TMP_2 = cache_path / "voc-tmp2.json"
|
|
|
|
|
|
|
|
|
|
switch = False
|
|
|
|
|
|
|
|
|
|
if initial_iteration % 2 == 1:
|
|
|
|
|
switch = True
|
|
|
|
|
|
|
|
|
|
del initial_iteration
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
if switch:
|
|
|
|
|
yield (INTERNAL_TMP_1, VOCAB_TMP_1)
|
|
|
|
|
else:
|
|
|
|
|
yield (INTERNAL_TMP_2, VOCAB_TMP_2)
|
|
|
|
|
switch = not switch
|