Fixed training
This commit is contained in:
parent
d8e65bfb8a
commit
56fbadd55e
365
Playgrounds/nanosocrates-train-experiment-2.py
Normal file
365
Playgrounds/nanosocrates-train-experiment-2.py
Normal file
@ -0,0 +1,365 @@
|
||||
import random
|
||||
import sys
|
||||
import torch
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import Project_Model.Libs.Embedder as Embedder
|
||||
import Project_Model.Libs.BPE as BPE
|
||||
import Project_Model.Libs.Transformer as Transformer
|
||||
import Project_Model.Libs.TransformerUtils as TUtils
|
||||
import Project_Model.Libs.TorchShims as torch_shims
|
||||
import Project_Model.Libs.Batch as Batch
|
||||
|
||||
# set a fixed seed
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
|
||||
# set a default device
|
||||
DEVICE = torch_shims.get_default_device()
|
||||
torch.set_default_device(DEVICE)
|
||||
|
||||
|
||||
# Get paths
|
||||
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
||||
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
|
||||
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
|
||||
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
|
||||
CHECKPOINT_PATH = Path("Assets/Dataset/Tmp/NanoSocrates.zip")
|
||||
|
||||
|
||||
# BPE Init
|
||||
SPECIAL_VOC = BPE.default_special_tokens()
|
||||
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
|
||||
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)
|
||||
|
||||
|
||||
# Constants
|
||||
MASK_EXTRA_SPACE = 100
|
||||
REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size
|
||||
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
|
||||
EMBEDDED_SIZE = 256
|
||||
FEED_FORWARD_MULTIPLIER = 4
|
||||
ATTENTION_HEADS = 8
|
||||
SENTENCE_LENGTH = 256
|
||||
NUMBER_OF_BLOCKS = 4
|
||||
MAX_EPOCHS = int(1e3)
|
||||
PRETRAIN_EPOCHS = int(300)
|
||||
WARMUP_EPOCHS = int(4e3)
|
||||
MINI_BATCH_SIZE = 100
|
||||
VALIDATION_STEPS = 50
|
||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||
PATIENCE = 4
|
||||
CURRENT_EPOCH = 0
|
||||
|
||||
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||
|
||||
PAD_TOKEN = TOKENANO.encode("<PAD>")[0]
|
||||
END_TOKEN = TOKENANO.encode("<END>")[0]
|
||||
SUBJ_TOKEN = TOKENANO.encode("<SUBJ>")[0]
|
||||
REL_TOKEN = TOKENANO.encode("<PRED>")[0]
|
||||
OBJ_TOKEN = TOKENANO.encode("<OBJ>")[0]
|
||||
|
||||
SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens())))
|
||||
ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN])
|
||||
FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS
|
||||
|
||||
|
||||
# Spanned_Masker
|
||||
MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS)
|
||||
|
||||
TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
||||
VALIDATION_BATCHER = Batch.Batcher(
|
||||
VALIDATION_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER
|
||||
)
|
||||
TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
||||
|
||||
|
||||
# Model
|
||||
NANOSOCRATES = Transformer.TrainingModel(
|
||||
TOKEN_SPACE_SIZE,
|
||||
EMBEDDED_SIZE,
|
||||
FEED_FORWARD_MULTIPLIER,
|
||||
ATTENTION_HEADS,
|
||||
NUMBER_OF_BLOCKS,
|
||||
)
|
||||
_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
||||
NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE
|
||||
)
|
||||
|
||||
|
||||
# Training constants
|
||||
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), 1)
|
||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), 1)
|
||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), 1)
|
||||
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
||||
encoder_only_scheduler = Transformer.WarmupLR(
|
||||
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
||||
)
|
||||
decoder_only_scheduler = Transformer.WarmupLR(
|
||||
decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
||||
)
|
||||
|
||||
current_epoch = CURRENT_EPOCH
|
||||
patience = 0
|
||||
|
||||
|
||||
average_loss_validation = {
|
||||
"txt": float("inf"),
|
||||
"encoder_only": float("inf"),
|
||||
"decoder_only": float("inf"),
|
||||
}
|
||||
|
||||
while current_epoch < MAX_EPOCHS:
|
||||
|
||||
NANOSOCRATES.train()
|
||||
ENCODER_ONLY.train()
|
||||
DECODER_ONLY.train()
|
||||
|
||||
text_batch_losses = []
|
||||
encoder_batch_losses = []
|
||||
decoder_batch_losses = []
|
||||
|
||||
batch_counter = 0
|
||||
|
||||
print(f"EPOCH {current_epoch} STARTING")
|
||||
|
||||
for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):
|
||||
|
||||
batch_counter += 1
|
||||
|
||||
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||
|
||||
enc_x = torch.tensor(src_x)
|
||||
|
||||
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
||||
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
||||
tgt = torch.tensor(tgt_y)
|
||||
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||
|
||||
dec_x = Transformer.get_decoder_input(
|
||||
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
||||
)
|
||||
dec_x[:, 1:] = tgt[:, :-1]
|
||||
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||
|
||||
print(f"\tBATCH {batch_counter} Starting")
|
||||
|
||||
# Task 1 and Task 2
|
||||
if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:
|
||||
|
||||
print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}")
|
||||
|
||||
nano_optim.zero_grad()
|
||||
|
||||
pred_logits: torch.Tensor = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt)
|
||||
|
||||
loss.backward()
|
||||
nano_optim.step()
|
||||
|
||||
text_batch_losses.append(loss)
|
||||
continue
|
||||
|
||||
# Pretrain first
|
||||
if current_epoch < PRETRAIN_EPOCHS:
|
||||
continue
|
||||
|
||||
# Task 3
|
||||
if tasktype == Batch.TaskType.MASKING:
|
||||
|
||||
print(f"\tExecuting TASK 3 - BATCH {batch_counter}")
|
||||
|
||||
encoder_only_optim.zero_grad()
|
||||
|
||||
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
print(torch.max(tgt))
|
||||
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
||||
|
||||
loss.backward()
|
||||
encoder_only_optim.step()
|
||||
|
||||
encoder_batch_losses.append(loss.item())
|
||||
|
||||
continue
|
||||
|
||||
# Task 4
|
||||
if tasktype == Batch.TaskType.COMPLETATION:
|
||||
|
||||
print(f"\tExecuting TASK 4 - BATCH {batch_counter}")
|
||||
|
||||
decoder_only_optim.zero_grad()
|
||||
|
||||
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
||||
|
||||
loss.backward()
|
||||
decoder_only_optim.step()
|
||||
|
||||
decoder_batch_losses.append(
|
||||
loss
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
nano_scheduler.step()
|
||||
encoder_only_scheduler.step()
|
||||
decoder_only_scheduler.step()
|
||||
|
||||
current_epoch += 1
|
||||
|
||||
if current_epoch % VALIDATION_STEPS == 0:
|
||||
|
||||
NANOSOCRATES.eval()
|
||||
ENCODER_ONLY.eval()
|
||||
DECODER_ONLY.eval()
|
||||
|
||||
txt_avg_batch_losses = []
|
||||
enc_avg_batch_losses = []
|
||||
dec_avg_batch_losses = []
|
||||
|
||||
for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):
|
||||
|
||||
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||
|
||||
enc_x = torch.tensor(src_x)
|
||||
|
||||
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
||||
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
||||
tgt = torch.tensor(tgt_y)
|
||||
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||
|
||||
dec_x = Transformer.get_decoder_input(
|
||||
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
||||
)
|
||||
dec_x[:, 1:] = tgt[:, :-1]
|
||||
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||
|
||||
# Task 1 and Task 2
|
||||
if (
|
||||
tasktype == Batch.TaskType.RDF2TXT
|
||||
or tasktype == Batch.TaskType.TEXT2RDF
|
||||
):
|
||||
|
||||
|
||||
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
||||
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
loss: torch.Tensor = nano_cross_entropy(
|
||||
pred_logits, tgt
|
||||
)
|
||||
|
||||
txt_avg_batch_losses.append(loss)
|
||||
|
||||
continue
|
||||
|
||||
# Pretrain first
|
||||
if current_epoch <= PRETRAIN_EPOCHS:
|
||||
continue
|
||||
|
||||
# Task 3
|
||||
if tasktype == Batch.TaskType.MASKING:
|
||||
|
||||
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
||||
|
||||
enc_avg_batch_losses.append(loss.item())
|
||||
|
||||
continue
|
||||
|
||||
# Task 4
|
||||
if tasktype == Batch.TaskType.COMPLETATION:
|
||||
|
||||
|
||||
|
||||
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
||||
|
||||
pred_logits = pred_logits.permute(0, 2, 1)
|
||||
|
||||
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
||||
|
||||
|
||||
dec_avg_batch_losses.append(loss)
|
||||
|
||||
continue
|
||||
|
||||
txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)
|
||||
enc_avg_loss = float("inf")
|
||||
dec_avg_loss = float("inf")
|
||||
|
||||
if current_epoch > PRETRAIN_EPOCHS:
|
||||
enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses)
|
||||
dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses)
|
||||
|
||||
if current_epoch < PRETRAIN_EPOCHS:
|
||||
|
||||
if txt_avg_loss < average_loss_validation["txt"]:
|
||||
average_loss_validation["txt"] = txt_avg_loss
|
||||
else:
|
||||
patience += 1
|
||||
else:
|
||||
|
||||
counter = 0
|
||||
|
||||
if txt_avg_loss > average_loss_validation["txt"]:
|
||||
counter += 1
|
||||
|
||||
if txt_avg_loss > average_loss_validation["encoder_only"]:
|
||||
counter += 1
|
||||
|
||||
if txt_avg_loss > average_loss_validation["decoder_only"]:
|
||||
counter += 1
|
||||
|
||||
if counter > 1:
|
||||
patience += 1
|
||||
|
||||
|
||||
txt_avg_loss = sum(text_batch_losses) / len(text_batch_losses)
|
||||
|
||||
enc_avg_train_loss = float("inf")
|
||||
dec_avg_loss = float("inf")
|
||||
|
||||
if current_epoch > PRETRAIN_EPOCHS:
|
||||
try:
|
||||
enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses)
|
||||
dec_avg_loss = sum(decoder_batch_losses) / len(decoder_batch_losses)
|
||||
except:
|
||||
pass
|
||||
|
||||
SEPARATOR = "================================================================================================================"
|
||||
DEBUG_TEXT = "".join(
|
||||
[
|
||||
f"{SEPARATOR}\n",
|
||||
f"EPOCH {current_epoch}\n",
|
||||
f"{SEPARATOR}\n",
|
||||
f"Train Losses:\n",
|
||||
f"\tAvg Losses:\n",
|
||||
f"\t\tavg_txt: {txt_avg_loss} - avg_enc: {enc_avg_loss} - avg_dec: {dec_avg_loss}\n",
|
||||
f"{SEPARATOR}\n",
|
||||
f"Validation Losses:\n",
|
||||
f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n",
|
||||
f"{SEPARATOR}\n",
|
||||
]
|
||||
)
|
||||
|
||||
print(DEBUG_TEXT)
|
||||
|
||||
# Warn about patience
|
||||
if patience == PATIENCE:
|
||||
print("Model is likely overfitting, so let's stop here")
|
||||
|
||||
# SAVE MODEL
|
||||
if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE:
|
||||
print(f"Saving model at {CHECKPOINT_PATH.as_posix()}")
|
||||
torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)
|
||||
Loading…
x
Reference in New Issue
Block a user