diff --git a/Scripts/Training/bpe_trainer.py b/Scripts/Training/bpe_trainer.py new file mode 100644 index 0000000..759f397 --- /dev/null +++ b/Scripts/Training/bpe_trainer.py @@ -0,0 +1,100 @@ +import argparse +import json +from pathlib import Path +import sys +# TODO: make relative imports +import Project_Model.Libs.BPE as BPE +from Scripts.Libs.CleaningPipeline.special_token import SpecialToken + +DEFAULT_CHUNK_SIZE = int(18e4) +DEFAULT_DEBUG_AFTER_ITER = 1 +DEFAULT_MAX_VOCABULARY = int(32E3) +DEFAULT_MERGE_TRESHOLD = 1 +DEFAULT_MAX_ITERATIONS = 0 +TOKEN_LIST = [token.value for token in SpecialToken] + + +class ProgramArgs: + + def __init__( + self, + input_file: str, + cache_dir: str, + output_file: str, + max_vocabulary: int, + max_iterations: int, + merge_treshold: int, + chunk_size: int, + debug_after: int, + ) -> None: + self.input_file = input_file + self.cache_dir = cache_dir + self.output_file = output_file + self.max_vocabulary = max_vocabulary + self.max_iterations = max_iterations + self.merge_treshold = merge_treshold + self.chunk_size = chunk_size + self.debug_after = debug_after + + +def get_args(args: list[str]) -> ProgramArgs: + + PARSER = argparse.ArgumentParser() + PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str) + PARSER.add_argument("--cache-dir", "--cache", "-c", required=True, type=str) + PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str) + PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int) + PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int) + PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int) + PARSER.add_argument("--chunk-size", default=DEFAULT_CHUNK_SIZE, type=int) + PARSER.add_argument("--debug-after", default=DEFAULT_DEBUG_AFTER_ITER, type=int) + + parsed_args, _ = PARSER.parse_known_args(args) + + return ProgramArgs( + parsed_args.input_file, + parsed_args.cache_dir, + parsed_args.output_file, + parsed_args.max_vocabulary, + parsed_args.max_iterations, + parsed_args.merge_treshold, + parsed_args.chunk_size, + parsed_args.debug_after, + ) # type ignore + + +def train(args: ProgramArgs): + + TRAINER = BPE.NanoSocraTrainer( + args.max_vocabulary, + TOKEN_LIST, + args.chunk_size, + args.merge_treshold, + args.max_iterations, + args.debug_after + ) + + DATASET_PATH = Path(args.input_file) + CACHE_DIR = Path(args.cache_dir) + VOCABULARY_PATH = Path(args.output_file) + + print(f"Training BPE") + + BPE_ENCODER = TRAINER.trainBPE( + DATASET_PATH, + CACHE_DIR + ) + + VOCABULARY = BPE_ENCODER.vocabulary + VOCABULARY_JSON = json.dumps(VOCABULARY) + + print(f"Saving Vocabulary in {VOCABULARY_PATH}") + + FILE = open(VOCABULARY_PATH, "w") + FILE.write(VOCABULARY_JSON) + FILE.close() + + +if __name__ == "__main__": + ARGS = get_args(sys.argv) + train(ARGS)