Added more logging

This commit is contained in:
Christian Risi 2025-10-14 10:41:28 +02:00
parent 4968d79403
commit 7585f556f8

View File

@ -47,15 +47,15 @@ REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
EMBEDDED_SIZE = 256 EMBEDDED_SIZE = 256
FEED_FORWARD_MULTIPLIER = 4 FEED_FORWARD_MULTIPLIER = 4
ATTENTION_HEADS = 8 ATTENTION_HEADS = 4
SENTENCE_LENGTH = 256 SENTENCE_LENGTH = 256
NUMBER_OF_BLOCKS = 4 NUMBER_OF_BLOCKS = 2
MAX_EPOCHS = int(3e3) MAX_EPOCHS = int(100)
PRETRAIN_EPOCHS = int(300) PRETRAIN_EPOCHS = int(20)
WARMUP_EPOCHS = int(1e3) WARMUP_EPOCHS = int(30)
MINI_BATCH_SIZE = 80 MINI_BATCH_SIZE = 80
VALIDATION_STEPS = 5 VALIDATION_STEPS = 2
CHECKPOINT_STEPS = VALIDATION_STEPS * 4 CHECKPOINT_STEPS = VALIDATION_STEPS * 1000
PATIENCE = 4 PATIENCE = 4
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text()) CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
VERBOSE = True VERBOSE = True
@ -341,24 +341,39 @@ while current_epoch < MAX_EPOCHS:
average_loss_validation["txt"] = txt_avg_loss average_loss_validation["txt"] = txt_avg_loss
else: else:
patience += 1 patience += 1
if VERBOSE:
print(f"losing a patience, current irritation: {patience}")
else: else:
counter = 0 counter = 0
if txt_avg_loss > average_loss_validation["txt"]: if txt_avg_loss > average_loss_validation["txt"]:
if VERBOSE:
print("txt average is higher than lowest")
counter += 1 counter += 1
if txt_avg_loss > average_loss_validation["encoder_only"]: if enc_avg_loss > average_loss_validation["encoder_only"]:
if VERBOSE:
print("masking average is higher than lowest")
counter += 1 counter += 1
if txt_avg_loss > average_loss_validation["decoder_only"]: if dec_avg_loss > average_loss_validation["decoder_only"]:
if VERBOSE:
print("decoding only average is higher than lowest")
counter += 1 counter += 1
if counter > 1: if counter > 1:
patience += 1 patience += 1
if VERBOSE:
print(f"losing a patience, current irritation: {patience}")
if counter == 0: if counter == 0:
patience = max(0, patience - 1) patience = max(0, patience - 1)
if VERBOSE:
print(f"all good, gaining a patience, current irritation: {patience}")
txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses) txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)