Added multithreaded training
This commit is contained in:
parent
b80b4e4112
commit
63baf29805
219
Project_Model/Libs/BPE/Classes/NanoSocraTrainerPool.py
Normal file
219
Project_Model/Libs/BPE/Classes/NanoSocraTrainerPool.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
from collections import deque
|
||||||
|
import datetime
|
||||||
|
import itertools
|
||||||
|
from multiprocessing import Pool
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
from ..Classes import (
|
||||||
|
NanoSocratesBPE,
|
||||||
|
NanoSocratesChunker,
|
||||||
|
NanoSocratesSplitter,
|
||||||
|
NanoSocratesBatchMemoryBPE,
|
||||||
|
)
|
||||||
|
from ..Enums import TokenType
|
||||||
|
from ..Utils import (
|
||||||
|
special_regex_maker,
|
||||||
|
iterator_with_checks,
|
||||||
|
save_nanos_vocabulary,
|
||||||
|
load_nanos_vocabulary,
|
||||||
|
save_json,
|
||||||
|
load_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
def split(a, n):
|
||||||
|
k, m = divmod(len(a), n)
|
||||||
|
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
|
||||||
|
|
||||||
|
def split_fit(object: tuple[NanoSocratesBPE, list[list[int]]]):
|
||||||
|
|
||||||
|
bpe, data = object
|
||||||
|
|
||||||
|
NEW_DATA: list[list[int]]= []
|
||||||
|
|
||||||
|
memory = NanoSocratesBatchMemoryBPE({}, 0)
|
||||||
|
|
||||||
|
while len(data) > 0:
|
||||||
|
|
||||||
|
piece = data.pop()
|
||||||
|
|
||||||
|
bpe, memory, output = bpe.fit(piece, memory, False)
|
||||||
|
|
||||||
|
if len(output) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We are sure of its type
|
||||||
|
NEW_DATA.append(output) # type: ignore
|
||||||
|
|
||||||
|
return (bpe, NEW_DATA, memory)
|
||||||
|
|
||||||
|
|
||||||
|
class NanoSocraTrainerPool:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_vocabulary: int,
|
||||||
|
special_vocabulary: list[str],
|
||||||
|
merge_treshold: int = 0,
|
||||||
|
max_iterations: int = 0,
|
||||||
|
print_after_iterations: int = 1,
|
||||||
|
) -> None:
|
||||||
|
# Bytes
|
||||||
|
BYTE_RESERVED_TOKENS = 256
|
||||||
|
SPECIAL_RESERVED_TOKENS = len(special_vocabulary)
|
||||||
|
RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS
|
||||||
|
|
||||||
|
self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS
|
||||||
|
self.__max_iterations = max_iterations
|
||||||
|
self.__merge_treshold = merge_treshold
|
||||||
|
self.__special_token_regex = special_regex_maker(special_vocabulary)
|
||||||
|
self.__print_after_iterations = print_after_iterations
|
||||||
|
|
||||||
|
# TODO: add a resume function
|
||||||
|
def trainBPE(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
cache_file: Path,
|
||||||
|
bpe: NanoSocratesBPE | None = None,
|
||||||
|
) -> NanoSocratesBPE:
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
raise FileNotFoundError()
|
||||||
|
|
||||||
|
if not cache_file.is_file():
|
||||||
|
file = cache_file.open("w")
|
||||||
|
file.close()
|
||||||
|
|
||||||
|
if bpe is None:
|
||||||
|
bpe = NanoSocratesBPE()
|
||||||
|
BPE = bpe
|
||||||
|
|
||||||
|
if BPE.vocabulary_size > self.__max_vocabulary:
|
||||||
|
return BPE
|
||||||
|
|
||||||
|
exit = False
|
||||||
|
current_iteration = 0
|
||||||
|
data = self.__gather_data_from_file(path)
|
||||||
|
|
||||||
|
while not exit:
|
||||||
|
|
||||||
|
current_iteration = self.__increment_counter(current_iteration)
|
||||||
|
|
||||||
|
LAST_VOC_SIZE = BPE.vocabulary_size
|
||||||
|
|
||||||
|
last_memory = None
|
||||||
|
|
||||||
|
_, data, last_memory = self.__round_train(BPE, data)
|
||||||
|
|
||||||
|
NEW_VOC_SIZE = BPE.vocabulary_size
|
||||||
|
|
||||||
|
VOCABULARY = BPE.vocabulary
|
||||||
|
|
||||||
|
save_nanos_vocabulary(VOCABULARY, cache_file)
|
||||||
|
|
||||||
|
if current_iteration % self.__print_after_iterations == 0:
|
||||||
|
|
||||||
|
DELIMITER = "==============="
|
||||||
|
|
||||||
|
DEBUG = "\n".join(
|
||||||
|
[
|
||||||
|
DELIMITER,
|
||||||
|
f"ITERATION: {current_iteration}",
|
||||||
|
DELIMITER,
|
||||||
|
f"\tVocabulary size: {BPE.vocabulary_size}\n",
|
||||||
|
f"\tvocabulary:\n{BPE.vocabulary}",
|
||||||
|
DELIMITER,
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(DEBUG)
|
||||||
|
|
||||||
|
if LAST_VOC_SIZE == NEW_VOC_SIZE:
|
||||||
|
exit = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_iteration == self.__max_iterations:
|
||||||
|
exit = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if BPE.vocabulary_size == self.__max_vocabulary:
|
||||||
|
exit = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
return BPE
|
||||||
|
|
||||||
|
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
|
||||||
|
|
||||||
|
NEW_DATA : list[list[int]] = []
|
||||||
|
|
||||||
|
MEMORY = NanoSocratesBatchMemoryBPE({}, 0)
|
||||||
|
|
||||||
|
fit_funct = split_fit
|
||||||
|
CPU_COUNT = os.process_cpu_count()
|
||||||
|
|
||||||
|
if CPU_COUNT is None:
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
VOCABULARY = bpe.vocabulary
|
||||||
|
|
||||||
|
data_chunks = split(data, CPU_COUNT)
|
||||||
|
JOBS = [(NanoSocratesBPE(VOCABULARY), chunk) for chunk in data_chunks]
|
||||||
|
|
||||||
|
JOB_RESULTS: list[tuple[NanoSocratesBPE, list[list[int]], NanoSocratesBatchMemoryBPE]]
|
||||||
|
|
||||||
|
with Pool() as pool:
|
||||||
|
JOB_RESULTS = pool.map(fit_funct, JOBS)
|
||||||
|
|
||||||
|
for i, res in zip(range(0, CPU_COUNT), JOB_RESULTS):
|
||||||
|
_, job_output, job_memory = res
|
||||||
|
NEW_DATA.extend(job_output)
|
||||||
|
|
||||||
|
for key, value in job_memory.frequencies.items():
|
||||||
|
MEMORY.frequencies[key] = value
|
||||||
|
|
||||||
|
del job_output
|
||||||
|
del job_memory
|
||||||
|
|
||||||
|
print(f"Joined {i + 1} out of {CPU_COUNT}")
|
||||||
|
|
||||||
|
|
||||||
|
# Get new token
|
||||||
|
bpe.fit([], MEMORY, True)
|
||||||
|
|
||||||
|
print(f"Sentences from {len(data)} to {len(NEW_DATA)}")
|
||||||
|
|
||||||
|
return (bpe, NEW_DATA, MEMORY)
|
||||||
|
|
||||||
|
def __gather_data_from_file(self, path: Path) -> list[list[int]]:
|
||||||
|
|
||||||
|
SPLITTER = NanoSocratesSplitter(self.__special_token_regex)
|
||||||
|
|
||||||
|
DATA: list[list[int]] = []
|
||||||
|
|
||||||
|
FILE = open(path, "r", encoding="utf-8")
|
||||||
|
file_string = FILE.read()
|
||||||
|
FILE.close()
|
||||||
|
|
||||||
|
for piece, type in SPLITTER.split_text(file_string):
|
||||||
|
|
||||||
|
if type != TokenType.BPE:
|
||||||
|
continue
|
||||||
|
|
||||||
|
int_list = self.__make_list_ids(piece)
|
||||||
|
DATA.append(int_list)
|
||||||
|
|
||||||
|
return DATA
|
||||||
|
|
||||||
|
def __increment_counter(self, counter: int):
|
||||||
|
|
||||||
|
# What if overflows???
|
||||||
|
try:
|
||||||
|
counter += 1
|
||||||
|
except:
|
||||||
|
print("Integer overflow")
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
return counter
|
||||||
|
|
||||||
|
def __make_list_ids(self, corpus: str):
|
||||||
|
return list(corpus.encode("utf-8"))
|
||||||
90
Scripts/Training/bpe_trainer_pool.py
Normal file
90
Scripts/Training/bpe_trainer_pool.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
# TODO: make relative imports
|
||||||
|
import Project_Model.Libs.BPE as BPE
|
||||||
|
from Scripts.Libs.CleaningPipeline.special_token import SpecialToken
|
||||||
|
|
||||||
|
DEFAULT_DEBUG_AFTER_ITER = 1
|
||||||
|
DEFAULT_MAX_VOCABULARY = int(32E3)
|
||||||
|
DEFAULT_MERGE_TRESHOLD = 1
|
||||||
|
DEFAULT_MAX_ITERATIONS = 0
|
||||||
|
TOKEN_LIST = [token.value for token in SpecialToken]
|
||||||
|
|
||||||
|
|
||||||
|
class ProgramArgs:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_file: str,
|
||||||
|
output_file: str,
|
||||||
|
cache_file: str,
|
||||||
|
max_vocabulary: int,
|
||||||
|
max_iterations: int,
|
||||||
|
merge_treshold: int,
|
||||||
|
debug_after: int,
|
||||||
|
) -> None:
|
||||||
|
self.input_file = input_file
|
||||||
|
self.output_file = output_file
|
||||||
|
self.cache_file = cache_file
|
||||||
|
self.max_vocabulary = max_vocabulary
|
||||||
|
self.max_iterations = max_iterations
|
||||||
|
self.merge_treshold = merge_treshold
|
||||||
|
self.debug_after = debug_after
|
||||||
|
|
||||||
|
|
||||||
|
def get_args(args: list[str]) -> ProgramArgs:
|
||||||
|
|
||||||
|
PARSER = argparse.ArgumentParser()
|
||||||
|
PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str)
|
||||||
|
PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str)
|
||||||
|
PARSER.add_argument("--cache-file", "--cache", "-c", required=True, type=str)
|
||||||
|
PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int)
|
||||||
|
PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int)
|
||||||
|
PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int)
|
||||||
|
PARSER.add_argument("--debug-after", default=DEFAULT_DEBUG_AFTER_ITER, type=int)
|
||||||
|
|
||||||
|
parsed_args, _ = PARSER.parse_known_args(args)
|
||||||
|
|
||||||
|
return ProgramArgs(
|
||||||
|
parsed_args.input_file,
|
||||||
|
parsed_args.output_file,
|
||||||
|
parsed_args.cache_file,
|
||||||
|
parsed_args.max_vocabulary,
|
||||||
|
parsed_args.max_iterations,
|
||||||
|
parsed_args.merge_treshold,
|
||||||
|
parsed_args.debug_after,
|
||||||
|
) # type ignore
|
||||||
|
|
||||||
|
|
||||||
|
def train(args: ProgramArgs):
|
||||||
|
|
||||||
|
TRAINER = BPE.NanoSocraTrainerPool(
|
||||||
|
args.max_vocabulary,
|
||||||
|
TOKEN_LIST,
|
||||||
|
args.merge_treshold,
|
||||||
|
args.max_iterations,
|
||||||
|
args.debug_after
|
||||||
|
)
|
||||||
|
|
||||||
|
DATASET_PATH = Path(args.input_file)
|
||||||
|
VOCABULARY_PATH = Path(args.output_file)
|
||||||
|
CACHE_PATH = Path(args.cache_file)
|
||||||
|
|
||||||
|
print(f"Training BPE")
|
||||||
|
|
||||||
|
BPE_ENCODER = TRAINER.trainBPE(
|
||||||
|
DATASET_PATH,
|
||||||
|
CACHE_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
VOCABULARY = BPE_ENCODER.vocabulary
|
||||||
|
|
||||||
|
print(f"Saving Vocabulary in {VOCABULARY_PATH}")
|
||||||
|
|
||||||
|
BPE.save_nanos_vocabulary(VOCABULARY, VOCABULARY_PATH)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ARGS = get_args(sys.argv)
|
||||||
|
train(ARGS)
|
||||||
Loading…
x
Reference in New Issue
Block a user