Fixed misprint in task 3

This commit is contained in:
Christian Risi 2025-10-12 01:16:09 +02:00
parent 4281f8724b
commit f98f5a2611

View File

@ -46,7 +46,7 @@ NUMBER_OF_BLOCKS = 4
MAX_EPOCHS = int(3e3) MAX_EPOCHS = int(3e3)
PRETRAIN_EPOCHS = int(300) PRETRAIN_EPOCHS = int(300)
WARMUP_EPOCHS = int(1e3) WARMUP_EPOCHS = int(1e3)
MINI_BATCH_SIZE = 300 MINI_BATCH_SIZE = 80
VALIDATION_STEPS = 5 VALIDATION_STEPS = 5
CHECKPOINT_STEPS = VALIDATION_STEPS * 4 CHECKPOINT_STEPS = VALIDATION_STEPS * 4
PATIENCE = 4 PATIENCE = 4
@ -185,7 +185,7 @@ while current_epoch < MAX_EPOCHS:
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
pred_logits = pred_logits.permute(0, 2, 1) pred_logits = pred_logits.permute(0, 2, 1)
print(torch.max(tgt)) # print(torch.max(tgt))
loss: torch.Tensor = encoder_ce(pred_logits, tgt) loss: torch.Tensor = encoder_ce(pred_logits, tgt)
loss.backward() loss.backward()