diff --git a/Playgrounds/nanosocrates-train-experiment-2.py b/Playgrounds/nanosocrates-train-experiment-2.py index 71844f4..2a070a9 100644 --- a/Playgrounds/nanosocrates-train-experiment-2.py +++ b/Playgrounds/nanosocrates-train-experiment-2.py @@ -331,6 +331,9 @@ while current_epoch < MAX_EPOCHS: if counter > 1: patience += 1 + + if counter == 0: + patience = max(0, patience - 1) txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)