Fixed a memory bug

This commit is contained in:
Christian Risi 2025-10-12 00:47:20 +02:00
parent 46ee6055ec
commit 71d602e36e

View File

@ -43,15 +43,16 @@ FEED_FORWARD_MULTIPLIER = 4
ATTENTION_HEADS = 8 ATTENTION_HEADS = 8
SENTENCE_LENGTH = 256 SENTENCE_LENGTH = 256
NUMBER_OF_BLOCKS = 4 NUMBER_OF_BLOCKS = 4
MAX_EPOCHS = int(1e3) MAX_EPOCHS = int(3e3)
PRETRAIN_EPOCHS = int(300) PRETRAIN_EPOCHS = int(300)
WARMUP_EPOCHS = int(4e3) WARMUP_EPOCHS = int(1e3)
MINI_BATCH_SIZE = 300 MINI_BATCH_SIZE = 300
VALIDATION_STEPS = 25 VALIDATION_STEPS = 25
CHECKPOINT_STEPS = VALIDATION_STEPS * 4 CHECKPOINT_STEPS = VALIDATION_STEPS * 4
PATIENCE = 4 PATIENCE = 4
CURRENT_EPOCH = 0 CURRENT_EPOCH = 0
VERBOSE = False VERBOSE = False
LEARNING_RATE = 1.5
SOS_TOKEN = TOKENANO.encode("<SOS>")[0] SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
@ -93,9 +94,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), 1) nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), 1) encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), 1) decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE) nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
encoder_only_scheduler = Transformer.WarmupLR( encoder_only_scheduler = Transformer.WarmupLR(
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
@ -228,77 +229,78 @@ while current_epoch < MAX_EPOCHS:
ENCODER_ONLY.eval() ENCODER_ONLY.eval()
DECODER_ONLY.eval() DECODER_ONLY.eval()
txt_avg_batch_losses = [] with torch.no_grad():
enc_avg_batch_losses = [] txt_avg_batch_losses = []
dec_avg_batch_losses = [] enc_avg_batch_losses = []
dec_avg_batch_losses = []
for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE): for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):
src_x, tgt_y, pad_x, pad_y, tasktype = batch src_x, tgt_y, pad_x, pad_y, tasktype = batch
enc_x = torch.tensor(src_x) enc_x = torch.tensor(src_x)
ACTUAL_BATCH_SIZE, _ = enc_x.shape ACTUAL_BATCH_SIZE, _ = enc_x.shape
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
tgt = torch.tensor(tgt_y) tgt = torch.tensor(tgt_y)
tgt_pad = torch.tensor(pad_y, dtype=torch.bool) tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
dec_x = Transformer.get_decoder_input( dec_x = Transformer.get_decoder_input(
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
)
dec_x[:, 1:] = tgt[:, :-1]
dec_x_pad = dec_x.eq(PAD_TOKEN)
# Task 1 and Task 2
if (
tasktype == Batch.TaskType.RDF2TXT
or tasktype == Batch.TaskType.TEXT2RDF
):
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)
loss: torch.Tensor = nano_cross_entropy(
pred_logits, tgt
) )
dec_x[:, 1:] = tgt[:, :-1]
dec_x_pad = dec_x.eq(PAD_TOKEN)
txt_avg_batch_losses.append(loss) # Task 1 and Task 2
if (
tasktype == Batch.TaskType.RDF2TXT
or tasktype == Batch.TaskType.TEXT2RDF
):
continue
# Pretrain first pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
if current_epoch <= PRETRAIN_EPOCHS:
continue
# Task 3 pred_logits = pred_logits.permute(0, 2, 1)
if tasktype == Batch.TaskType.MASKING:
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) loss: torch.Tensor = nano_cross_entropy(
pred_logits = pred_logits.permute(0, 2, 1) pred_logits, tgt
)
loss: torch.Tensor = encoder_ce(pred_logits, tgt) txt_avg_batch_losses.append(loss)
enc_avg_batch_losses.append(loss.item()) continue
continue # Pretrain first
if current_epoch <= PRETRAIN_EPOCHS:
continue
# Task 4 # Task 3
if tasktype == Batch.TaskType.COMPLETATION: if tasktype == Batch.TaskType.MASKING:
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
enc_avg_batch_losses.append(loss.item())
continue
# Task 4
if tasktype == Batch.TaskType.COMPLETATION:
pred_logits = DECODER_ONLY((enc_x, enc_x_pad)) pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
pred_logits = pred_logits.permute(0, 2, 1) pred_logits = pred_logits.permute(0, 2, 1)
loss: torch.Tensor = decoder_ce(pred_logits, tgt) loss: torch.Tensor = decoder_ce(pred_logits, tgt)
dec_avg_batch_losses.append(loss) dec_avg_batch_losses.append(loss)
continue continue
txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses) txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)
enc_avg_loss = float("inf") enc_avg_loss = float("inf")