Fixed misprint in task 3
This commit is contained in:
parent
4281f8724b
commit
f98f5a2611
@ -46,7 +46,7 @@ NUMBER_OF_BLOCKS = 4
|
|||||||
MAX_EPOCHS = int(3e3)
|
MAX_EPOCHS = int(3e3)
|
||||||
PRETRAIN_EPOCHS = int(300)
|
PRETRAIN_EPOCHS = int(300)
|
||||||
WARMUP_EPOCHS = int(1e3)
|
WARMUP_EPOCHS = int(1e3)
|
||||||
MINI_BATCH_SIZE = 300
|
MINI_BATCH_SIZE = 80
|
||||||
VALIDATION_STEPS = 5
|
VALIDATION_STEPS = 5
|
||||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||||
PATIENCE = 4
|
PATIENCE = 4
|
||||||
@ -185,7 +185,7 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
|
|
||||||
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
||||||
pred_logits = pred_logits.permute(0, 2, 1)
|
pred_logits = pred_logits.permute(0, 2, 1)
|
||||||
print(torch.max(tgt))
|
# print(torch.max(tgt))
|
||||||
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user