diff --git a/Project_Model/Libs/BPE/Classes/NanoSocraTrainer.py b/Project_Model/Libs/BPE/Classes/NanoSocraTrainer.py new file mode 100644 index 0000000..1d6d429 --- /dev/null +++ b/Project_Model/Libs/BPE/Classes/NanoSocraTrainer.py @@ -0,0 +1,164 @@ +from collections import deque +from pathlib import Path +import re +from ..Classes import NanoSocratesBPE, NanoSocratesChunker, NanoSocratesSplitter, NanoSocratesBatchMemoryBPE +from ..Enums import TokenType +from ..Utils import special_regex_maker, iterator_with_checks + + +class NanoSocraTrainer: + + def __init__( + self, + max_vocabulary: int, + special_vocabulary: list[str], + chunk_size: int, + merge_treshold: int = 0, + max_iterations: int = 0, + ) -> None: + # Bytes + BYTE_RESERVED_TOKENS = 256 + SPECIAL_RESERVED_TOKENS = len(special_vocabulary) + RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS + + self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS + self.__max_iterations = max_iterations + self.__chunk_size = chunk_size + self.__merge_treshold = merge_treshold + self.__special_token_regex = special_regex_maker(special_vocabulary) + + def trainBPE( + self, path: Path, cache_dir: Path, bpe: NanoSocratesBPE | None = None + ) -> NanoSocratesBPE: + + if not path.is_file(): + raise FileNotFoundError() + + if not cache_dir.is_dir(): + raise NotADirectoryError() + + if bpe is None: + bpe = NanoSocratesBPE() + BPE = bpe + + if BPE.vocabulary_size > self.__max_vocabulary: + return BPE + + exit = False + cached = False + current_iteration = 0 + + PATH_GEN = self.__switch_paths(path, cache_dir) + + input_path = next(PATH_GEN) + + while not exit: + + + out_path = next(PATH_GEN) + current_iteration = self.__increment_counter(current_iteration) + LAST_VOC_SIZE = BPE.vocabulary_size + + FILE = open(out_path, "w") + + for _, _, output in self.__round_train(input_path, BPE, cached): + FILE.write(output) + + FILE.close() + + cached = True + input_path = out_path + + NEW_VOC_SIZE = BPE.vocabulary_size + + if LAST_VOC_SIZE == NEW_VOC_SIZE: + exit = True + continue + + if current_iteration == self.__max_iterations: + exit = True + continue + + if BPE.vocabulary_size == self.__max_vocabulary: + exit = True + continue + + return BPE + + def __round_train( + self, + path: Path, + bpe: NanoSocratesBPE, + cached: bool + ): + + CHUNKER = NanoSocratesChunker(self.__chunk_size, self.__special_token_regex) + SPLITTER = NanoSocratesSplitter(self.__special_token_regex) + + BPE = bpe + memory = NanoSocratesBatchMemoryBPE({}, self.__merge_treshold) + + CHUNKER_GENERATOR = iterator_with_checks(CHUNKER.chunk(path)) + + for chunk, last_chunk in CHUNKER_GENERATOR: + + PIECE_GENERATOR = iterator_with_checks( + SPLITTER.split_text(chunk) + ) + + for piece, last_piece in PIECE_GENERATOR: + + LAST_BATCH = last_chunk and last_piece + PIECE, TOKEN_TYPE = piece + + if TOKEN_TYPE != TokenType.BPE: + _, _, out = BPE.fit([], memory, LAST_BATCH) + yield (BPE, memory, PIECE) + continue + + PIECE_DATA = self.__make_list_ids(PIECE, cached) + + _, _, out = BPE.fit(PIECE_DATA, memory, LAST_BATCH) + + OUT_STRING = f"{out}" + yield (BPE, memory, OUT_STRING) + + def __increment_counter(self, counter: int): + + # What if overflows??? + try: + counter += 1 + except: + print("Integer overflow") + counter = 1 + + return counter + + def __make_list_ids(self, corpus: str, cached: bool): + + if not cached: + return list(map(ord, corpus)) + + REDUCED_CORPUS_LEN = len(corpus) -1 + + # Skip these cars "[" "]" + INTS = corpus[1:REDUCED_CORPUS_LEN] + INT_LIST = list(map(int,INTS.split(","))) + return INT_LIST + + def __switch_paths(self, path: Path, cache_path: Path): + + yield path + + TMP_1 = cache_path / "tmp1.txt" + TMP_2 = cache_path / "tmp2.txt" + + switch = True + + while True: + if switch: + yield TMP_1 + else: + yield TMP_2 + switch = not switch +