From e0f8a36aa5b8275110d6217f9b49a49f0a5ba0e3 Mon Sep 17 00:00:00 2001 From: Christian Risi <75698846+CnF-Gris@users.noreply.github.com> Date: Sun, 12 Oct 2025 13:53:07 +0200 Subject: [PATCH] Added support for fast resuming --- .../nanosocrates-train-experiment-2.py | 38 ++++++++++++++++--- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/Playgrounds/nanosocrates-train-experiment-2.py b/Playgrounds/nanosocrates-train-experiment-2.py index d845fe0..256d180 100644 --- a/Playgrounds/nanosocrates-train-experiment-2.py +++ b/Playgrounds/nanosocrates-train-experiment-2.py @@ -21,11 +21,18 @@ torch.set_default_device(DEVICE) # Get paths +CHECKPOINT_DIR = "Assets/Dataset/Tmp" VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json") TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv") VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv") TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv") -CHECKPOINT_PATH = Path("Assets/Dataset/Tmp/NanoSocrates.zip") +CHECKPOINT_PATH = Path(f"{CHECKPOINT_DIR}/NanoSocrates.zip") + +NANO_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/nano_optim.zip") +ENC_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/enc_optim.zip") +DEC_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/dec_optim.zip") +LAST_EPOCH_PATH = Path(f"{CHECKPOINT_DIR}/last_epoch.txt") + # BPE Init @@ -50,7 +57,7 @@ MINI_BATCH_SIZE = 80 VALIDATION_STEPS = 5 CHECKPOINT_STEPS = VALIDATION_STEPS * 4 PATIENCE = 4 -CURRENT_EPOCH = 0 +CURRENT_EPOCH = 0 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text()) VERBOSE = True LEARNING_RATE = 1.5 @@ -78,7 +85,6 @@ TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKE # Model - NANOSOCRATES = Transformer.TrainingModel( TOKEN_SPACE_SIZE, EMBEDDED_SIZE, @@ -103,12 +109,25 @@ decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE) encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE) decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE) -nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE) + +if NANO_OPTIM_PATH.is_file(): + optim_dict = torch.load(NANO_OPTIM_PATH) + nano_optim.load_state_dict(optim_dict) + +if ENC_OPTIM_PATH.is_file(): + optim_dict = torch.load(ENC_OPTIM_PATH) + encoder_only_optim.load_state_dict(optim_dict) + +if DEC_OPTIM_PATH.is_file(): + optim_dict = torch.load(DEC_OPTIM_PATH) + decoder_only_optim.load_state_dict(optim_dict) + +nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE, last_epoch=CURRENT_EPOCH) encoder_only_scheduler = Transformer.WarmupLR( - encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE + encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE, last_epoch=CURRENT_EPOCH ) decoder_only_scheduler = Transformer.WarmupLR( - decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE + decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE, last_epoch=CURRENT_EPOCH ) current_epoch = CURRENT_EPOCH @@ -380,6 +399,13 @@ while current_epoch < MAX_EPOCHS: if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE: print(f"Saving model at {CHECKPOINT_PATH.as_posix()}") torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH) + torch.save(nano_optim.state_dict(), NANO_OPTIM_PATH) + torch.save(encoder_only_optim.state_dict(), ENC_OPTIM_PATH) + torch.save(decoder_only_optim.state_dict(), DEC_OPTIM_PATH) + FILE = open(LAST_EPOCH_PATH, "w", encoding="utf-8") + FILE.write(f"{current_epoch}") + FILE.close() + if patience == PATIENCE: exit(0)