Added more logging
This commit is contained in:
parent
4968d79403
commit
7585f556f8
@ -47,15 +47,15 @@ REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size
|
|||||||
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
|
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
|
||||||
EMBEDDED_SIZE = 256
|
EMBEDDED_SIZE = 256
|
||||||
FEED_FORWARD_MULTIPLIER = 4
|
FEED_FORWARD_MULTIPLIER = 4
|
||||||
ATTENTION_HEADS = 8
|
ATTENTION_HEADS = 4
|
||||||
SENTENCE_LENGTH = 256
|
SENTENCE_LENGTH = 256
|
||||||
NUMBER_OF_BLOCKS = 4
|
NUMBER_OF_BLOCKS = 2
|
||||||
MAX_EPOCHS = int(3e3)
|
MAX_EPOCHS = int(100)
|
||||||
PRETRAIN_EPOCHS = int(300)
|
PRETRAIN_EPOCHS = int(20)
|
||||||
WARMUP_EPOCHS = int(1e3)
|
WARMUP_EPOCHS = int(30)
|
||||||
MINI_BATCH_SIZE = 80
|
MINI_BATCH_SIZE = 80
|
||||||
VALIDATION_STEPS = 5
|
VALIDATION_STEPS = 2
|
||||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
CHECKPOINT_STEPS = VALIDATION_STEPS * 1000
|
||||||
PATIENCE = 4
|
PATIENCE = 4
|
||||||
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
|
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
|
||||||
VERBOSE = True
|
VERBOSE = True
|
||||||
@ -341,24 +341,39 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
average_loss_validation["txt"] = txt_avg_loss
|
average_loss_validation["txt"] = txt_avg_loss
|
||||||
else:
|
else:
|
||||||
patience += 1
|
patience += 1
|
||||||
|
if VERBOSE:
|
||||||
|
print(f"losing a patience, current irritation: {patience}")
|
||||||
else:
|
else:
|
||||||
|
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
if txt_avg_loss > average_loss_validation["txt"]:
|
if txt_avg_loss > average_loss_validation["txt"]:
|
||||||
|
|
||||||
|
if VERBOSE:
|
||||||
|
print("txt average is higher than lowest")
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
if txt_avg_loss > average_loss_validation["encoder_only"]:
|
if enc_avg_loss > average_loss_validation["encoder_only"]:
|
||||||
|
if VERBOSE:
|
||||||
|
print("masking average is higher than lowest")
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
if txt_avg_loss > average_loss_validation["decoder_only"]:
|
if dec_avg_loss > average_loss_validation["decoder_only"]:
|
||||||
|
if VERBOSE:
|
||||||
|
print("decoding only average is higher than lowest")
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
if counter > 1:
|
if counter > 1:
|
||||||
patience += 1
|
patience += 1
|
||||||
|
if VERBOSE:
|
||||||
|
print(f"losing a patience, current irritation: {patience}")
|
||||||
|
|
||||||
|
|
||||||
if counter == 0:
|
if counter == 0:
|
||||||
patience = max(0, patience - 1)
|
patience = max(0, patience - 1)
|
||||||
|
if VERBOSE:
|
||||||
|
print(f"all good, gaining a patience, current irritation: {patience}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)
|
txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user