diff --git a/Playgrounds/nanosocrates-train.py b/Playgrounds/nanosocrates-train.py index e3c44e9..80c73f0 100644 --- a/Playgrounds/nanosocrates-train.py +++ b/Playgrounds/nanosocrates-train.py @@ -134,10 +134,14 @@ while current_epoch < MAX_EPOCHS: src_x, tgt_y, pad_x, pad_y, tasktype = batch + + enc_x = torch.tensor(src_x) + + ACTUAL_BATCH_SIZE, _ = enc_x.shape enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) dec_x = Transformer.get_decoder_input( - MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH + ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ) dec_x_pad = dec_x.eq(PAD_TOKEN) tgt = torch.tensor(tgt_y) @@ -257,10 +261,13 @@ while current_epoch < MAX_EPOCHS: src_x, tgt_y, pad_x, pad_y, tasktype = batch enc_x = torch.tensor(src_x) + + ACTUAL_BATCH_SIZE, _, _ = enc_x.shape enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) dec_x = Transformer.get_decoder_input( - MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH + ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ) + dec_x_pad = dec_x.eq(PAD_TOKEN) tgt = torch.tensor(tgt_y) tgt_pad = torch.tensor(pad_y, dtype=torch.bool)