Fixed a bug about mismatched batch sizes
This commit is contained in:
parent
bcc2fe7368
commit
d8e65bfb8a
@ -134,10 +134,14 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
|
|
||||||
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
enc_x = torch.tensor(src_x)
|
enc_x = torch.tensor(src_x)
|
||||||
|
|
||||||
|
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
||||||
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
||||||
dec_x = Transformer.get_decoder_input(
|
dec_x = Transformer.get_decoder_input(
|
||||||
MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
||||||
)
|
)
|
||||||
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||||
tgt = torch.tensor(tgt_y)
|
tgt = torch.tensor(tgt_y)
|
||||||
@ -257,10 +261,13 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||||
|
|
||||||
enc_x = torch.tensor(src_x)
|
enc_x = torch.tensor(src_x)
|
||||||
|
|
||||||
|
ACTUAL_BATCH_SIZE, _, _ = enc_x.shape
|
||||||
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
||||||
dec_x = Transformer.get_decoder_input(
|
dec_x = Transformer.get_decoder_input(
|
||||||
MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
||||||
)
|
)
|
||||||
|
|
||||||
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||||
tgt = torch.tensor(tgt_y)
|
tgt = torch.tensor(tgt_y)
|
||||||
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user