From 71d602e36e550b45c165615c7664825fd63a6c80 Mon Sep 17 00:00:00 2001 From: Christian Risi <75698846+CnF-Gris@users.noreply.github.com> Date: Sun, 12 Oct 2025 00:47:20 +0200 Subject: [PATCH] Fixed a memory bug --- .../nanosocrates-train-experiment-2.py | 108 +++++++++--------- 1 file changed, 55 insertions(+), 53 deletions(-) diff --git a/Playgrounds/nanosocrates-train-experiment-2.py b/Playgrounds/nanosocrates-train-experiment-2.py index e94756b..1f15a82 100644 --- a/Playgrounds/nanosocrates-train-experiment-2.py +++ b/Playgrounds/nanosocrates-train-experiment-2.py @@ -43,15 +43,16 @@ FEED_FORWARD_MULTIPLIER = 4 ATTENTION_HEADS = 8 SENTENCE_LENGTH = 256 NUMBER_OF_BLOCKS = 4 -MAX_EPOCHS = int(1e3) +MAX_EPOCHS = int(3e3) PRETRAIN_EPOCHS = int(300) -WARMUP_EPOCHS = int(4e3) +WARMUP_EPOCHS = int(1e3) MINI_BATCH_SIZE = 300 VALIDATION_STEPS = 25 CHECKPOINT_STEPS = VALIDATION_STEPS * 4 PATIENCE = 4 CURRENT_EPOCH = 0 VERBOSE = False +LEARNING_RATE = 1.5 SOS_TOKEN = TOKENANO.encode("")[0] @@ -93,9 +94,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates( nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) -nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), 1) -encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), 1) -decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), 1) +nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE) +encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE) +decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE) nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE) encoder_only_scheduler = Transformer.WarmupLR( encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE @@ -228,77 +229,78 @@ while current_epoch < MAX_EPOCHS: ENCODER_ONLY.eval() DECODER_ONLY.eval() - txt_avg_batch_losses = [] - enc_avg_batch_losses = [] - dec_avg_batch_losses = [] + with torch.no_grad(): + txt_avg_batch_losses = [] + enc_avg_batch_losses = [] + dec_avg_batch_losses = [] - for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE): + for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE): - src_x, tgt_y, pad_x, pad_y, tasktype = batch + src_x, tgt_y, pad_x, pad_y, tasktype = batch - enc_x = torch.tensor(src_x) + enc_x = torch.tensor(src_x) - ACTUAL_BATCH_SIZE, _ = enc_x.shape - enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) - tgt = torch.tensor(tgt_y) - tgt_pad = torch.tensor(pad_y, dtype=torch.bool) + ACTUAL_BATCH_SIZE, _ = enc_x.shape + enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) + tgt = torch.tensor(tgt_y) + tgt_pad = torch.tensor(pad_y, dtype=torch.bool) - dec_x = Transformer.get_decoder_input( - ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH - ) - dec_x[:, 1:] = tgt[:, :-1] - dec_x_pad = dec_x.eq(PAD_TOKEN) - - # Task 1 and Task 2 - if ( - tasktype == Batch.TaskType.RDF2TXT - or tasktype == Batch.TaskType.TEXT2RDF - ): - - - pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad)) - - pred_logits = pred_logits.permute(0, 2, 1) - - loss: torch.Tensor = nano_cross_entropy( - pred_logits, tgt + dec_x = Transformer.get_decoder_input( + ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ) + dec_x[:, 1:] = tgt[:, :-1] + dec_x_pad = dec_x.eq(PAD_TOKEN) - txt_avg_batch_losses.append(loss) + # Task 1 and Task 2 + if ( + tasktype == Batch.TaskType.RDF2TXT + or tasktype == Batch.TaskType.TEXT2RDF + ): - continue - # Pretrain first - if current_epoch <= PRETRAIN_EPOCHS: - continue + pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad)) - # Task 3 - if tasktype == Batch.TaskType.MASKING: + pred_logits = pred_logits.permute(0, 2, 1) - pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) - pred_logits = pred_logits.permute(0, 2, 1) + loss: torch.Tensor = nano_cross_entropy( + pred_logits, tgt + ) - loss: torch.Tensor = encoder_ce(pred_logits, tgt) + txt_avg_batch_losses.append(loss) - enc_avg_batch_losses.append(loss.item()) + continue - continue + # Pretrain first + if current_epoch <= PRETRAIN_EPOCHS: + continue - # Task 4 - if tasktype == Batch.TaskType.COMPLETATION: + # Task 3 + if tasktype == Batch.TaskType.MASKING: + + pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) + pred_logits = pred_logits.permute(0, 2, 1) + + loss: torch.Tensor = encoder_ce(pred_logits, tgt) + + enc_avg_batch_losses.append(loss.item()) + + continue + + # Task 4 + if tasktype == Batch.TaskType.COMPLETATION: - pred_logits = DECODER_ONLY((enc_x, enc_x_pad)) + pred_logits = DECODER_ONLY((enc_x, enc_x_pad)) - pred_logits = pred_logits.permute(0, 2, 1) + pred_logits = pred_logits.permute(0, 2, 1) - loss: torch.Tensor = decoder_ce(pred_logits, tgt) + loss: torch.Tensor = decoder_ce(pred_logits, tgt) - dec_avg_batch_losses.append(loss) + dec_avg_batch_losses.append(loss) - continue + continue txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses) enc_avg_loss = float("inf")