Fixed Validation loss
This commit is contained in:
parent
71d602e36e
commit
4281f8724b
@ -47,11 +47,11 @@ MAX_EPOCHS = int(3e3)
|
|||||||
PRETRAIN_EPOCHS = int(300)
|
PRETRAIN_EPOCHS = int(300)
|
||||||
WARMUP_EPOCHS = int(1e3)
|
WARMUP_EPOCHS = int(1e3)
|
||||||
MINI_BATCH_SIZE = 300
|
MINI_BATCH_SIZE = 300
|
||||||
VALIDATION_STEPS = 25
|
VALIDATION_STEPS = 5
|
||||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||||
PATIENCE = 4
|
PATIENCE = 4
|
||||||
CURRENT_EPOCH = 0
|
CURRENT_EPOCH = 0
|
||||||
VERBOSE = False
|
VERBOSE = True
|
||||||
LEARNING_RATE = 1.5
|
LEARNING_RATE = 1.5
|
||||||
|
|
||||||
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||||
@ -333,15 +333,15 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
patience += 1
|
patience += 1
|
||||||
|
|
||||||
|
|
||||||
txt_avg_loss = sum(text_batch_losses) / len(text_batch_losses)
|
txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)
|
||||||
|
|
||||||
enc_avg_train_loss = float("inf")
|
enc_avg_train_loss = float("inf")
|
||||||
dec_avg_loss = float("inf")
|
dec_avg_train_loss = float("inf")
|
||||||
|
|
||||||
if current_epoch > PRETRAIN_EPOCHS:
|
if current_epoch > PRETRAIN_EPOCHS:
|
||||||
try:
|
try:
|
||||||
enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses)
|
enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses)
|
||||||
dec_avg_loss = sum(decoder_batch_losses) / len(decoder_batch_losses)
|
dec_avg_train_loss = sum(decoder_batch_losses) / len(decoder_batch_losses)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -353,7 +353,7 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
f"{SEPARATOR}\n",
|
f"{SEPARATOR}\n",
|
||||||
f"Train Losses:\n",
|
f"Train Losses:\n",
|
||||||
f"\tAvg Losses:\n",
|
f"\tAvg Losses:\n",
|
||||||
f"\t\tavg_txt: {txt_avg_loss} - avg_enc: {enc_avg_loss} - avg_dec: {dec_avg_loss}\n",
|
f"\t\tavg_txt: {txt_train_avg_loss} - avg_enc: {enc_avg_train_loss} - avg_dec: {dec_avg_train_loss}\n",
|
||||||
f"{SEPARATOR}\n",
|
f"{SEPARATOR}\n",
|
||||||
f"Validation Losses:\n",
|
f"Validation Losses:\n",
|
||||||
f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n",
|
f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user