Fixed several bugs for task 4

This commit is contained in:
Christian Risi
2025-10-12 16:30:30 +02:00
parent e0f8a36aa5
commit 07130ff489
5 changed files with 186 additions and 9 deletions

View File

@@ -57,7 +57,7 @@ MINI_BATCH_SIZE = 80
VALIDATION_STEPS = 5
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
PATIENCE = 4
CURRENT_EPOCH = 0 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
VERBOSE = True
LEARNING_RATE = 1.5
@@ -228,7 +228,7 @@ while current_epoch < MAX_EPOCHS:
decoder_only_optim.zero_grad()
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
pred_logits = DECODER_ONLY((dec_x, enc_x_pad, dec_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
@@ -316,7 +316,7 @@ while current_epoch < MAX_EPOCHS:
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
pred_logits = DECODER_ONLY((dec_x, enc_x_pad, dec_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)