Fixed a memory bug
This commit is contained in:
parent
46ee6055ec
commit
71d602e36e
@ -43,15 +43,16 @@ FEED_FORWARD_MULTIPLIER = 4
|
||||
ATTENTION_HEADS = 8
|
||||
SENTENCE_LENGTH = 256
|
||||
NUMBER_OF_BLOCKS = 4
|
||||
MAX_EPOCHS = int(1e3)
|
||||
MAX_EPOCHS = int(3e3)
|
||||
PRETRAIN_EPOCHS = int(300)
|
||||
WARMUP_EPOCHS = int(4e3)
|
||||
WARMUP_EPOCHS = int(1e3)
|
||||
MINI_BATCH_SIZE = 300
|
||||
VALIDATION_STEPS = 25
|
||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||
PATIENCE = 4
|
||||
CURRENT_EPOCH = 0
|
||||
VERBOSE = False
|
||||
LEARNING_RATE = 1.5
|
||||
|
||||
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||
|
||||
@ -93,9 +94,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
||||
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), 1)
|
||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), 1)
|
||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), 1)
|
||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
||||
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
||||
encoder_only_scheduler = Transformer.WarmupLR(
|
||||
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
||||
@ -228,6 +229,7 @@ while current_epoch < MAX_EPOCHS:
|
||||
ENCODER_ONLY.eval()
|
||||
DECODER_ONLY.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
txt_avg_batch_losses = []
|
||||
enc_avg_batch_losses = []
|
||||
dec_avg_batch_losses = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user