Added multithreaded training

This commit is contained in:
Christian Risi 2025-10-02 01:30:24 +02:00
parent b80b4e4112
commit 63baf29805
2 changed files with 309 additions and 0 deletions

View File

@ -0,0 +1,219 @@
from collections import deque
import datetime
import itertools
from multiprocessing import Pool
import os
from pathlib import Path
import re
from ..Classes import (
NanoSocratesBPE,
NanoSocratesChunker,
NanoSocratesSplitter,
NanoSocratesBatchMemoryBPE,
)
from ..Enums import TokenType
from ..Utils import (
special_regex_maker,
iterator_with_checks,
save_nanos_vocabulary,
load_nanos_vocabulary,
save_json,
load_json,
)
def split(a, n):
k, m = divmod(len(a), n)
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
def split_fit(object: tuple[NanoSocratesBPE, list[list[int]]]):
bpe, data = object
NEW_DATA: list[list[int]]= []
memory = NanoSocratesBatchMemoryBPE({}, 0)
while len(data) > 0:
piece = data.pop()
bpe, memory, output = bpe.fit(piece, memory, False)
if len(output) < 2:
continue
# We are sure of its type
NEW_DATA.append(output) # type: ignore
return (bpe, NEW_DATA, memory)
class NanoSocraTrainerPool:
def __init__(
self,
max_vocabulary: int,
special_vocabulary: list[str],
merge_treshold: int = 0,
max_iterations: int = 0,
print_after_iterations: int = 1,
) -> None:
# Bytes
BYTE_RESERVED_TOKENS = 256
SPECIAL_RESERVED_TOKENS = len(special_vocabulary)
RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS
self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS
self.__max_iterations = max_iterations
self.__merge_treshold = merge_treshold
self.__special_token_regex = special_regex_maker(special_vocabulary)
self.__print_after_iterations = print_after_iterations
# TODO: add a resume function
def trainBPE(
self,
path: Path,
cache_file: Path,
bpe: NanoSocratesBPE | None = None,
) -> NanoSocratesBPE:
if not path.is_file():
raise FileNotFoundError()
if not cache_file.is_file():
file = cache_file.open("w")
file.close()
if bpe is None:
bpe = NanoSocratesBPE()
BPE = bpe
if BPE.vocabulary_size > self.__max_vocabulary:
return BPE
exit = False
current_iteration = 0
data = self.__gather_data_from_file(path)
while not exit:
current_iteration = self.__increment_counter(current_iteration)
LAST_VOC_SIZE = BPE.vocabulary_size
last_memory = None
_, data, last_memory = self.__round_train(BPE, data)
NEW_VOC_SIZE = BPE.vocabulary_size
VOCABULARY = BPE.vocabulary
save_nanos_vocabulary(VOCABULARY, cache_file)
if current_iteration % self.__print_after_iterations == 0:
DELIMITER = "==============="
DEBUG = "\n".join(
[
DELIMITER,
f"ITERATION: {current_iteration}",
DELIMITER,
f"\tVocabulary size: {BPE.vocabulary_size}\n",
f"\tvocabulary:\n{BPE.vocabulary}",
DELIMITER,
"",
]
)
print(DEBUG)
if LAST_VOC_SIZE == NEW_VOC_SIZE:
exit = True
continue
if current_iteration == self.__max_iterations:
exit = True
continue
if BPE.vocabulary_size == self.__max_vocabulary:
exit = True
continue
return BPE
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
NEW_DATA : list[list[int]] = []
MEMORY = NanoSocratesBatchMemoryBPE({}, 0)
fit_funct = split_fit
CPU_COUNT = os.process_cpu_count()
if CPU_COUNT is None:
raise Exception()
VOCABULARY = bpe.vocabulary
data_chunks = split(data, CPU_COUNT)
JOBS = [(NanoSocratesBPE(VOCABULARY), chunk) for chunk in data_chunks]
JOB_RESULTS: list[tuple[NanoSocratesBPE, list[list[int]], NanoSocratesBatchMemoryBPE]]
with Pool() as pool:
JOB_RESULTS = pool.map(fit_funct, JOBS)
for i, res in zip(range(0, CPU_COUNT), JOB_RESULTS):
_, job_output, job_memory = res
NEW_DATA.extend(job_output)
for key, value in job_memory.frequencies.items():
MEMORY.frequencies[key] = value
del job_output
del job_memory
print(f"Joined {i + 1} out of {CPU_COUNT}")
# Get new token
bpe.fit([], MEMORY, True)
print(f"Sentences from {len(data)} to {len(NEW_DATA)}")
return (bpe, NEW_DATA, MEMORY)
def __gather_data_from_file(self, path: Path) -> list[list[int]]:
SPLITTER = NanoSocratesSplitter(self.__special_token_regex)
DATA: list[list[int]] = []
FILE = open(path, "r", encoding="utf-8")
file_string = FILE.read()
FILE.close()
for piece, type in SPLITTER.split_text(file_string):
if type != TokenType.BPE:
continue
int_list = self.__make_list_ids(piece)
DATA.append(int_list)
return DATA
def __increment_counter(self, counter: int):
# What if overflows???
try:
counter += 1
except:
print("Integer overflow")
counter = 1
return counter
def __make_list_ids(self, corpus: str):
return list(corpus.encode("utf-8"))

