import argparse import json from pathlib import Path import sys # TODO: make relative imports import Project_Model.Libs.BPE as BPE from Scripts.Libs.CleaningPipeline.special_token import SpecialToken DEFAULT_DEBUG_AFTER_ITER = 1 DEFAULT_MAX_VOCABULARY = int(32E3) DEFAULT_MERGE_TRESHOLD = 1 DEFAULT_MAX_ITERATIONS = 0 TOKEN_LIST = [token.value for token in SpecialToken] class ProgramArgs: def __init__( self, input_file: str, output_file: str, cache_file: str, max_vocabulary: int, max_iterations: int, merge_treshold: int, debug_after: int, ) -> None: self.input_file = input_file self.output_file = output_file self.cache_file = cache_file self.max_vocabulary = max_vocabulary self.max_iterations = max_iterations self.merge_treshold = merge_treshold self.debug_after = debug_after def get_args(args: list[str]) -> ProgramArgs: PARSER = argparse.ArgumentParser() PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str) PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str) PARSER.add_argument("--cache-file", "--cache", "-c", required=True, type=str) PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int) PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int) PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int) PARSER.add_argument("--debug-after", default=DEFAULT_DEBUG_AFTER_ITER, type=int) parsed_args, _ = PARSER.parse_known_args(args) return ProgramArgs( parsed_args.input_file, parsed_args.output_file, parsed_args.cache_file, parsed_args.max_vocabulary, parsed_args.max_iterations, parsed_args.merge_treshold, parsed_args.debug_after, ) # type ignore def train(args: ProgramArgs): TRAINER = BPE.NanoSocraTrainerPool( args.max_vocabulary, TOKEN_LIST, args.merge_treshold, args.max_iterations, args.debug_after ) DATASET_PATH = Path(args.input_file) VOCABULARY_PATH = Path(args.output_file) CACHE_PATH = Path(args.cache_file) start_bpe = BPE.NanoSocratesBPE() if CACHE_PATH.is_file(): voc = BPE.load_nanos_vocabulary(CACHE_PATH) start_bpe = BPE.NanoSocratesBPE(voc) print(f"Training BPE") BPE_ENCODER = TRAINER.trainBPE( DATASET_PATH, CACHE_PATH, start_bpe ) VOCABULARY = BPE_ENCODER.vocabulary print(f"Saving Vocabulary in {VOCABULARY_PATH}") BPE.save_nanos_vocabulary(VOCABULARY, VOCABULARY_PATH) if __name__ == "__main__": ARGS = get_args(sys.argv) train(ARGS)