Added support for fast resuming
This commit is contained in:
parent
37a2501a79
commit
e0f8a36aa5
@ -21,11 +21,18 @@ torch.set_default_device(DEVICE)
|
|||||||
|
|
||||||
|
|
||||||
# Get paths
|
# Get paths
|
||||||
|
CHECKPOINT_DIR = "Assets/Dataset/Tmp"
|
||||||
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
||||||
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
|
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
|
||||||
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
|
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
|
||||||
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
|
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
|
||||||
CHECKPOINT_PATH = Path("Assets/Dataset/Tmp/NanoSocrates.zip")
|
CHECKPOINT_PATH = Path(f"{CHECKPOINT_DIR}/NanoSocrates.zip")
|
||||||
|
|
||||||
|
NANO_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/nano_optim.zip")
|
||||||
|
ENC_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/enc_optim.zip")
|
||||||
|
DEC_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/dec_optim.zip")
|
||||||
|
LAST_EPOCH_PATH = Path(f"{CHECKPOINT_DIR}/last_epoch.txt")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# BPE Init
|
# BPE Init
|
||||||
@ -50,7 +57,7 @@ MINI_BATCH_SIZE = 80
|
|||||||
VALIDATION_STEPS = 5
|
VALIDATION_STEPS = 5
|
||||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||||
PATIENCE = 4
|
PATIENCE = 4
|
||||||
CURRENT_EPOCH = 0
|
CURRENT_EPOCH = 0 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
|
||||||
VERBOSE = True
|
VERBOSE = True
|
||||||
LEARNING_RATE = 1.5
|
LEARNING_RATE = 1.5
|
||||||
|
|
||||||
@ -78,7 +85,6 @@ TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKE
|
|||||||
|
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
|
|
||||||
NANOSOCRATES = Transformer.TrainingModel(
|
NANOSOCRATES = Transformer.TrainingModel(
|
||||||
TOKEN_SPACE_SIZE,
|
TOKEN_SPACE_SIZE,
|
||||||
EMBEDDED_SIZE,
|
EMBEDDED_SIZE,
|
||||||
@ -103,12 +109,25 @@ decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
|||||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
||||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
||||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
||||||
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
|
||||||
|
if NANO_OPTIM_PATH.is_file():
|
||||||
|
optim_dict = torch.load(NANO_OPTIM_PATH)
|
||||||
|
nano_optim.load_state_dict(optim_dict)
|
||||||
|
|
||||||
|
if ENC_OPTIM_PATH.is_file():
|
||||||
|
optim_dict = torch.load(ENC_OPTIM_PATH)
|
||||||
|
encoder_only_optim.load_state_dict(optim_dict)
|
||||||
|
|
||||||
|
if DEC_OPTIM_PATH.is_file():
|
||||||
|
optim_dict = torch.load(DEC_OPTIM_PATH)
|
||||||
|
decoder_only_optim.load_state_dict(optim_dict)
|
||||||
|
|
||||||
|
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE, last_epoch=CURRENT_EPOCH)
|
||||||
encoder_only_scheduler = Transformer.WarmupLR(
|
encoder_only_scheduler = Transformer.WarmupLR(
|
||||||
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE, last_epoch=CURRENT_EPOCH
|
||||||
)
|
)
|
||||||
decoder_only_scheduler = Transformer.WarmupLR(
|
decoder_only_scheduler = Transformer.WarmupLR(
|
||||||
decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE, last_epoch=CURRENT_EPOCH
|
||||||
)
|
)
|
||||||
|
|
||||||
current_epoch = CURRENT_EPOCH
|
current_epoch = CURRENT_EPOCH
|
||||||
@ -380,6 +399,13 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE:
|
if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE:
|
||||||
print(f"Saving model at {CHECKPOINT_PATH.as_posix()}")
|
print(f"Saving model at {CHECKPOINT_PATH.as_posix()}")
|
||||||
torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)
|
torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)
|
||||||
|
torch.save(nano_optim.state_dict(), NANO_OPTIM_PATH)
|
||||||
|
torch.save(encoder_only_optim.state_dict(), ENC_OPTIM_PATH)
|
||||||
|
torch.save(decoder_only_optim.state_dict(), DEC_OPTIM_PATH)
|
||||||
|
FILE = open(LAST_EPOCH_PATH, "w", encoding="utf-8")
|
||||||
|
FILE.write(f"{current_epoch}")
|
||||||
|
FILE.close()
|
||||||
|
|
||||||
|
|
||||||
if patience == PATIENCE:
|
if patience == PATIENCE:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user