Fixed a memory bug
This commit is contained in:
parent
46ee6055ec
commit
71d602e36e
@ -43,15 +43,16 @@ FEED_FORWARD_MULTIPLIER = 4
|
|||||||
ATTENTION_HEADS = 8
|
ATTENTION_HEADS = 8
|
||||||
SENTENCE_LENGTH = 256
|
SENTENCE_LENGTH = 256
|
||||||
NUMBER_OF_BLOCKS = 4
|
NUMBER_OF_BLOCKS = 4
|
||||||
MAX_EPOCHS = int(1e3)
|
MAX_EPOCHS = int(3e3)
|
||||||
PRETRAIN_EPOCHS = int(300)
|
PRETRAIN_EPOCHS = int(300)
|
||||||
WARMUP_EPOCHS = int(4e3)
|
WARMUP_EPOCHS = int(1e3)
|
||||||
MINI_BATCH_SIZE = 300
|
MINI_BATCH_SIZE = 300
|
||||||
VALIDATION_STEPS = 25
|
VALIDATION_STEPS = 25
|
||||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||||
PATIENCE = 4
|
PATIENCE = 4
|
||||||
CURRENT_EPOCH = 0
|
CURRENT_EPOCH = 0
|
||||||
VERBOSE = False
|
VERBOSE = False
|
||||||
|
LEARNING_RATE = 1.5
|
||||||
|
|
||||||
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||||
|
|
||||||
@ -93,9 +94,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
|||||||
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||||
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||||
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), 1)
|
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
||||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), 1)
|
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
||||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), 1)
|
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
||||||
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
||||||
encoder_only_scheduler = Transformer.WarmupLR(
|
encoder_only_scheduler = Transformer.WarmupLR(
|
||||||
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
||||||
@ -228,6 +229,7 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
ENCODER_ONLY.eval()
|
ENCODER_ONLY.eval()
|
||||||
DECODER_ONLY.eval()
|
DECODER_ONLY.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
txt_avg_batch_losses = []
|
txt_avg_batch_losses = []
|
||||||
enc_avg_batch_losses = []
|
enc_avg_batch_losses = []
|
||||||
dec_avg_batch_losses = []
|
dec_avg_batch_losses = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user