from collections import deque import datetime from pathlib import Path import re from ..Classes import ( NanoSocratesBPE, NanoSocratesChunker, NanoSocratesSplitter, NanoSocratesBatchMemoryBPE, ) from ..Enums import TokenType from ..Utils import ( special_regex_maker, iterator_with_checks, save_nanos_vocabulary, load_nanos_vocabulary, save_json, load_json, ) class NanoSocraTrainer: def __init__( self, max_vocabulary: int, special_vocabulary: list[str], chunk_size: int, merge_treshold: int = 0, max_iterations: int = 0, print_after_iterations: int = 1, ) -> None: # Bytes BYTE_RESERVED_TOKENS = 256 SPECIAL_RESERVED_TOKENS = len(special_vocabulary) RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS self.__max_iterations = max_iterations self.__chunk_size = chunk_size self.__merge_treshold = merge_treshold self.__special_token_regex = special_regex_maker(special_vocabulary) self.__print_after_iterations = print_after_iterations def trainBPE( self, path: Path, cache_dir: Path, bpe: NanoSocratesBPE | None = None, resume_from_iter: int = 0, ) -> NanoSocratesBPE: if not path.is_file(): raise FileNotFoundError() if not cache_dir.is_dir(): raise NotADirectoryError() if bpe is None: bpe = NanoSocratesBPE() BPE = bpe if BPE.vocabulary_size > self.__max_vocabulary: return BPE exit = False cached = False current_iteration = 0 input_path = path NEXT_ITERATION = resume_from_iter + 1 if resume_from_iter != 0 else 0 PATH_GEN = self.__switch_paths(path, cache_dir, NEXT_ITERATION) MEMORY_PATH_GEN = self.__switch_memory(cache_dir, resume_from_iter) if resume_from_iter != 0: cached = True current_iteration = resume_from_iter input_path = next(PATH_GEN) # UGLY: fixes a bug immediately, unfortunately _, _ = next(MEMORY_PATH_GEN) _, voc_cache_path = next(MEMORY_PATH_GEN) vocabulary = load_nanos_vocabulary(voc_cache_path) BPE = NanoSocratesBPE(vocabulary) while not exit: out_path = next(PATH_GEN) internal_cache_path, vocabulary_cache = next(MEMORY_PATH_GEN) current_iteration = self.__increment_counter(current_iteration) LAST_VOC_SIZE = BPE.vocabulary_size FILE = open(out_path, "w") last_memory = None for _, memory, output in self.__round_train(input_path, BPE, cached): last_memory = memory FILE.write(output) FILE.close() internal_cache = { "finished_iter": current_iteration, "read_from": f"{input_path}", "wrote_to": f"{out_path}", "at": datetime.datetime.now(datetime.timezone.utc).strftime( "%Y-%m-%d %H:%M:%S.%f" )[:-3], } VOCABULARY = BPE.vocabulary save_json(internal_cache, internal_cache_path) save_nanos_vocabulary(VOCABULARY, vocabulary_cache) cached = True input_path = out_path NEW_VOC_SIZE = BPE.vocabulary_size if current_iteration % self.__print_after_iterations == 0: DELIMITER = "===============" DEBUG = "\n".join( [ DELIMITER, f"ITERATION: {current_iteration}", DELIMITER, f"\tVocabulary size: {BPE.vocabulary_size}\n", f"\tFrequencies:\n{last_memory.frequencies}\n", # type: ignore (pretty sure it's not None) f"\tvocabulary:\n{BPE.vocabulary}", DELIMITER, "", ] ) print(DEBUG) if LAST_VOC_SIZE == NEW_VOC_SIZE: exit = True continue if current_iteration == self.__max_iterations: exit = True continue if BPE.vocabulary_size == self.__max_vocabulary: exit = True continue return BPE def __round_train(self, path: Path, bpe: NanoSocratesBPE, cached: bool): CHUNKER = NanoSocratesChunker(self.__chunk_size, self.__special_token_regex) SPLITTER = NanoSocratesSplitter(self.__special_token_regex) BPE = bpe memory = NanoSocratesBatchMemoryBPE({}, self.__merge_treshold) CHUNKER_GENERATOR = iterator_with_checks(CHUNKER.chunk(path)) for chunk, last_chunk in CHUNKER_GENERATOR: PIECE_GENERATOR = iterator_with_checks(SPLITTER.split_text(chunk)) for piece, last_piece in PIECE_GENERATOR: LAST_BATCH = last_chunk and last_piece PIECE, TOKEN_TYPE = piece if TOKEN_TYPE != TokenType.BPE: _, _, out = BPE.fit([], memory, LAST_BATCH) yield (BPE, memory, PIECE) continue PIECE_DATA = self.__make_list_ids(PIECE, cached) _, _, out = BPE.fit(PIECE_DATA, memory, LAST_BATCH) OUT_STRING = f"{out}" yield (BPE, memory, OUT_STRING) def __increment_counter(self, counter: int): # What if overflows??? try: counter += 1 except: print("Integer overflow") counter = 1 return counter def __make_list_ids(self, corpus: str, cached: bool): if not cached: return list(corpus.encode("utf-8")) REDUCED_CORPUS_LEN = len(corpus) - 1 # Skip these cars "[" "]" INTS = corpus[1:REDUCED_CORPUS_LEN] INT_LIST = list(map(int, INTS.split(","))) return INT_LIST def __switch_paths(self, path: Path, cache_path: Path, initial_iteration: int): CORPUS_TMP_1 = cache_path / "corpus-tmp1.txt" CORPUS_TMP_2 = cache_path / "corpus-tmp2.txt" switch = True if initial_iteration % 2 == 1: switch = False del initial_iteration while True: if switch: yield CORPUS_TMP_1 else: yield CORPUS_TMP_2 switch = not switch def __switch_memory(self, cache_path: Path, initial_iteration: int): INTERNAL_TMP_1 = cache_path / "internal-tmp1.json" INTERNAL_TMP_2 = cache_path / "internal-tmp2.json" VOCAB_TMP_1 = cache_path / "voc-tmp1.json" VOCAB_TMP_2 = cache_path / "voc-tmp2.json" switch = False if initial_iteration % 2 == 1: switch = True del initial_iteration while True: if switch: yield (INTERNAL_TMP_1, VOCAB_TMP_1) else: yield (INTERNAL_TMP_2, VOCAB_TMP_2) switch = not switch