Fixed bugs and added visibility
This commit is contained in:
parent
160b7dbfc0
commit
bcc2fe7368
@ -46,7 +46,7 @@ NUMBER_OF_BLOCKS = 4
|
|||||||
MAX_EPOCHS = int(1e3)
|
MAX_EPOCHS = int(1e3)
|
||||||
PRETRAIN_EPOCHS = int(10)
|
PRETRAIN_EPOCHS = int(10)
|
||||||
WARMUP_EPOCHS = int(4e3)
|
WARMUP_EPOCHS = int(4e3)
|
||||||
MINI_BATCH_SIZE = 100
|
MINI_BATCH_SIZE = 20
|
||||||
VALIDATION_STEPS = 5
|
VALIDATION_STEPS = 5
|
||||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||||
PATIENCE = 4
|
PATIENCE = 4
|
||||||
@ -124,8 +124,14 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
encoder_batch_losses = []
|
encoder_batch_losses = []
|
||||||
decoder_batch_losses = []
|
decoder_batch_losses = []
|
||||||
|
|
||||||
|
batch_counter = 0
|
||||||
|
|
||||||
|
print(f"EPOCH {current_epoch} STARTING")
|
||||||
|
|
||||||
for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):
|
for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):
|
||||||
|
|
||||||
|
batch_counter += 1
|
||||||
|
|
||||||
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||||
|
|
||||||
enc_x = torch.tensor(src_x)
|
enc_x = torch.tensor(src_x)
|
||||||
@ -137,10 +143,16 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
tgt = torch.tensor(tgt_y)
|
tgt = torch.tensor(tgt_y)
|
||||||
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||||
|
|
||||||
|
print(f"\tBATCH {batch_counter} Starting")
|
||||||
|
|
||||||
# Task 1 and Task 2
|
# Task 1 and Task 2
|
||||||
if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:
|
if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:
|
||||||
|
|
||||||
|
print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}")
|
||||||
|
|
||||||
BATCH_LOSS = []
|
BATCH_LOSS = []
|
||||||
|
|
||||||
|
|
||||||
for token_idx in range(0, SENTENCE_LENGTH):
|
for token_idx in range(0, SENTENCE_LENGTH):
|
||||||
|
|
||||||
nano_optim.zero_grad()
|
nano_optim.zero_grad()
|
||||||
@ -173,6 +185,8 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
# Task 3
|
# Task 3
|
||||||
if tasktype == Batch.TaskType.MASKING:
|
if tasktype == Batch.TaskType.MASKING:
|
||||||
|
|
||||||
|
print(f"\tExecuting TASK 3 - BATCH {batch_counter}")
|
||||||
|
|
||||||
encoder_only_optim.zero_grad()
|
encoder_only_optim.zero_grad()
|
||||||
|
|
||||||
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
||||||
@ -190,6 +204,8 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
# Task 4
|
# Task 4
|
||||||
if tasktype == Batch.TaskType.COMPLETATION:
|
if tasktype == Batch.TaskType.COMPLETATION:
|
||||||
|
|
||||||
|
print(f"\tExecuting TASK 4 - BATCH {batch_counter}")
|
||||||
|
|
||||||
BATCH_LOSS = []
|
BATCH_LOSS = []
|
||||||
|
|
||||||
for token_idx in range(0, SENTENCE_LENGTH):
|
for token_idx in range(0, SENTENCE_LENGTH):
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class NanoSocraDecoder(torch.nn.Module):
|
|||||||
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
||||||
|
|
||||||
decoder_output, _, _, _, _, _ = self.__decoder(
|
decoder_output, _, _, _, _, _ = self.__decoder(
|
||||||
(decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, False)
|
(decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, True)
|
||||||
)
|
)
|
||||||
|
|
||||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user