Fixed a memory bug
This commit is contained in:
parent
46ee6055ec
commit
71d602e36e
@ -43,15 +43,16 @@ FEED_FORWARD_MULTIPLIER = 4
|
||||
ATTENTION_HEADS = 8
|
||||
SENTENCE_LENGTH = 256
|
||||
NUMBER_OF_BLOCKS = 4
|
||||
MAX_EPOCHS = int(1e3)
|
||||
MAX_EPOCHS = int(3e3)
|
||||
PRETRAIN_EPOCHS = int(300)
|
||||
WARMUP_EPOCHS = int(4e3)
|
||||
WARMUP_EPOCHS = int(1e3)
|
||||
MINI_BATCH_SIZE = 300
|
||||
VALIDATION_STEPS = 25
|
||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||
PATIENCE = 4
|
||||
CURRENT_EPOCH = 0
|
||||
VERBOSE = False
|
||||
LEARNING_RATE = 1.5
|
||||
|
||||
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||
|
||||
@ -93,9 +94,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
||||
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), 1)
|
||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), 1)
|
||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), 1)
|
||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
||||
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
||||
encoder_only_scheduler = Transformer.WarmupLR(
|
||||
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
||||
@ -228,77 +229,78 @@ while current_epoch < MAX_EPOCHS:
|
||||
ENCODER_ONLY.eval()
|
||||
DECODER_ONLY.eval()
|
||||
|
||||
txt_avg_batch_losses = []
|
||||
enc_avg_batch_losses = []
|
||||
dec_avg_batch_losses = []
|
||||
with torch.no_grad():
|
||||
txt_avg_batch_losses = []
|
||||
enc_avg_batch_losses = []
|
||||
dec_avg_batch_losses = []
|
||||
|
||||
for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):
|
||||
for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):
|
||||
|
||||
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||
|
||||
enc_x = torch.tensor(src_x)
|
||||
enc_x = torch.tensor(src_x)
|
||||
|
||||
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
||||
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
||||
tgt = torch.tensor(tgt_y)
|
||||
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
||||
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
||||
tgt = torch.tensor(tgt_y)
|
||||
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||
|
||||
dec_x = Transformer.get_decoder_input(
|
||||
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
||||
)
|
||||
dec_x[:, 1:] = tgt[:, :-1]
|
||||
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||
|
||||
# Task 1 and Task 2
|
||||
if (
|
||||
tasktype == Batch.TaskType.RDF2TXT
|
||||
or tasktype == Batch.TaskType.TEXT2RDF
|
||||
):
|
||||
|
||||
|
||||
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
||||
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
loss: torch.Tensor = nano_cross_entropy(
|
||||
pred_logits, tgt
|
||||
dec_x = Transformer.get_decoder_input(
|
||||
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
||||
)
|
||||
dec_x[:, 1:] = tgt[:, :-1]
|
||||
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||
|
||||
txt_avg_batch_losses.append(loss)
|
||||
# Task 1 and Task 2
|
||||
if (
|
||||
tasktype == Batch.TaskType.RDF2TXT
|
||||
or tasktype == Batch.TaskType.TEXT2RDF
|
||||
):
|
||||
|
||||
continue
|
||||
|
||||
# Pretrain first
|
||||
if current_epoch <= PRETRAIN_EPOCHS:
|
||||
continue
|
||||
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
||||
|
||||
# Task 3
|
||||
if tasktype == Batch.TaskType.MASKING:
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
loss: torch.Tensor = nano_cross_entropy(
|
||||
pred_logits, tgt
|
||||
)
|
||||
|
||||
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
||||
txt_avg_batch_losses.append(loss)
|
||||
|
||||
enc_avg_batch_losses.append(loss.item())
|
||||
continue
|
||||
|
||||
continue
|
||||
# Pretrain first
|
||||
if current_epoch <= PRETRAIN_EPOCHS:
|
||||
continue
|
||||
|
||||
# Task 4
|
||||
if tasktype == Batch.TaskType.COMPLETATION:
|
||||
# Task 3
|
||||
if tasktype == Batch.TaskType.MASKING:
|
||||
|
||||
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
||||
|
||||
enc_avg_batch_losses.append(loss.item())
|
||||
|
||||
continue
|
||||
|
||||
# Task 4
|
||||
if tasktype == Batch.TaskType.COMPLETATION:
|
||||
|
||||
|
||||
|
||||
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
||||
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
||||
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
||||
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
||||
|
||||
|
||||
dec_avg_batch_losses.append(loss)
|
||||
dec_avg_batch_losses.append(loss)
|
||||
|
||||
continue
|
||||
continue
|
||||
|
||||
txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)
|
||||
enc_avg_loss = float("inf")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user