View File

@ -0,0 +1,90 @@
import argparse
import json
from pathlib import Path
import sys
# TODO: make relative imports
import Project_Model.Libs.BPE as BPE
from Scripts.Libs.CleaningPipeline.special_token import SpecialToken
DEFAULT_DEBUG_AFTER_ITER = 1
DEFAULT_MAX_VOCABULARY = int(32E3)
DEFAULT_MERGE_TRESHOLD = 1
DEFAULT_MAX_ITERATIONS = 0
TOKEN_LIST = [token.value for token in SpecialToken]
class ProgramArgs:
def __init__(
self,
input_file: str,
output_file: str,
cache_file: str,
max_vocabulary: int,
max_iterations: int,
merge_treshold: int,
debug_after: int,
) -> None:
self.input_file = input_file
self.output_file = output_file
self.cache_file = cache_file
self.max_vocabulary = max_vocabulary
self.max_iterations = max_iterations
self.merge_treshold = merge_treshold
self.debug_after = debug_after
def get_args(args: list[str]) -> ProgramArgs:
PARSER = argparse.ArgumentParser()
PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str)
PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str)
PARSER.add_argument("--cache-file", "--cache", "-c", required=True, type=str)
PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int)
PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int)
PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int)
PARSER.add_argument("--debug-after", default=DEFAULT_DEBUG_AFTER_ITER, type=int)
parsed_args, _ = PARSER.parse_known_args(args)
return ProgramArgs(
parsed_args.input_file,
parsed_args.output_file,
parsed_args.cache_file,
parsed_args.max_vocabulary,
parsed_args.max_iterations,
parsed_args.merge_treshold,
parsed_args.debug_after,
) # type ignore
def train(args: ProgramArgs):
TRAINER = BPE.NanoSocraTrainerPool(
args.max_vocabulary,
TOKEN_LIST,
args.merge_treshold,
args.max_iterations,
args.debug_after
)
DATASET_PATH = Path(args.input_file)
VOCABULARY_PATH = Path(args.output_file)
CACHE_PATH = Path(args.cache_file)
print(f"Training BPE")
BPE_ENCODER = TRAINER.trainBPE(
DATASET_PATH,
CACHE_PATH
)
VOCABULARY = BPE_ENCODER.vocabulary
print(f"Saving Vocabulary in {VOCABULARY_PATH}")
BPE.save_nanos_vocabulary(VOCABULARY, VOCABULARY_PATH)
if __name__ == "__main__":
ARGS = get_args(sys.argv)
train(ARGS)