Fixed a bug about mismatched batch sizes

This commit is contained in:
Christian Risi 2025-10-11 22:09:46 +02:00
parent bcc2fe7368
commit d8e65bfb8a

View File

@ -134,10 +134,14 @@ while current_epoch < MAX_EPOCHS:
src_x, tgt_y, pad_x, pad_y, tasktype = batch src_x, tgt_y, pad_x, pad_y, tasktype = batch
enc_x = torch.tensor(src_x) enc_x = torch.tensor(src_x)
ACTUAL_BATCH_SIZE, _ = enc_x.shape
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
dec_x = Transformer.get_decoder_input( dec_x = Transformer.get_decoder_input(
MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
) )
dec_x_pad = dec_x.eq(PAD_TOKEN) dec_x_pad = dec_x.eq(PAD_TOKEN)
tgt = torch.tensor(tgt_y) tgt = torch.tensor(tgt_y)
@ -257,10 +261,13 @@ while current_epoch < MAX_EPOCHS:
src_x, tgt_y, pad_x, pad_y, tasktype = batch src_x, tgt_y, pad_x, pad_y, tasktype = batch
enc_x = torch.tensor(src_x) enc_x = torch.tensor(src_x)
ACTUAL_BATCH_SIZE, _, _ = enc_x.shape
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
dec_x = Transformer.get_decoder_input( dec_x = Transformer.get_decoder_input(
MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
) )
dec_x_pad = dec_x.eq(PAD_TOKEN) dec_x_pad = dec_x.eq(PAD_TOKEN)
tgt = torch.tensor(tgt_y) tgt = torch.tensor(tgt_y)
tgt_pad = torch.tensor(pad_y, dtype=torch.bool) tgt_pad = torch.tensor(pad_y, dtype=torch.bool)