Fixed a patience bug and added label smoothing
This commit is contained in:
parent
0b256001fe
commit
83693f1d4e
@ -59,7 +59,8 @@ CHECKPOINT_STEPS = VALIDATION_STEPS * 1000
|
||||
PATIENCE = 4
|
||||
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
|
||||
VERBOSE = True
|
||||
LEARNING_RATE = 1.5
|
||||
LEARNING_RATE = 0.05
|
||||
LABEL_SMOOTHING = 0.01
|
||||
|
||||
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||
|
||||
@ -103,9 +104,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
||||
|
||||
|
||||
# Training constants
|
||||
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
|
||||
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
|
||||
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
|
||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
||||
@ -352,16 +353,22 @@ while current_epoch < MAX_EPOCHS:
|
||||
if VERBOSE:
|
||||
print("txt average is higher than lowest")
|
||||
counter += 1
|
||||
else:
|
||||
average_loss_validation["txt"] = txt_avg_loss
|
||||
|
||||
if enc_avg_loss > average_loss_validation["encoder_only"]:
|
||||
if VERBOSE:
|
||||
print("masking average is higher than lowest")
|
||||
counter += 1
|
||||
else:
|
||||
average_loss_validation["encoder_only"] = enc_avg_loss
|
||||
|
||||
if dec_avg_loss > average_loss_validation["decoder_only"]:
|
||||
if VERBOSE:
|
||||
print("decoding only average is higher than lowest")
|
||||
counter += 1
|
||||
else:
|
||||
average_loss_validation["decoder_only"] = dec_avg_loss
|
||||
|
||||
if counter > 1:
|
||||
patience += 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user