diff --git a/Scripts/Training/bpe_trainer.py b/Scripts/Training/bpe_trainer.py index 904bfbf..bc8916e 100644 --- a/Scripts/Training/bpe_trainer.py +++ b/Scripts/Training/bpe_trainer.py @@ -21,6 +21,7 @@ class ProgramArgs: input_file: str, cache_dir: str, output_file: str, + resume_at: int, max_vocabulary: int, max_iterations: int, merge_treshold: int, @@ -30,6 +31,7 @@ class ProgramArgs: self.input_file = input_file self.cache_dir = cache_dir self.output_file = output_file + self.resume_at = resume_at self.max_vocabulary = max_vocabulary self.max_iterations = max_iterations self.merge_treshold = merge_treshold @@ -43,6 +45,7 @@ def get_args(args: list[str]) -> ProgramArgs: PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str) PARSER.add_argument("--cache-dir", "--cache", "-c", required=True, type=str) PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str) + PARSER.add_argument("--resume-at", "--resume", "-r", default=0, type=int) PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int) PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int) PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int) @@ -55,6 +58,7 @@ def get_args(args: list[str]) -> ProgramArgs: parsed_args.input_file, parsed_args.cache_dir, parsed_args.output_file, + parsed_args.resume_at, parsed_args.max_vocabulary, parsed_args.max_iterations, parsed_args.merge_treshold, @@ -82,25 +86,15 @@ def train(args: ProgramArgs): BPE_ENCODER = TRAINER.trainBPE( DATASET_PATH, - CACHE_DIR + CACHE_DIR, + resume_from_iter=args.resume_at ) VOCABULARY = BPE_ENCODER.vocabulary - JSON_VOCABULARY: dict[str, int]= {} - - for key, item in VOCABULARY.items(): - TUPLE_STR = f"{key}" - JSON_VOCABULARY[TUPLE_STR] = item - - VOCABULARY_JSON = json.dumps(JSON_VOCABULARY) - print(f"Saving Vocabulary in {VOCABULARY_PATH}") - FILE = open(VOCABULARY_PATH, "w") - FILE.write(VOCABULARY_JSON) - FILE.close() - + BPE.save_nanos_vocabulary(VOCABULARY, VOCABULARY_PATH) if __name__ == "__main__": ARGS = get_args(sys.argv)