import random import sys import torch import pandas as pd from pathlib import Path import Project_Model.Libs.Embedder as Embedder import Project_Model.Libs.BPE as BPE import Project_Model.Libs.Transformer as Transformer import Project_Model.Libs.TransformerUtils as TUtils import Project_Model.Libs.TorchShims as torch_shims import Project_Model.Libs.Batch as Batch # set a fixed seed torch.manual_seed(0) random.seed(0) # set a default device DEVICE = torch_shims.get_default_device() torch.set_default_device(DEVICE) # Get paths VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json") TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv") VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv") TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv") CHECKPOINT_PATH = Path("Assets/Dataset/Tmp/NanoSocrates.zip") # BPE Init SPECIAL_VOC = BPE.default_special_tokens() VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH) TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC) # Constants MASK_EXTRA_SPACE = 100 REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE EMBEDDED_SIZE = 256 FEED_FORWARD_MULTIPLIER = 4 ATTENTION_HEADS = 8 SENTENCE_LENGTH = 256 NUMBER_OF_BLOCKS = 4 MAX_EPOCHS = int(3e3) PRETRAIN_EPOCHS = int(300) WARMUP_EPOCHS = int(1e3) MINI_BATCH_SIZE = 300 VALIDATION_STEPS = 5 CHECKPOINT_STEPS = VALIDATION_STEPS * 4 PATIENCE = 4 CURRENT_EPOCH = 0 VERBOSE = True LEARNING_RATE = 1.5 SOS_TOKEN = TOKENANO.encode("")[0] PAD_TOKEN = TOKENANO.encode("")[0] END_TOKEN = TOKENANO.encode("")[0] SUBJ_TOKEN = TOKENANO.encode("")[0] REL_TOKEN = TOKENANO.encode("")[0] OBJ_TOKEN = TOKENANO.encode("")[0] SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens()))) ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN]) FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS # Spanned_Masker MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS) TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER) VALIDATION_BATCHER = Batch.Batcher( VALIDATION_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER ) TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER) # Model NANOSOCRATES = Transformer.TrainingModel( TOKEN_SPACE_SIZE, EMBEDDED_SIZE, FEED_FORWARD_MULTIPLIER, ATTENTION_HEADS, NUMBER_OF_BLOCKS, ) _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates( NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE ) # Training constants nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE) encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE) decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE) nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE) encoder_only_scheduler = Transformer.WarmupLR( encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE ) decoder_only_scheduler = Transformer.WarmupLR( decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE ) current_epoch = CURRENT_EPOCH patience = 0 average_loss_validation = { "txt": float("inf"), "encoder_only": float("inf"), "decoder_only": float("inf"), } while current_epoch < MAX_EPOCHS: NANOSOCRATES.train() ENCODER_ONLY.train() DECODER_ONLY.train() text_batch_losses = [] encoder_batch_losses = [] decoder_batch_losses = [] batch_counter = 0 if VERBOSE: print(f"EPOCH {current_epoch} STARTING") for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE): batch_counter += 1 src_x, tgt_y, pad_x, pad_y, tasktype = batch enc_x = torch.tensor(src_x) ACTUAL_BATCH_SIZE, _ = enc_x.shape enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) tgt = torch.tensor(tgt_y) tgt_pad = torch.tensor(pad_y, dtype=torch.bool) dec_x = Transformer.get_decoder_input( ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ) dec_x[:, 1:] = tgt[:, :-1] dec_x_pad = dec_x.eq(PAD_TOKEN) if VERBOSE: print(f"\tBATCH {batch_counter} Starting") # Task 1 and Task 2 if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF: if VERBOSE: print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}") nano_optim.zero_grad() pred_logits: torch.Tensor = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt) loss.backward() nano_optim.step() text_batch_losses.append(loss) continue # Pretrain first if current_epoch < PRETRAIN_EPOCHS: continue # Task 3 if tasktype == Batch.TaskType.MASKING: if VERBOSE: print(f"\tExecuting TASK 3 - BATCH {batch_counter}") encoder_only_optim.zero_grad() pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) print(torch.max(tgt)) loss: torch.Tensor = encoder_ce(pred_logits, tgt) loss.backward() encoder_only_optim.step() encoder_batch_losses.append(loss.item()) continue # Task 4 if tasktype == Batch.TaskType.COMPLETATION: if VERBOSE: print(f"\tExecuting TASK 4 - BATCH {batch_counter}") decoder_only_optim.zero_grad() pred_logits = DECODER_ONLY((enc_x, enc_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = decoder_ce(pred_logits, tgt) loss.backward() decoder_only_optim.step() decoder_batch_losses.append( loss ) continue nano_scheduler.step() encoder_only_scheduler.step() decoder_only_scheduler.step() current_epoch += 1 if current_epoch % VALIDATION_STEPS == 0: NANOSOCRATES.eval() ENCODER_ONLY.eval() DECODER_ONLY.eval() with torch.no_grad(): txt_avg_batch_losses = [] enc_avg_batch_losses = [] dec_avg_batch_losses = [] for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE): src_x, tgt_y, pad_x, pad_y, tasktype = batch enc_x = torch.tensor(src_x) ACTUAL_BATCH_SIZE, _ = enc_x.shape enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) tgt = torch.tensor(tgt_y) tgt_pad = torch.tensor(pad_y, dtype=torch.bool) dec_x = Transformer.get_decoder_input( ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ) dec_x[:, 1:] = tgt[:, :-1] dec_x_pad = dec_x.eq(PAD_TOKEN) # Task 1 and Task 2 if ( tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF ): pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = nano_cross_entropy( pred_logits, tgt ) txt_avg_batch_losses.append(loss) continue # Pretrain first if current_epoch <= PRETRAIN_EPOCHS: continue # Task 3 if tasktype == Batch.TaskType.MASKING: pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = encoder_ce(pred_logits, tgt) enc_avg_batch_losses.append(loss.item()) continue # Task 4 if tasktype == Batch.TaskType.COMPLETATION: pred_logits = DECODER_ONLY((enc_x, enc_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = decoder_ce(pred_logits, tgt) dec_avg_batch_losses.append(loss) continue txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses) enc_avg_loss = float("inf") dec_avg_loss = float("inf") if current_epoch > PRETRAIN_EPOCHS: enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses) dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses) if current_epoch < PRETRAIN_EPOCHS: if txt_avg_loss < average_loss_validation["txt"]: average_loss_validation["txt"] = txt_avg_loss else: patience += 1 else: counter = 0 if txt_avg_loss > average_loss_validation["txt"]: counter += 1 if txt_avg_loss > average_loss_validation["encoder_only"]: counter += 1 if txt_avg_loss > average_loss_validation["decoder_only"]: counter += 1 if counter > 1: patience += 1 txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses) enc_avg_train_loss = float("inf") dec_avg_train_loss = float("inf") if current_epoch > PRETRAIN_EPOCHS: try: enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses) dec_avg_train_loss = sum(decoder_batch_losses) / len(decoder_batch_losses) except: pass SEPARATOR = "================================================================================================================" DEBUG_TEXT = "".join( [ f"{SEPARATOR}\n", f"EPOCH {current_epoch}\n", f"{SEPARATOR}\n", f"Train Losses:\n", f"\tAvg Losses:\n", f"\t\tavg_txt: {txt_train_avg_loss} - avg_enc: {enc_avg_train_loss} - avg_dec: {dec_avg_train_loss}\n", f"{SEPARATOR}\n", f"Validation Losses:\n", f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n", f"{SEPARATOR}\n", ] ) print(DEBUG_TEXT) # Warn about patience if patience == PATIENCE: print("Model is likely overfitting, so let's stop here") # SAVE MODEL if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE: print(f"Saving model at {CHECKPOINT_PATH.as_posix()}") torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)