Fixed a bug over task 4

This commit is contained in:
Christian Risi
2025-10-12 12:22:38 +02:00
parent ab3d68bc13
commit f463f699cf
2 changed files with 4 additions and 5 deletions

View File

@@ -203,7 +203,7 @@ while current_epoch < MAX_EPOCHS:
decoder_only_optim.zero_grad()
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
@@ -291,7 +291,7 @@ while current_epoch < MAX_EPOCHS:
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)