Added more logging

This commit is contained in:
Christian Risi 2025-10-14 10:41:28 +02:00
parent 4968d79403
commit 7585f556f8

View File

@ -47,15 +47,15 @@ REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
EMBEDDED_SIZE = 256
FEED_FORWARD_MULTIPLIER = 4
ATTENTION_HEADS = 8
ATTENTION_HEADS = 4
SENTENCE_LENGTH = 256
NUMBER_OF_BLOCKS = 4
MAX_EPOCHS = int(3e3)
PRETRAIN_EPOCHS = int(300)
WARMUP_EPOCHS = int(1e3)
NUMBER_OF_BLOCKS = 2
MAX_EPOCHS = int(100)
PRETRAIN_EPOCHS = int(20)
WARMUP_EPOCHS = int(30)
MINI_BATCH_SIZE = 80
VALIDATION_STEPS = 5
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
VALIDATION_STEPS = 2
CHECKPOINT_STEPS = VALIDATION_STEPS * 1000
PATIENCE = 4
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
VERBOSE = True
@ -341,24 +341,39 @@ while current_epoch < MAX_EPOCHS:
average_loss_validation["txt"] = txt_avg_loss
else:
patience += 1
if VERBOSE:
print(f"losing a patience, current irritation: {patience}")
else:
counter = 0
if txt_avg_loss > average_loss_validation["txt"]:
if VERBOSE:
print("txt average is higher than lowest")
counter += 1
if txt_avg_loss > average_loss_validation["encoder_only"]:
if enc_avg_loss > average_loss_validation["encoder_only"]:
if VERBOSE:
print("masking average is higher than lowest")
counter += 1
if txt_avg_loss > average_loss_validation["decoder_only"]:
if dec_avg_loss > average_loss_validation["decoder_only"]:
if VERBOSE:
print("decoding only average is higher than lowest")
counter += 1
if counter > 1:
patience += 1
if VERBOSE:
print(f"losing a patience, current irritation: {patience}")
if counter == 0:
patience = max(0, patience - 1)
if VERBOSE:
print(f"all good, gaining a patience, current irritation: {patience}")
txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)