Fixed a patience bug and added label smoothing
This commit is contained in:
parent
0b256001fe
commit
83693f1d4e
@ -59,7 +59,8 @@ CHECKPOINT_STEPS = VALIDATION_STEPS * 1000
|
|||||||
PATIENCE = 4
|
PATIENCE = 4
|
||||||
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
|
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
|
||||||
VERBOSE = True
|
VERBOSE = True
|
||||||
LEARNING_RATE = 1.5
|
LEARNING_RATE = 0.05
|
||||||
|
LABEL_SMOOTHING = 0.01
|
||||||
|
|
||||||
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||||
|
|
||||||
@ -103,9 +104,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
|||||||
|
|
||||||
|
|
||||||
# Training constants
|
# Training constants
|
||||||
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
|
||||||
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
|
||||||
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
|
||||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
||||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
||||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
||||||
@ -352,16 +353,22 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
if VERBOSE:
|
if VERBOSE:
|
||||||
print("txt average is higher than lowest")
|
print("txt average is higher than lowest")
|
||||||
counter += 1
|
counter += 1
|
||||||
|
else:
|
||||||
|
average_loss_validation["txt"] = txt_avg_loss
|
||||||
|
|
||||||
if enc_avg_loss > average_loss_validation["encoder_only"]:
|
if enc_avg_loss > average_loss_validation["encoder_only"]:
|
||||||
if VERBOSE:
|
if VERBOSE:
|
||||||
print("masking average is higher than lowest")
|
print("masking average is higher than lowest")
|
||||||
counter += 1
|
counter += 1
|
||||||
|
else:
|
||||||
|
average_loss_validation["encoder_only"] = enc_avg_loss
|
||||||
|
|
||||||
if dec_avg_loss > average_loss_validation["decoder_only"]:
|
if dec_avg_loss > average_loss_validation["decoder_only"]:
|
||||||
if VERBOSE:
|
if VERBOSE:
|
||||||
print("decoding only average is higher than lowest")
|
print("decoding only average is higher than lowest")
|
||||||
counter += 1
|
counter += 1
|
||||||
|
else:
|
||||||
|
average_loss_validation["decoder_only"] = dec_avg_loss
|
||||||
|
|
||||||
if counter > 1:
|
if counter > 1:
|
||||||
patience += 1
|
patience += 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user