from collections import deque import datetime import itertools from multiprocessing import Pool import os from pathlib import Path import re import time from ..Classes import ( NanoSocratesBPE, NanoSocratesChunker, NanoSocratesSplitter, NanoSocratesBatchMemoryBPE, ) from ..Enums import TokenType from ..Utils import ( special_regex_maker, iterator_with_checks, save_nanos_vocabulary, load_nanos_vocabulary, save_json, load_json, ) def split(a, n): k, m = divmod(len(a), n) return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)) def split_fit(object: tuple[NanoSocratesBPE, list[list[int]]]): bpe, data = object NEW_DATA: list[list[int]] = [] memory = NanoSocratesBatchMemoryBPE({}, 0) while len(data) > 0: piece = data.pop() bpe, memory, output = bpe.fit(piece, memory, False) if len(output) < 2: continue # We are sure of its type NEW_DATA.append(piece) # type: ignore return (bpe, NEW_DATA, memory) def split_encode(object: tuple[NanoSocratesBPE, list[list[int]]]): bpe, data = object NEW_DATA: list[list[int]] = [] for index, piece in zip(range(0, len(data)), data): output = bpe.encode_intermediate(piece) if len(output) < 2: continue # We are sure of its type NEW_DATA.append(data[index]) # type: ignore return NEW_DATA class NanoSocraTrainerPool: def __init__( self, max_vocabulary: int, special_vocabulary: list[str], merge_treshold: int = 0, max_iterations: int = 0, print_after_iterations: int = 1, ) -> None: # Bytes BYTE_RESERVED_TOKENS = 256 SPECIAL_RESERVED_TOKENS = len(special_vocabulary) RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS self.__max_iterations = max_iterations self.__merge_treshold = merge_treshold self.__special_token_regex = special_regex_maker(special_vocabulary) self.__print_after_iterations = print_after_iterations # TODO: add a resume function def trainBPE( self, path: Path, cache_file: Path, bpe: NanoSocratesBPE | None = None, ) -> NanoSocratesBPE: if not path.is_file(): raise FileNotFoundError() if not cache_file.is_file(): file = cache_file.open("w") file.close() if bpe is None: bpe = NanoSocratesBPE() BPE = bpe if BPE.vocabulary_size > self.__max_vocabulary: return BPE exit = False current_iteration = 0 data = self.__gather_data_from_file(path) data = self.__encode_from_cache(BPE, data) while not exit: current_iteration = self.__increment_counter(current_iteration) LAST_VOC_SIZE = BPE.vocabulary_size last_memory = None start = time.time_ns() _, data, last_memory = self.__round_train(BPE, data) end = time.time_ns() NEW_VOC_SIZE = BPE.vocabulary_size VOCABULARY = BPE.vocabulary save_nanos_vocabulary(VOCABULARY, cache_file) if current_iteration % self.__print_after_iterations == 0: DELIMITER = "===============" DEBUG = "\n".join( [ DELIMITER, f"ITERATION: {current_iteration}", DELIMITER, f"\tVocabulary size: {BPE.vocabulary_size - 256}\n", f"\tTime elapsed: {(end - start)/1E9}s", DELIMITER, "", ] ) print(DEBUG) if LAST_VOC_SIZE == NEW_VOC_SIZE: exit = True continue if current_iteration == self.__max_iterations: exit = True continue if BPE.vocabulary_size == self.__max_vocabulary: exit = True continue return BPE def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]): NEW_DATA: list[list[int]] = [] MEMORY = NanoSocratesBatchMemoryBPE({}, self.__merge_treshold) fit_funct = split_fit CPU_COUNT = os.process_cpu_count() if CPU_COUNT is None: raise Exception() VOCABULARY = bpe.vocabulary data_chunks = split(data, CPU_COUNT) JOBS = [(NanoSocratesBPE(VOCABULARY), chunk) for chunk in data_chunks] JOB_RESULTS: list[ tuple[NanoSocratesBPE, list[list[int]], NanoSocratesBatchMemoryBPE] ] with Pool() as pool: JOB_RESULTS = pool.map(fit_funct, JOBS) for i, res in zip(range(0, CPU_COUNT), JOB_RESULTS): _, job_output, job_memory = res NEW_DATA.extend(job_output) for key, value in job_memory.frequencies.items(): frequency = MEMORY.frequencies.get(key) if frequency is None: frequency = 0 MEMORY.frequencies[key] = 0 frequency += value MEMORY.frequencies[key] = frequency del job_output del job_memory print(f"Joined {i + 1} out of {CPU_COUNT}") # Get new token bpe.fit([], MEMORY, True) print(f"Sentences from {len(data)} to {len(NEW_DATA)}") return (bpe, NEW_DATA, MEMORY) def __gather_data_from_file(self, path: Path) -> list[list[int]]: SPLITTER = NanoSocratesSplitter(self.__special_token_regex) DATA: list[list[int]] = [] FILE = open(path, "r", encoding="utf-8") file_string = FILE.read() FILE.close() for piece, type in SPLITTER.split_text(file_string): if type != TokenType.BPE: continue int_list = self.__make_list_ids(piece) DATA.append(int_list) return DATA def __encode_from_cache(self, bpe: NanoSocratesBPE, data: list[list[int]]): NEW_DATA : list[list[int]]= [] CPU_COUNT = os.process_cpu_count() if CPU_COUNT is None: raise Exception() VOCABULARY = bpe.vocabulary data_chunks = split(data, CPU_COUNT) JOBS = [(NanoSocratesBPE(VOCABULARY), chunk) for chunk in data_chunks] JOB_RESULTS: list[list[list[int]]] with Pool() as pool: JOB_RESULTS = pool.map(split_encode, JOBS) for i, res in zip(range(0, CPU_COUNT), JOB_RESULTS): job_output = res NEW_DATA.extend(job_output) del job_output print(f"Joined {i + 1} out of {CPU_COUNT}") print(f"Sentences from {len(data)} to {len(NEW_DATA)}") return NEW_DATA def __increment_counter(self, counter: int): # What if overflows??? try: counter += 1 except: print("Integer overflow") counter = 1 return counter def __make_list_ids(self, corpus: str): return list(corpus.encode("utf-8"))