Fixed bugs and added visibility

This commit is contained in:
Christian Risi 2025-10-11 21:49:29 +02:00
parent 160b7dbfc0
commit bcc2fe7368
2 changed files with 18 additions and 2 deletions

View File

@ -46,7 +46,7 @@ NUMBER_OF_BLOCKS = 4
MAX_EPOCHS = int(1e3) MAX_EPOCHS = int(1e3)
PRETRAIN_EPOCHS = int(10) PRETRAIN_EPOCHS = int(10)
WARMUP_EPOCHS = int(4e3) WARMUP_EPOCHS = int(4e3)
MINI_BATCH_SIZE = 100 MINI_BATCH_SIZE = 20
VALIDATION_STEPS = 5 VALIDATION_STEPS = 5
CHECKPOINT_STEPS = VALIDATION_STEPS * 4 CHECKPOINT_STEPS = VALIDATION_STEPS * 4
PATIENCE = 4 PATIENCE = 4
@ -124,8 +124,14 @@ while current_epoch < MAX_EPOCHS:
encoder_batch_losses = [] encoder_batch_losses = []
decoder_batch_losses = [] decoder_batch_losses = []
batch_counter = 0
print(f"EPOCH {current_epoch} STARTING")
for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE): for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):
batch_counter += 1
src_x, tgt_y, pad_x, pad_y, tasktype = batch src_x, tgt_y, pad_x, pad_y, tasktype = batch
enc_x = torch.tensor(src_x) enc_x = torch.tensor(src_x)
@ -137,10 +143,16 @@ while current_epoch < MAX_EPOCHS:
tgt = torch.tensor(tgt_y) tgt = torch.tensor(tgt_y)
tgt_pad = torch.tensor(pad_y, dtype=torch.bool) tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
print(f"\tBATCH {batch_counter} Starting")
# Task 1 and Task 2 # Task 1 and Task 2
if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF: if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:
print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}")
BATCH_LOSS = [] BATCH_LOSS = []
for token_idx in range(0, SENTENCE_LENGTH): for token_idx in range(0, SENTENCE_LENGTH):
nano_optim.zero_grad() nano_optim.zero_grad()
@ -173,6 +185,8 @@ while current_epoch < MAX_EPOCHS:
# Task 3 # Task 3
if tasktype == Batch.TaskType.MASKING: if tasktype == Batch.TaskType.MASKING:
print(f"\tExecuting TASK 3 - BATCH {batch_counter}")
encoder_only_optim.zero_grad() encoder_only_optim.zero_grad()
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
@ -190,6 +204,8 @@ while current_epoch < MAX_EPOCHS:
# Task 4 # Task 4
if tasktype == Batch.TaskType.COMPLETATION: if tasktype == Batch.TaskType.COMPLETATION:
print(f"\tExecuting TASK 4 - BATCH {batch_counter}")
BATCH_LOSS = [] BATCH_LOSS = []
for token_idx in range(0, SENTENCE_LENGTH): for token_idx in range(0, SENTENCE_LENGTH):

View File

@ -24,7 +24,7 @@ class NanoSocraDecoder(torch.nn.Module):
decoder_tensor = self.__decoder_embedder(decoder_embedder_input) decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
decoder_output, _, _, _, _, _ = self.__decoder( decoder_output, _, _, _, _, _ = self.__decoder(
(decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, False) (decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, True)
) )
logits: torch.Tensor = self.__detokener(decoder_output) logits: torch.Tensor = self.__detokener(decoder_output)