Fixed several bugs for task 4
This commit is contained in:
@@ -57,7 +57,7 @@ MINI_BATCH_SIZE = 80
|
||||
VALIDATION_STEPS = 5
|
||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||
PATIENCE = 4
|
||||
CURRENT_EPOCH = 0 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
|
||||
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
|
||||
VERBOSE = True
|
||||
LEARNING_RATE = 1.5
|
||||
|
||||
@@ -228,7 +228,7 @@ while current_epoch < MAX_EPOCHS:
|
||||
|
||||
decoder_only_optim.zero_grad()
|
||||
|
||||
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
|
||||
pred_logits = DECODER_ONLY((dec_x, enc_x_pad, dec_x_pad))
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
||||
@@ -316,7 +316,7 @@ while current_epoch < MAX_EPOCHS:
|
||||
|
||||
|
||||
|
||||
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
|
||||
pred_logits = DECODER_ONLY((dec_x, enc_x_pad, dec_x_pad))
|
||||
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user