Fixed bugs and added visibility
This commit is contained in:
parent
160b7dbfc0
commit
bcc2fe7368
@ -46,7 +46,7 @@ NUMBER_OF_BLOCKS = 4
|
||||
MAX_EPOCHS = int(1e3)
|
||||
PRETRAIN_EPOCHS = int(10)
|
||||
WARMUP_EPOCHS = int(4e3)
|
||||
MINI_BATCH_SIZE = 100
|
||||
MINI_BATCH_SIZE = 20
|
||||
VALIDATION_STEPS = 5
|
||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||
PATIENCE = 4
|
||||
@ -124,8 +124,14 @@ while current_epoch < MAX_EPOCHS:
|
||||
encoder_batch_losses = []
|
||||
decoder_batch_losses = []
|
||||
|
||||
batch_counter = 0
|
||||
|
||||
print(f"EPOCH {current_epoch} STARTING")
|
||||
|
||||
for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):
|
||||
|
||||
batch_counter += 1
|
||||
|
||||
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||
|
||||
enc_x = torch.tensor(src_x)
|
||||
@ -137,10 +143,16 @@ while current_epoch < MAX_EPOCHS:
|
||||
tgt = torch.tensor(tgt_y)
|
||||
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||
|
||||
print(f"\tBATCH {batch_counter} Starting")
|
||||
|
||||
# Task 1 and Task 2
|
||||
if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:
|
||||
|
||||
print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}")
|
||||
|
||||
BATCH_LOSS = []
|
||||
|
||||
|
||||
for token_idx in range(0, SENTENCE_LENGTH):
|
||||
|
||||
nano_optim.zero_grad()
|
||||
@ -173,6 +185,8 @@ while current_epoch < MAX_EPOCHS:
|
||||
# Task 3
|
||||
if tasktype == Batch.TaskType.MASKING:
|
||||
|
||||
print(f"\tExecuting TASK 3 - BATCH {batch_counter}")
|
||||
|
||||
encoder_only_optim.zero_grad()
|
||||
|
||||
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
||||
@ -190,6 +204,8 @@ while current_epoch < MAX_EPOCHS:
|
||||
# Task 4
|
||||
if tasktype == Batch.TaskType.COMPLETATION:
|
||||
|
||||
print(f"\tExecuting TASK 4 - BATCH {batch_counter}")
|
||||
|
||||
BATCH_LOSS = []
|
||||
|
||||
for token_idx in range(0, SENTENCE_LENGTH):
|
||||
|
||||
@ -24,7 +24,7 @@ class NanoSocraDecoder(torch.nn.Module):
|
||||
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
||||
|
||||
decoder_output, _, _, _, _, _ = self.__decoder(
|
||||
(decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, False)
|
||||
(decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, True)
|
||||
)
|
||||
|
||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user