Fixed Validation loss

This commit is contained in:
Christian Risi 2025-10-12 00:57:24 +02:00
parent 71d602e36e
commit 4281f8724b

View File

@ -47,11 +47,11 @@ MAX_EPOCHS = int(3e3)
PRETRAIN_EPOCHS = int(300) PRETRAIN_EPOCHS = int(300)
WARMUP_EPOCHS = int(1e3) WARMUP_EPOCHS = int(1e3)
MINI_BATCH_SIZE = 300 MINI_BATCH_SIZE = 300
VALIDATION_STEPS = 25 VALIDATION_STEPS = 5
CHECKPOINT_STEPS = VALIDATION_STEPS * 4 CHECKPOINT_STEPS = VALIDATION_STEPS * 4
PATIENCE = 4 PATIENCE = 4
CURRENT_EPOCH = 0 CURRENT_EPOCH = 0
VERBOSE = False VERBOSE = True
LEARNING_RATE = 1.5 LEARNING_RATE = 1.5
SOS_TOKEN = TOKENANO.encode("<SOS>")[0] SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
@ -333,15 +333,15 @@ while current_epoch < MAX_EPOCHS:
patience += 1 patience += 1
txt_avg_loss = sum(text_batch_losses) / len(text_batch_losses) txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)
enc_avg_train_loss = float("inf") enc_avg_train_loss = float("inf")
dec_avg_loss = float("inf") dec_avg_train_loss = float("inf")
if current_epoch > PRETRAIN_EPOCHS: if current_epoch > PRETRAIN_EPOCHS:
try: try:
enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses) enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses)
dec_avg_loss = sum(decoder_batch_losses) / len(decoder_batch_losses) dec_avg_train_loss = sum(decoder_batch_losses) / len(decoder_batch_losses)
except: except:
pass pass
@ -353,7 +353,7 @@ while current_epoch < MAX_EPOCHS:
f"{SEPARATOR}\n", f"{SEPARATOR}\n",
f"Train Losses:\n", f"Train Losses:\n",
f"\tAvg Losses:\n", f"\tAvg Losses:\n",
f"\t\tavg_txt: {txt_avg_loss} - avg_enc: {enc_avg_loss} - avg_dec: {dec_avg_loss}\n", f"\t\tavg_txt: {txt_train_avg_loss} - avg_enc: {enc_avg_train_loss} - avg_dec: {dec_avg_train_loss}\n",
f"{SEPARATOR}\n", f"{SEPARATOR}\n",
f"Validation Losses:\n", f"Validation Losses:\n",
f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n", f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n",