fixed patience not quitting

This commit is contained in:
Christian Risi 2025-10-12 01:41:34 +02:00
parent 79438e3d30
commit ab3d68bc13

View File

@ -331,7 +331,7 @@ while current_epoch < MAX_EPOCHS:
if counter > 1: if counter > 1:
patience += 1 patience += 1
if counter == 0: if counter == 0:
patience = max(0, patience - 1) patience = max(0, patience - 1)
@ -359,7 +359,7 @@ while current_epoch < MAX_EPOCHS:
f"\t\tavg_txt: {txt_train_avg_loss} - avg_enc: {enc_avg_train_loss} - avg_dec: {dec_avg_train_loss}\n", f"\t\tavg_txt: {txt_train_avg_loss} - avg_enc: {enc_avg_train_loss} - avg_dec: {dec_avg_train_loss}\n",
f"{SEPARATOR}\n", f"{SEPARATOR}\n",
f"Validation Losses:\n", f"Validation Losses:\n",
f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n", f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction_loss: {dec_avg_loss}\n",
f"{SEPARATOR}\n", f"{SEPARATOR}\n",
] ]
) )
@ -374,3 +374,6 @@ while current_epoch < MAX_EPOCHS:
if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE: if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE:
print(f"Saving model at {CHECKPOINT_PATH.as_posix()}") print(f"Saving model at {CHECKPOINT_PATH.as_posix()}")
torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH) torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)
if patience == PATIENCE:
exit(0)