diff --git a/Scripts/Training/bpe_trainer_pool.py b/Scripts/Training/bpe_trainer_pool.py index 5c7ab6e..966816d 100644 --- a/Scripts/Training/bpe_trainer_pool.py +++ b/Scripts/Training/bpe_trainer_pool.py @@ -72,11 +72,17 @@ def train(args: ProgramArgs): VOCABULARY_PATH = Path(args.output_file) CACHE_PATH = Path(args.cache_file) + start_bpe = BPE.NanoSocratesBPE() + if CACHE_PATH.is_file(): + voc = BPE.load_nanos_vocabulary(CACHE_PATH) + start_bpe = BPE.NanoSocratesBPE(voc) + print(f"Training BPE") BPE_ENCODER = TRAINER.trainBPE( DATASET_PATH, - CACHE_PATH + CACHE_PATH, + start_bpe ) VOCABULARY = BPE_ENCODER.vocabulary