Fixed a patience bug and added label smoothing

This commit is contained in:
Christian Risi 2025-10-14 11:03:15 +02:00
parent 0b256001fe
commit 83693f1d4e

View File

@ -59,7 +59,8 @@ CHECKPOINT_STEPS = VALIDATION_STEPS * 1000
PATIENCE = 4 PATIENCE = 4
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text()) CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
VERBOSE = True VERBOSE = True
LEARNING_RATE = 1.5 LEARNING_RATE = 0.05
LABEL_SMOOTHING = 0.01
SOS_TOKEN = TOKENANO.encode("<SOS>")[0] SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
@ -103,9 +104,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
# Training constants # Training constants
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE) nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE) encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE) decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
@ -352,16 +353,22 @@ while current_epoch < MAX_EPOCHS:
if VERBOSE: if VERBOSE:
print("txt average is higher than lowest") print("txt average is higher than lowest")
counter += 1 counter += 1
else:
average_loss_validation["txt"] = txt_avg_loss
if enc_avg_loss > average_loss_validation["encoder_only"]: if enc_avg_loss > average_loss_validation["encoder_only"]:
if VERBOSE: if VERBOSE:
print("masking average is higher than lowest") print("masking average is higher than lowest")
counter += 1 counter += 1
else:
average_loss_validation["encoder_only"] = enc_avg_loss
if dec_avg_loss > average_loss_validation["decoder_only"]: if dec_avg_loss > average_loss_validation["decoder_only"]:
if VERBOSE: if VERBOSE:
print("decoding only average is higher than lowest") print("decoding only average is higher than lowest")
counter += 1 counter += 1
else:
average_loss_validation["decoder_only"] = dec_avg_loss
if counter > 1: if counter > 1:
patience += 1 patience += 1