import random import sys import torch import pandas as pd from pathlib import Path import Project_Model.Libs.Embedder as Embedder import Project_Model.Libs.BPE as BPE import Project_Model.Libs.Transformer as Transformer import Project_Model.Libs.TransformerUtils as TUtils import Project_Model.Libs.TorchShims as torch_shims import Project_Model.Libs.Batch as Batch # set a fixed seed torch.manual_seed(0) random.seed(0) # set a default device DEVICE = torch_shims.get_default_device() torch.set_default_device(DEVICE) # Get paths VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json") TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv") VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv") TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv") CHECKPOINT_PATH = Path("Assets/Dataset/Tmp/NanoSocrates.zip") # BPE Init SPECIAL_VOC = BPE.default_special_tokens() VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH) TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC) # Constants MASK_EXTRA_SPACE = 100 REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE EMBEDDED_SIZE = 256 FEED_FORWARD_MULTIPLIER = 4 ATTENTION_HEADS = 8 SENTENCE_LENGTH = 256 NUMBER_OF_BLOCKS = 4 MAX_EPOCHS = int(1e3) PRETRAIN_EPOCHS = int(10) WARMUP_EPOCHS = int(4e3) MINI_BATCH_SIZE = 100 VALIDATION_STEPS = 5 CHECKPOINT_STEPS = VALIDATION_STEPS * 4 PATIENCE = 4 CURRENT_EPOCH = 0 SOS_TOKEN = TOKENANO.encode("")[0] PAD_TOKEN = TOKENANO.encode("")[0] END_TOKEN = TOKENANO.encode("")[0] SUBJ_TOKEN = TOKENANO.encode("")[0] REL_TOKEN = TOKENANO.encode("")[0] OBJ_TOKEN = TOKENANO.encode("")[0] SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens()))) ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN]) FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS # Spanned_Masker MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS) TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER) VALIDATION_BATCHER = Batch.Batcher( VALIDATION_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER ) TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER) # Model NANOSOCRATES = Transformer.TrainingModel( TOKEN_SPACE_SIZE, EMBEDDED_SIZE, FEED_FORWARD_MULTIPLIER, ATTENTION_HEADS, NUMBER_OF_BLOCKS, ) _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates( NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE ) # Training constants nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters()) encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters()) decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters()) nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE) encoder_only_scheduler = Transformer.WarmupLR( encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE ) decoder_only_scheduler = Transformer.WarmupLR( decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE ) current_epoch = CURRENT_EPOCH patience = 0 average_loss_validation = { "txt": float("inf"), "encoder_only": float("inf"), "decoder_only": float("inf"), } while current_epoch < MAX_EPOCHS: NANOSOCRATES.train() ENCODER_ONLY.train() DECODER_ONLY.train() text_batch_losses = [] encoder_batch_losses = [] decoder_batch_losses = [] batch_counter = 0 print(f"EPOCH {current_epoch} STARTING") for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE): batch_counter += 1 src_x, tgt_y, pad_x, pad_y, tasktype = batch enc_x = torch.tensor(src_x) ACTUAL_BATCH_SIZE, _ = enc_x.shape enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) dec_x = Transformer.get_decoder_input( ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ) dec_x_pad = dec_x.eq(PAD_TOKEN) tgt = torch.tensor(tgt_y) tgt_pad = torch.tensor(pad_y, dtype=torch.bool) print(f"\tBATCH {batch_counter} Starting") # Task 1 and Task 2 if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF: print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}") BATCH_LOSS = [] for token_idx in range(0, SENTENCE_LENGTH): nano_optim.zero_grad() pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad)) pred_logits = pred_logits[:, token_idx, :] loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt[:, token_idx]) loss.backward() nano_optim.step() BATCH_LOSS.append(loss.item()) if token_idx < SENTENCE_LENGTH - 1: dec_x[:, token_idx + 1] = tgt[:, token_idx] MIN_BATCH_LOSS = min(BATCH_LOSS) MAX_BATCH_LOSS = max(BATCH_LOSS) AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE text_batch_losses.append([MIN_BATCH_LOSS, AVG_BATCH_LOSS, MAX_BATCH_LOSS]) continue # Pretrain first if current_epoch < PRETRAIN_EPOCHS: continue # Task 3 if tasktype == Batch.TaskType.MASKING: print(f"\tExecuting TASK 3 - BATCH {batch_counter}") encoder_only_optim.zero_grad() pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) print(torch.max(tgt)) loss: torch.Tensor = encoder_ce(pred_logits, tgt) loss.backward() encoder_only_optim.step() encoder_batch_losses.append(loss.item()) continue # Task 4 if tasktype == Batch.TaskType.COMPLETATION: print(f"\tExecuting TASK 4 - BATCH {batch_counter}") BATCH_LOSS = [] for token_idx in range(0, SENTENCE_LENGTH): decoder_only_optim.zero_grad() pred_logits = DECODER_ONLY((enc_x, enc_x_pad)) pred_logits = pred_logits[:, token_idx, :] loss: torch.Tensor = decoder_ce(pred_logits, tgt[:, token_idx]) loss.backward() decoder_only_optim.step() BATCH_LOSS.append(loss.item()) if token_idx < SENTENCE_LENGTH - 1: dec_x[:, token_idx + 1] = tgt[:, token_idx] MIN_BATCH_LOSS = min(BATCH_LOSS) MAX_BATCH_LOSS = max(BATCH_LOSS) AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE decoder_batch_losses.append( [MIN_BATCH_LOSS, AVG_BATCH_LOSS, MAX_BATCH_LOSS] ) continue nano_scheduler.step() encoder_only_scheduler.step() decoder_only_scheduler.step() current_epoch += 1 if current_epoch % VALIDATION_STEPS == 0: NANOSOCRATES.eval() ENCODER_ONLY.eval() DECODER_ONLY.eval() txt_avg_batch_losses = [] enc_avg_batch_losses = [] dec_avg_batch_losses = [] for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE): src_x, tgt_y, pad_x, pad_y, tasktype = batch enc_x = torch.tensor(src_x) ACTUAL_BATCH_SIZE, _, _ = enc_x.shape enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) dec_x = Transformer.get_decoder_input( ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ) dec_x_pad = dec_x.eq(PAD_TOKEN) tgt = torch.tensor(tgt_y) tgt_pad = torch.tensor(pad_y, dtype=torch.bool) # Task 1 and Task 2 if ( tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF ): BATCH_LOSS = [] for token_idx in range(0, SENTENCE_LENGTH): pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad)) pred_logits = pred_logits[:, token_idx, :] loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt[:, token_idx]) BATCH_LOSS.append(loss.item()) if token_idx < SENTENCE_LENGTH - 1: dec_x[:, token_idx + 1] = tgt[:, token_idx] AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE txt_avg_batch_losses.append(AVG_BATCH_LOSS) continue # Pretrain first if current_epoch < PRETRAIN_EPOCHS: continue # Task 3 if tasktype == Batch.TaskType.MASKING: pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = encoder_ce(pred_logits, tgt) enc_avg_batch_losses.append(loss.item()) continue # Task 4 if tasktype == Batch.TaskType.COMPLETATION: BATCH_LOSS = [] for token_idx in range(0, SENTENCE_LENGTH): pred_logits = DECODER_ONLY((enc_x, enc_x_pad)) pred_logits = pred_logits[:, token_idx, :] loss: torch.Tensor = decoder_ce(pred_logits, tgt[:, token_idx]) BATCH_LOSS.append(loss.item()) if token_idx < SENTENCE_LENGTH - 1: dec_x[:, token_idx + 1] = tgt[:, token_idx] AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE dec_avg_batch_losses.append(AVG_BATCH_LOSS) continue txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses) enc_avg_loss = float("inf") dec_avg_loss = float("inf") if current_epoch >= PRETRAIN_EPOCHS: enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses) dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses) if current_epoch < PRETRAIN_EPOCHS: if txt_avg_loss < average_loss_validation["txt"]: average_loss_validation["txt"] = txt_avg_loss else: patience += 1 else: counter = 0 if txt_avg_loss > average_loss_validation["txt"]: counter += 1 if txt_avg_loss > average_loss_validation["encoder_only"]: counter += 1 if txt_avg_loss > average_loss_validation["decoder_only"]: counter += 1 if counter > 1: patience += 1 txt_min_train_losses = [row[0] for row in text_batch_losses] txt_avg_train_losses = [row[1] for row in text_batch_losses] txt_max_train_losses = [row[2] for row in text_batch_losses] txt_min_loss = min(txt_min_train_losses) txt_avg_min_loss = sum(txt_min_train_losses) / len(txt_min_train_losses) txt_max_loss = max(txt_max_train_losses) txt_avg_max_loss = sum(txt_max_train_losses) / len(txt_max_train_losses) txt_avg_loss = sum(txt_avg_train_losses) / len(txt_avg_train_losses) enc_avg_train_loss = float("inf") dec_min_loss = float("inf") dec_avg_min_loss = float("inf") dec_max_loss = float("inf") dec_avg_max_loss = float("inf") dec_avg_loss = float("inf") if current_epoch >= PRETRAIN_EPOCHS: enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses) dec_min_train_losses = [row[0] for row in decoder_batch_losses] dec_avg_train_losses = [row[1] for row in decoder_batch_losses] dec_max_train_losses = [row[2] for row in decoder_batch_losses] dec_min_loss = min(dec_min_train_losses) dec_avg_min_loss = sum(dec_min_train_losses) / len(dec_min_train_losses) dec_max_loss = max(dec_max_train_losses) dec_avg_max_loss = sum(dec_max_train_losses) / len(dec_max_train_losses) dec_avg_loss = sum(dec_avg_train_losses) / len(dec_avg_train_losses) SEPARATOR = "================================================================================================================" DEBUG_TEXT = "".join( [ f"{SEPARATOR}\n", f"EPOCH {current_epoch}\n", f"{SEPARATOR}\n", f"Train Losses:\n", f"\tMin Losses:\n", f"\t\tmin_txt: {txt_min_loss} - avg_txt: {txt_avg_min_loss}\n", f"\t\tmin_dec: {dec_min_loss} - avg_dec: {dec_avg_min_loss}\n", f"\tMax Losses:\n", f"\t\tmax_txt: {txt_max_loss} - avg_txt: {txt_avg_max_loss}\n", f"\t\tmax_dec: {dec_min_loss} - avg_dec: {dec_avg_max_loss}\n", f"\tAvg Losses:\n", f"\t\tavg_txt: {txt_avg_loss} - avg_enc: {enc_avg_loss} - avg_dec: {dec_avg_loss}\n", f"{SEPARATOR}\n", f"Validation Losses:\n", f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n", f"{SEPARATOR}\n", ] ) print(DEBUG_TEXT) # Warn about patience if patience == PATIENCE: print("Model is likely overfitting, so let's stop here") # SAVE MODEL if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE: print(f"Saving model at {CHECKPOINT_PATH.as_posix()}") torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)