Fixed a memory bug

This commit is contained in:
Christian Risi 2025-10-12 00:47:20 +02:00
parent 46ee6055ec
commit 71d602e36e

View File

@ -43,15 +43,16 @@ FEED_FORWARD_MULTIPLIER = 4
ATTENTION_HEADS = 8
SENTENCE_LENGTH = 256
NUMBER_OF_BLOCKS = 4
MAX_EPOCHS = int(1e3)
MAX_EPOCHS = int(3e3)
PRETRAIN_EPOCHS = int(300)
WARMUP_EPOCHS = int(4e3)
WARMUP_EPOCHS = int(1e3)
MINI_BATCH_SIZE = 300
VALIDATION_STEPS = 25
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
PATIENCE = 4
CURRENT_EPOCH = 0
VERBOSE = False
LEARNING_RATE = 1.5
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
@ -93,9 +94,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), 1)
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), 1)
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), 1)
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
encoder_only_scheduler = Transformer.WarmupLR(
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
@ -228,6 +229,7 @@ while current_epoch < MAX_EPOCHS:
ENCODER_ONLY.eval()
DECODER_ONLY.eval()
with torch.no_grad():
txt_avg_batch_losses = []
enc_avg_batch_losses = []
dec_avg_batch_losses = []