diff --git a/Playgrounds/nanosocrates-train.py b/Playgrounds/nanosocrates-train.py index 3a15a54..e3c44e9 100644 --- a/Playgrounds/nanosocrates-train.py +++ b/Playgrounds/nanosocrates-train.py @@ -46,7 +46,7 @@ NUMBER_OF_BLOCKS = 4 MAX_EPOCHS = int(1e3) PRETRAIN_EPOCHS = int(10) WARMUP_EPOCHS = int(4e3) -MINI_BATCH_SIZE = 100 +MINI_BATCH_SIZE = 20 VALIDATION_STEPS = 5 CHECKPOINT_STEPS = VALIDATION_STEPS * 4 PATIENCE = 4 @@ -124,8 +124,14 @@ while current_epoch < MAX_EPOCHS: encoder_batch_losses = [] decoder_batch_losses = [] + batch_counter = 0 + + print(f"EPOCH {current_epoch} STARTING") + for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE): + batch_counter += 1 + src_x, tgt_y, pad_x, pad_y, tasktype = batch enc_x = torch.tensor(src_x) @@ -137,10 +143,16 @@ while current_epoch < MAX_EPOCHS: tgt = torch.tensor(tgt_y) tgt_pad = torch.tensor(pad_y, dtype=torch.bool) + print(f"\tBATCH {batch_counter} Starting") + # Task 1 and Task 2 if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF: + + print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}") + BATCH_LOSS = [] + for token_idx in range(0, SENTENCE_LENGTH): nano_optim.zero_grad() @@ -173,6 +185,8 @@ while current_epoch < MAX_EPOCHS: # Task 3 if tasktype == Batch.TaskType.MASKING: + print(f"\tExecuting TASK 3 - BATCH {batch_counter}") + encoder_only_optim.zero_grad() pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) @@ -190,6 +204,8 @@ while current_epoch < MAX_EPOCHS: # Task 4 if tasktype == Batch.TaskType.COMPLETATION: + print(f"\tExecuting TASK 4 - BATCH {batch_counter}") + BATCH_LOSS = [] for token_idx in range(0, SENTENCE_LENGTH): diff --git a/Project_Model/Libs/Transformer/Models/NanoSocraDecoder.py b/Project_Model/Libs/Transformer/Models/NanoSocraDecoder.py index 04f1789..3abceea 100644 --- a/Project_Model/Libs/Transformer/Models/NanoSocraDecoder.py +++ b/Project_Model/Libs/Transformer/Models/NanoSocraDecoder.py @@ -24,7 +24,7 @@ class NanoSocraDecoder(torch.nn.Module): decoder_tensor = self.__decoder_embedder(decoder_embedder_input) decoder_output, _, _, _, _, _ = self.__decoder( - (decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, False) + (decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, True) ) logits: torch.Tensor = self.__detokener(decoder_output)