NanoSocrates/Project_Model/Libs/BPE/Classes/NanoSocraTrainer.py

165 lines
4.3 KiB
Python
Raw Normal View History

2025-09-30 13:33:40 +02:00
from collections import deque
from pathlib import Path
import re
from ..Classes import NanoSocratesBPE, NanoSocratesChunker, NanoSocratesSplitter, NanoSocratesBatchMemoryBPE
from ..Enums import TokenType
from ..Utils import special_regex_maker, iterator_with_checks
class NanoSocraTrainer:
def __init__(
self,
max_vocabulary: int,
special_vocabulary: list[str],
chunk_size: int,
merge_treshold: int = 0,
max_iterations: int = 0,
) -> None:
# Bytes
BYTE_RESERVED_TOKENS = 256
SPECIAL_RESERVED_TOKENS = len(special_vocabulary)
RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS
self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS
self.__max_iterations = max_iterations
self.__chunk_size = chunk_size
self.__merge_treshold = merge_treshold
self.__special_token_regex = special_regex_maker(special_vocabulary)
def trainBPE(
self, path: Path, cache_dir: Path, bpe: NanoSocratesBPE | None = None
) -> NanoSocratesBPE:
if not path.is_file():
raise FileNotFoundError()
if not cache_dir.is_dir():
raise NotADirectoryError()
if bpe is None:
bpe = NanoSocratesBPE()
BPE = bpe
if BPE.vocabulary_size > self.__max_vocabulary:
return BPE
exit = False
cached = False
current_iteration = 0
PATH_GEN = self.__switch_paths(path, cache_dir)
input_path = next(PATH_GEN)
while not exit:
out_path = next(PATH_GEN)
current_iteration = self.__increment_counter(current_iteration)
LAST_VOC_SIZE = BPE.vocabulary_size
FILE = open(out_path, "w")
for _, _, output in self.__round_train(input_path, BPE, cached):
FILE.write(output)
FILE.close()
cached = True
input_path = out_path
NEW_VOC_SIZE = BPE.vocabulary_size
if LAST_VOC_SIZE == NEW_VOC_SIZE:
exit = True
continue
if current_iteration == self.__max_iterations:
exit = True
continue
if BPE.vocabulary_size == self.__max_vocabulary:
exit = True
continue
return BPE
def __round_train(
self,
path: Path,
bpe: NanoSocratesBPE,
cached: bool
):
CHUNKER = NanoSocratesChunker(self.__chunk_size, self.__special_token_regex)
SPLITTER = NanoSocratesSplitter(self.__special_token_regex)
BPE = bpe
memory = NanoSocratesBatchMemoryBPE({}, self.__merge_treshold)
CHUNKER_GENERATOR = iterator_with_checks(CHUNKER.chunk(path))
for chunk, last_chunk in CHUNKER_GENERATOR:
PIECE_GENERATOR = iterator_with_checks(
SPLITTER.split_text(chunk)
)
for piece, last_piece in PIECE_GENERATOR:
LAST_BATCH = last_chunk and last_piece
PIECE, TOKEN_TYPE = piece
if TOKEN_TYPE != TokenType.BPE:
_, _, out = BPE.fit([], memory, LAST_BATCH)
yield (BPE, memory, PIECE)
continue
PIECE_DATA = self.__make_list_ids(PIECE, cached)
_, _, out = BPE.fit(PIECE_DATA, memory, LAST_BATCH)
OUT_STRING = f"{out}"
yield (BPE, memory, OUT_STRING)
def __increment_counter(self, counter: int):
# What if overflows???
try:
counter += 1
except:
print("Integer overflow")
counter = 1
return counter
def __make_list_ids(self, corpus: str, cached: bool):
if not cached:
return list(map(ord, corpus))
REDUCED_CORPUS_LEN = len(corpus) -1
# Skip these cars "[" "]"
INTS = corpus[1:REDUCED_CORPUS_LEN]
INT_LIST = list(map(int,INTS.split(",")))
return INT_LIST
def __switch_paths(self, path: Path, cache_path: Path):
yield path
TMP_1 = cache_path / "tmp1.txt"
TMP_2 = cache_path / "tmp2.txt"
switch = True
while True:
if switch:
yield TMP_1
else:
yield TMP_2
switch = not switch