Added flag to resume work correctly

This commit is contained in:
Christian Risi 2025-10-01 12:22:09 +02:00
parent 66bcf6e55f
commit b3d444979f

View File

@ -21,6 +21,7 @@ class ProgramArgs:
input_file: str, input_file: str,
cache_dir: str, cache_dir: str,
output_file: str, output_file: str,
resume_at: int,
max_vocabulary: int, max_vocabulary: int,
max_iterations: int, max_iterations: int,
merge_treshold: int, merge_treshold: int,
@ -30,6 +31,7 @@ class ProgramArgs:
self.input_file = input_file self.input_file = input_file
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.output_file = output_file self.output_file = output_file
self.resume_at = resume_at
self.max_vocabulary = max_vocabulary self.max_vocabulary = max_vocabulary
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.merge_treshold = merge_treshold self.merge_treshold = merge_treshold
@ -43,6 +45,7 @@ def get_args(args: list[str]) -> ProgramArgs:
PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str) PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str)
PARSER.add_argument("--cache-dir", "--cache", "-c", required=True, type=str) PARSER.add_argument("--cache-dir", "--cache", "-c", required=True, type=str)
PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str) PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str)
PARSER.add_argument("--resume-at", "--resume", "-r", default=0, type=int)
PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int) PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int)
PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int) PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int)
PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int) PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int)
@ -55,6 +58,7 @@ def get_args(args: list[str]) -> ProgramArgs:
parsed_args.input_file, parsed_args.input_file,
parsed_args.cache_dir, parsed_args.cache_dir,
parsed_args.output_file, parsed_args.output_file,
parsed_args.resume_at,
parsed_args.max_vocabulary, parsed_args.max_vocabulary,
parsed_args.max_iterations, parsed_args.max_iterations,
parsed_args.merge_treshold, parsed_args.merge_treshold,
@ -82,25 +86,15 @@ def train(args: ProgramArgs):
BPE_ENCODER = TRAINER.trainBPE( BPE_ENCODER = TRAINER.trainBPE(
DATASET_PATH, DATASET_PATH,
CACHE_DIR CACHE_DIR,
resume_from_iter=args.resume_at
) )
VOCABULARY = BPE_ENCODER.vocabulary VOCABULARY = BPE_ENCODER.vocabulary
JSON_VOCABULARY: dict[str, int]= {}
for key, item in VOCABULARY.items():
TUPLE_STR = f"{key}"
JSON_VOCABULARY[TUPLE_STR] = item
VOCABULARY_JSON = json.dumps(JSON_VOCABULARY)
print(f"Saving Vocabulary in {VOCABULARY_PATH}") print(f"Saving Vocabulary in {VOCABULARY_PATH}")
FILE = open(VOCABULARY_PATH, "w") BPE.save_nanos_vocabulary(VOCABULARY, VOCABULARY_PATH)
FILE.write(VOCABULARY_JSON)
FILE.close()
if __name__ == "__main__": if __name__ == "__main__":
ARGS = get_args(sys.argv) ARGS = get_args(sys.argv)