Fixed a memory bug
This commit is contained in:
parent
46ee6055ec
commit
71d602e36e
@ -43,15 +43,16 @@ FEED_FORWARD_MULTIPLIER = 4
|
|||||||
ATTENTION_HEADS = 8
|
ATTENTION_HEADS = 8
|
||||||
SENTENCE_LENGTH = 256
|
SENTENCE_LENGTH = 256
|
||||||
NUMBER_OF_BLOCKS = 4
|
NUMBER_OF_BLOCKS = 4
|
||||||
MAX_EPOCHS = int(1e3)
|
MAX_EPOCHS = int(3e3)
|
||||||
PRETRAIN_EPOCHS = int(300)
|
PRETRAIN_EPOCHS = int(300)
|
||||||
WARMUP_EPOCHS = int(4e3)
|
WARMUP_EPOCHS = int(1e3)
|
||||||
MINI_BATCH_SIZE = 300
|
MINI_BATCH_SIZE = 300
|
||||||
VALIDATION_STEPS = 25
|
VALIDATION_STEPS = 25
|
||||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||||
PATIENCE = 4
|
PATIENCE = 4
|
||||||
CURRENT_EPOCH = 0
|
CURRENT_EPOCH = 0
|
||||||
VERBOSE = False
|
VERBOSE = False
|
||||||
|
LEARNING_RATE = 1.5
|
||||||
|
|
||||||
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||||
|
|
||||||
@ -93,9 +94,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
|||||||
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||||
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||||
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), 1)
|
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
||||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), 1)
|
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
||||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), 1)
|
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
||||||
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
||||||
encoder_only_scheduler = Transformer.WarmupLR(
|
encoder_only_scheduler = Transformer.WarmupLR(
|
||||||
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
||||||
@ -228,77 +229,78 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
ENCODER_ONLY.eval()
|
ENCODER_ONLY.eval()
|
||||||
DECODER_ONLY.eval()
|
DECODER_ONLY.eval()
|
||||||
|
|
||||||
txt_avg_batch_losses = []
|
with torch.no_grad():
|
||||||
enc_avg_batch_losses = []
|
txt_avg_batch_losses = []
|
||||||
dec_avg_batch_losses = []
|
enc_avg_batch_losses = []
|
||||||
|
dec_avg_batch_losses = []
|
||||||
|
|
||||||
for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):
|
for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):
|
||||||
|
|
||||||
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||||
|
|
||||||
enc_x = torch.tensor(src_x)
|
enc_x = torch.tensor(src_x)
|
||||||
|
|
||||||
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
||||||
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
||||||
tgt = torch.tensor(tgt_y)
|
tgt = torch.tensor(tgt_y)
|
||||||
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||||
|
|
||||||
dec_x = Transformer.get_decoder_input(
|
dec_x = Transformer.get_decoder_input(
|
||||||
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
||||||
)
|
|
||||||
dec_x[:, 1:] = tgt[:, :-1]
|
|
||||||
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
|
||||||
|
|
||||||
# Task 1 and Task 2
|
|
||||||
if (
|
|
||||||
tasktype == Batch.TaskType.RDF2TXT
|
|
||||||
or tasktype == Batch.TaskType.TEXT2RDF
|
|
||||||
):
|
|
||||||
|
|
||||||
|
|
||||||
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
|
||||||
|
|
||||||
pred_logits = pred_logits.permute(0, 2, 1)
|
|
||||||
|
|
||||||
loss: torch.Tensor = nano_cross_entropy(
|
|
||||||
pred_logits, tgt
|
|
||||||
)
|
)
|
||||||
|
dec_x[:, 1:] = tgt[:, :-1]
|
||||||
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||||
|
|
||||||
txt_avg_batch_losses.append(loss)
|
# Task 1 and Task 2
|
||||||
|
if (
|
||||||
|
tasktype == Batch.TaskType.RDF2TXT
|
||||||
|
or tasktype == Batch.TaskType.TEXT2RDF
|
||||||
|
):
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Pretrain first
|
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
||||||
if current_epoch <= PRETRAIN_EPOCHS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Task 3
|
pred_logits = pred_logits.permute(0, 2, 1)
|
||||||
if tasktype == Batch.TaskType.MASKING:
|
|
||||||
|
|
||||||
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
loss: torch.Tensor = nano_cross_entropy(
|
||||||
pred_logits = pred_logits.permute(0, 2, 1)
|
pred_logits, tgt
|
||||||
|
)
|
||||||
|
|
||||||
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
txt_avg_batch_losses.append(loss)
|
||||||
|
|
||||||
enc_avg_batch_losses.append(loss.item())
|
continue
|
||||||
|
|
||||||
continue
|
# Pretrain first
|
||||||
|
if current_epoch <= PRETRAIN_EPOCHS:
|
||||||
|
continue
|
||||||
|
|
||||||
# Task 4
|
# Task 3
|
||||||
if tasktype == Batch.TaskType.COMPLETATION:
|
if tasktype == Batch.TaskType.MASKING:
|
||||||
|
|
||||||
|
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
||||||
|
pred_logits = pred_logits.permute(0, 2, 1)
|
||||||
|
|
||||||
|
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
||||||
|
|
||||||
|
enc_avg_batch_losses.append(loss.item())
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Task 4
|
||||||
|
if tasktype == Batch.TaskType.COMPLETATION:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
||||||
|
|
||||||
pred_logits = pred_logits.permute(0, 2, 1)
|
pred_logits = pred_logits.permute(0, 2, 1)
|
||||||
|
|
||||||
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
||||||
|
|
||||||
|
|
||||||
dec_avg_batch_losses.append(loss)
|
dec_avg_batch_losses.append(loss)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)
|
txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)
|
||||||
enc_avg_loss = float("inf")
|
enc_avg_loss = float("inf")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user