from collections import deque import datetime from pathlib import Path import re from ..Classes import ( NanoSocratesBPE, NanoSocratesChunker, NanoSocratesSplitter, NanoSocratesBatchMemoryBPE, ) from ..Enums import TokenType from ..Utils import ( special_regex_maker, iterator_with_checks, save_nanos_vocabulary, load_nanos_vocabulary, save_json, load_json, ) class NanoSocraTraineRam: def __init__( self, max_vocabulary: int, special_vocabulary: list[str], merge_treshold: int = 0, max_iterations: int = 0, print_after_iterations: int = 1, ) -> None: # Bytes BYTE_RESERVED_TOKENS = 256 SPECIAL_RESERVED_TOKENS = len(special_vocabulary) RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS self.__max_iterations = max_iterations self.__merge_treshold = merge_treshold self.__special_token_regex = special_regex_maker(special_vocabulary) self.__print_after_iterations = print_after_iterations def trainBPE( self, path: Path, bpe: NanoSocratesBPE | None = None, ) -> NanoSocratesBPE: if not path.is_file(): raise FileNotFoundError() if bpe is None: bpe = NanoSocratesBPE() BPE = bpe if BPE.vocabulary_size > self.__max_vocabulary: return BPE exit = False current_iteration = 0 data = self.__gather_data_from_file(path) while not exit: current_iteration = self.__increment_counter(current_iteration) LAST_VOC_SIZE = BPE.vocabulary_size last_memory = None _, data, last_memory = self.__round_train(BPE, data) NEW_VOC_SIZE = BPE.vocabulary_size if current_iteration % self.__print_after_iterations == 0: DELIMITER = "===============" DEBUG = "\n".join( [ DELIMITER, f"ITERATION: {current_iteration}", DELIMITER, f"\tVocabulary size: {BPE.vocabulary_size}\n", f"\tFrequencies:\n{last_memory.frequencies}\n", # type: ignore (pretty sure it's not None) f"\tvocabulary:\n{BPE.vocabulary}", DELIMITER, "", ] ) print(DEBUG) if LAST_VOC_SIZE == NEW_VOC_SIZE: exit = True continue if current_iteration == self.__max_iterations: exit = True continue if BPE.vocabulary_size == self.__max_vocabulary: exit = True continue return BPE def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]): DATA_LEN = len(data) memory = NanoSocratesBatchMemoryBPE({}, 0) for piece, index in zip(data, range(0, DATA_LEN)): last_batch = index == DATA_LEN - 1 bpe, memory, output = bpe.fit(piece, memory, last_batch) data[index] = output return (bpe, data, memory) def __gather_data_from_file(self, path: Path) -> list[list[int]]: SPLITTER = NanoSocratesSplitter(self.__special_token_regex) DATA: list[list[int]] = [] FILE = open(path, "r", encoding="utf-8") file_string = FILE.read() FILE.close() for piece, type in SPLITTER.split_text(file_string): if type != TokenType.BPE: continue int_list = self.__make_list_ids(piece) DATA.append(int_list) return DATA def __increment_counter(self, counter: int): # What if overflows??? try: counter += 1 except: print("Integer overflow") counter = 1 return counter def __make_list_ids(self, corpus: str): return list(corpus.encode("utf-8"))