Added multithreaded training
This commit is contained in:
parent
b80b4e4112
commit
63baf29805
219
Project_Model/Libs/BPE/Classes/NanoSocraTrainerPool.py
Normal file
219
Project_Model/Libs/BPE/Classes/NanoSocraTrainerPool.py
Normal file
@ -0,0 +1,219 @@
|
||||
from collections import deque
|
||||
import datetime
|
||||
import itertools
|
||||
from multiprocessing import Pool
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from ..Classes import (
|
||||
NanoSocratesBPE,
|
||||
NanoSocratesChunker,
|
||||
NanoSocratesSplitter,
|
||||
NanoSocratesBatchMemoryBPE,
|
||||
)
|
||||
from ..Enums import TokenType
|
||||
from ..Utils import (
|
||||
special_regex_maker,
|
||||
iterator_with_checks,
|
||||
save_nanos_vocabulary,
|
||||
load_nanos_vocabulary,
|
||||
save_json,
|
||||
load_json,
|
||||
)
|
||||
|
||||
def split(a, n):
|
||||
k, m = divmod(len(a), n)
|
||||
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
|
||||
|
||||
def split_fit(object: tuple[NanoSocratesBPE, list[list[int]]]):
|
||||
|
||||
bpe, data = object
|
||||
|
||||
NEW_DATA: list[list[int]]= []
|
||||
|
||||
memory = NanoSocratesBatchMemoryBPE({}, 0)
|
||||
|
||||
while len(data) > 0:
|
||||
|
||||
piece = data.pop()
|
||||
|
||||
bpe, memory, output = bpe.fit(piece, memory, False)
|
||||
|
||||
if len(output) < 2:
|
||||
continue
|
||||
|
||||
# We are sure of its type
|
||||
NEW_DATA.append(output) # type: ignore
|
||||
|
||||
return (bpe, NEW_DATA, memory)
|
||||
|
||||
|
||||
class NanoSocraTrainerPool:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_vocabulary: int,
|
||||
special_vocabulary: list[str],
|
||||
merge_treshold: int = 0,
|
||||
max_iterations: int = 0,
|
||||
print_after_iterations: int = 1,
|
||||
) -> None:
|
||||
# Bytes
|
||||
BYTE_RESERVED_TOKENS = 256
|
||||
SPECIAL_RESERVED_TOKENS = len(special_vocabulary)
|
||||
RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS
|
||||
|
||||
self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS
|
||||
self.__max_iterations = max_iterations
|
||||
self.__merge_treshold = merge_treshold
|
||||
self.__special_token_regex = special_regex_maker(special_vocabulary)
|
||||
self.__print_after_iterations = print_after_iterations
|
||||
|
||||
# TODO: add a resume function
|
||||
def trainBPE(
|
||||
self,
|
||||
path: Path,
|
||||
cache_file: Path,
|
||||
bpe: NanoSocratesBPE | None = None,
|
||||
) -> NanoSocratesBPE:
|
||||
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError()
|
||||
|
||||
if not cache_file.is_file():
|
||||
file = cache_file.open("w")
|
||||
file.close()
|
||||
|
||||
if bpe is None:
|
||||
bpe = NanoSocratesBPE()
|
||||
BPE = bpe
|
||||
|
||||
if BPE.vocabulary_size > self.__max_vocabulary:
|
||||
return BPE
|
||||
|
||||
exit = False
|
||||
current_iteration = 0
|
||||
data = self.__gather_data_from_file(path)
|
||||
|
||||
while not exit:
|
||||
|
||||
current_iteration = self.__increment_counter(current_iteration)
|
||||
|
||||
LAST_VOC_SIZE = BPE.vocabulary_size
|
||||
|
||||
last_memory = None
|
||||
|
||||
_, data, last_memory = self.__round_train(BPE, data)
|
||||
|
||||
NEW_VOC_SIZE = BPE.vocabulary_size
|
||||
|
||||
VOCABULARY = BPE.vocabulary
|
||||
|
||||
save_nanos_vocabulary(VOCABULARY, cache_file)
|
||||
|
||||
if current_iteration % self.__print_after_iterations == 0:
|
||||
|
||||
DELIMITER = "==============="
|
||||
|
||||
DEBUG = "\n".join(
|
||||
[
|
||||
DELIMITER,
|
||||
f"ITERATION: {current_iteration}",
|
||||
DELIMITER,
|
||||
f"\tVocabulary size: {BPE.vocabulary_size}\n",
|
||||
f"\tvocabulary:\n{BPE.vocabulary}",
|
||||
DELIMITER,
|
||||
"",
|
||||
]
|
||||
)
|
||||
print(DEBUG)
|
||||
|
||||
if LAST_VOC_SIZE == NEW_VOC_SIZE:
|
||||
exit = True
|
||||
continue
|
||||
|
||||
if current_iteration == self.__max_iterations:
|
||||
exit = True
|
||||
continue
|
||||
|
||||
if BPE.vocabulary_size == self.__max_vocabulary:
|
||||
exit = True
|
||||
continue
|
||||
|
||||
return BPE
|
||||
|
||||
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
|
||||
|
||||
NEW_DATA : list[list[int]] = []
|
||||
|
||||
MEMORY = NanoSocratesBatchMemoryBPE({}, 0)
|
||||
|
||||
fit_funct = split_fit
|
||||
CPU_COUNT = os.process_cpu_count()
|
||||
|
||||
if CPU_COUNT is None:
|
||||
raise Exception()
|
||||
|
||||
VOCABULARY = bpe.vocabulary
|
||||
|
||||
data_chunks = split(data, CPU_COUNT)
|
||||
JOBS = [(NanoSocratesBPE(VOCABULARY), chunk) for chunk in data_chunks]
|
||||
|
||||
JOB_RESULTS: list[tuple[NanoSocratesBPE, list[list[int]], NanoSocratesBatchMemoryBPE]]
|
||||
|
||||
with Pool() as pool:
|
||||
JOB_RESULTS = pool.map(fit_funct, JOBS)
|
||||
|
||||
for i, res in zip(range(0, CPU_COUNT), JOB_RESULTS):
|
||||
_, job_output, job_memory = res
|
||||
NEW_DATA.extend(job_output)
|
||||
|
||||
for key, value in job_memory.frequencies.items():
|
||||
MEMORY.frequencies[key] = value
|
||||
|
||||
del job_output
|
||||
del job_memory
|
||||
|
||||
print(f"Joined {i + 1} out of {CPU_COUNT}")
|
||||
|
||||
|
||||
# Get new token
|
||||
bpe.fit([], MEMORY, True)
|
||||
|
||||
print(f"Sentences from {len(data)} to {len(NEW_DATA)}")
|
||||
|
||||
return (bpe, NEW_DATA, MEMORY)
|
||||
|
||||
def __gather_data_from_file(self, path: Path) -> list[list[int]]:
|
||||
|
||||
SPLITTER = NanoSocratesSplitter(self.__special_token_regex)
|
||||
|
||||
DATA: list[list[int]] = []
|
||||
|
||||
FILE = open(path, "r", encoding="utf-8")
|
||||
file_string = FILE.read()
|
||||
FILE.close()
|
||||
|
||||
for piece, type in SPLITTER.split_text(file_string):
|
||||
|
||||
if type != TokenType.BPE:
|
||||
continue
|
||||
|
||||
int_list = self.__make_list_ids(piece)
|
||||
DATA.append(int_list)
|
||||
|
||||
return DATA
|
||||
|
||||
def __increment_counter(self, counter: int):
|
||||
|
||||
# What if overflows???
|
||||
try:
|
||||
counter += 1
|
||||
except:
|
||||
print("Integer overflow")
|
||||
counter = 1
|
||||
|
||||
return counter
|
||||
|
||||
def __make_list_ids(self, corpus: str):
|
||||
return list(corpus.encode("utf-8"))
|
||||
90
Scripts/Training/bpe_trainer_pool.py
Normal file
90
Scripts/Training/bpe_trainer_pool.py
Normal file
@ -0,0 +1,90 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
# TODO: make relative imports
|
||||
import Project_Model.Libs.BPE as BPE
|
||||
from Scripts.Libs.CleaningPipeline.special_token import SpecialToken
|
||||
|
||||
DEFAULT_DEBUG_AFTER_ITER = 1
|
||||
DEFAULT_MAX_VOCABULARY = int(32E3)
|
||||
DEFAULT_MERGE_TRESHOLD = 1
|
||||
DEFAULT_MAX_ITERATIONS = 0
|
||||
TOKEN_LIST = [token.value for token in SpecialToken]
|
||||
|
||||
|
||||
class ProgramArgs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
cache_file: str,
|
||||
max_vocabulary: int,
|
||||
max_iterations: int,
|
||||
merge_treshold: int,
|
||||
debug_after: int,
|
||||
) -> None:
|
||||
self.input_file = input_file
|
||||
self.output_file = output_file
|
||||
self.cache_file = cache_file
|
||||
self.max_vocabulary = max_vocabulary
|
||||
self.max_iterations = max_iterations
|
||||
self.merge_treshold = merge_treshold
|
||||
self.debug_after = debug_after
|
||||
|
||||
|
||||
def get_args(args: list[str]) -> ProgramArgs:
|
||||
|
||||
PARSER = argparse.ArgumentParser()
|
||||
PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str)
|
||||
PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str)
|
||||
PARSER.add_argument("--cache-file", "--cache", "-c", required=True, type=str)
|
||||
PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int)
|
||||
PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int)
|
||||
PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int)
|
||||
PARSER.add_argument("--debug-after", default=DEFAULT_DEBUG_AFTER_ITER, type=int)
|
||||
|
||||
parsed_args, _ = PARSER.parse_known_args(args)
|
||||
|
||||
return ProgramArgs(
|
||||
parsed_args.input_file,
|
||||
parsed_args.output_file,
|
||||
parsed_args.cache_file,
|
||||
parsed_args.max_vocabulary,
|
||||
parsed_args.max_iterations,
|
||||
parsed_args.merge_treshold,
|
||||
parsed_args.debug_after,
|
||||
) # type ignore
|
||||
|
||||
|
||||
def train(args: ProgramArgs):
|
||||
|
||||
TRAINER = BPE.NanoSocraTrainerPool(
|
||||
args.max_vocabulary,
|
||||
TOKEN_LIST,
|
||||
args.merge_treshold,
|
||||
args.max_iterations,
|
||||
args.debug_after
|
||||
)
|
||||
|
||||
DATASET_PATH = Path(args.input_file)
|
||||
VOCABULARY_PATH = Path(args.output_file)
|
||||
CACHE_PATH = Path(args.cache_file)
|
||||
|
||||
print(f"Training BPE")
|
||||
|
||||
BPE_ENCODER = TRAINER.trainBPE(
|
||||
DATASET_PATH,
|
||||
CACHE_PATH
|
||||
)
|
||||
|
||||
VOCABULARY = BPE_ENCODER.vocabulary
|
||||
|
||||
print(f"Saving Vocabulary in {VOCABULARY_PATH}")
|
||||
|
||||
BPE.save_nanos_vocabulary(VOCABULARY, VOCABULARY_PATH)
|
||||
|
||||
if __name__ == "__main__":
|
||||
ARGS = get_args(sys.argv)
|
||||
train(ARGS)
|
||||
Loading…
x
Reference in New Issue
Block a user