dev.train #8

Merged
gape_01 merged 50 commits from dev.train into dev 2025-10-17 22:20:14 +02:00
Showing only changes of commit f51ada866f - Show all commits

View File

@ -51,6 +51,7 @@ VALIDATION_STEPS = 50
CHECKPOINT_STEPS = VALIDATION_STEPS * 4 CHECKPOINT_STEPS = VALIDATION_STEPS * 4
PATIENCE = 4 PATIENCE = 4
CURRENT_EPOCH = 0 CURRENT_EPOCH = 0
VERBOSE = False
SOS_TOKEN = TOKENANO.encode("<SOS>")[0] SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
@ -125,7 +126,8 @@ while current_epoch < MAX_EPOCHS:
batch_counter = 0 batch_counter = 0
print(f"EPOCH {current_epoch} STARTING") if VERBOSE:
print(f"EPOCH {current_epoch} STARTING")
for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE): for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):
@ -146,7 +148,8 @@ while current_epoch < MAX_EPOCHS:
dec_x[:, 1:] = tgt[:, :-1] dec_x[:, 1:] = tgt[:, :-1]
dec_x_pad = dec_x.eq(PAD_TOKEN) dec_x_pad = dec_x.eq(PAD_TOKEN)
print(f"\tBATCH {batch_counter} Starting") if VERBOSE:
print(f"\tBATCH {batch_counter} Starting")
# Task 1 and Task 2 # Task 1 and Task 2
if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF: if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:
@ -173,7 +176,8 @@ while current_epoch < MAX_EPOCHS:
# Task 3 # Task 3
if tasktype == Batch.TaskType.MASKING: if tasktype == Batch.TaskType.MASKING:
print(f"\tExecuting TASK 3 - BATCH {batch_counter}") if VERBOSE:
print(f"\tExecuting TASK 3 - BATCH {batch_counter}")
encoder_only_optim.zero_grad() encoder_only_optim.zero_grad()
@ -192,7 +196,8 @@ while current_epoch < MAX_EPOCHS:
# Task 4 # Task 4
if tasktype == Batch.TaskType.COMPLETATION: if tasktype == Batch.TaskType.COMPLETATION:
print(f"\tExecuting TASK 4 - BATCH {batch_counter}") if VERBOSE:
print(f"\tExecuting TASK 4 - BATCH {batch_counter}")
decoder_only_optim.zero_grad() decoder_only_optim.zero_grad()