import argparse import json from pathlib import Path import sys # TODO: make relative imports import Project_Model.Libs.BPE as BPE from Scripts.Libs.CleaningPipeline.special_token import SpecialToken DEFAULT_CHUNK_SIZE = int(18e4) DEFAULT_DEBUG_AFTER_ITER = 1 DEFAULT_MAX_VOCABULARY = int(32E3) DEFAULT_MERGE_TRESHOLD = 1 DEFAULT_MAX_ITERATIONS = 0 TOKEN_LIST = [token.value for token in SpecialToken] class ProgramArgs: def __init__( self, input_file: str, cache_dir: str, output_file: str, resume_at: int, max_vocabulary: int, max_iterations: int, merge_treshold: int, chunk_size: int, debug_after: int, ) -> None: self.input_file = input_file self.cache_dir = cache_dir self.output_file = output_file self.resume_at = resume_at self.max_vocabulary = max_vocabulary self.max_iterations = max_iterations self.merge_treshold = merge_treshold self.chunk_size = chunk_size self.debug_after = debug_after def get_args(args: list[str]) -> ProgramArgs: PARSER = argparse.ArgumentParser() PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str) PARSER.add_argument("--cache-dir", "--cache", "-c", required=True, type=str) PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str) PARSER.add_argument("--resume-at", "--resume", "-r", default=0, type=int) PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int) PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int) PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int) PARSER.add_argument("--chunk-size", default=DEFAULT_CHUNK_SIZE, type=int) PARSER.add_argument("--debug-after", default=DEFAULT_DEBUG_AFTER_ITER, type=int) parsed_args, _ = PARSER.parse_known_args(args) return ProgramArgs( parsed_args.input_file, parsed_args.cache_dir, parsed_args.output_file, parsed_args.resume_at, parsed_args.max_vocabulary, parsed_args.max_iterations, parsed_args.merge_treshold, parsed_args.chunk_size, parsed_args.debug_after, ) # type ignore def train(args: ProgramArgs): TRAINER = BPE.NanoSocraTrainer( args.max_vocabulary, TOKEN_LIST, args.chunk_size, args.merge_treshold, args.max_iterations, args.debug_after ) DATASET_PATH = Path(args.input_file) CACHE_DIR = Path(args.cache_dir) VOCABULARY_PATH = Path(args.output_file) print(f"Training BPE") BPE_ENCODER = TRAINER.trainBPE( DATASET_PATH, CACHE_DIR, resume_from_iter=args.resume_at ) VOCABULARY = BPE_ENCODER.vocabulary print(f"Saving Vocabulary in {VOCABULARY_PATH}") BPE.save_nanos_vocabulary(VOCABULARY, VOCABULARY_PATH) if __name__ == "__main__": ARGS = get_args(sys.argv) train(ARGS)