import random import sys import torch import pandas as pd from pathlib import Path import Project_Model.Libs.Embedder as Embedder import Project_Model.Libs.BPE as BPE import Project_Model.Libs.Transformer as Transformer import Project_Model.Libs.TransformerUtils as TUtils import Project_Model.Libs.TorchShims as torch_shims import Project_Model.Libs.Batch as Batch from Project_Model.Libs.Training.loss_saver import Log # set a fixed seed torch.manual_seed(0) random.seed(0) # set a default device DEVICE = torch_shims.get_default_device() torch.set_default_device(DEVICE) # Get paths CHECKPOINT_DIR = "Assets/Dataset/Tmp" VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json") TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv") VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv") TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv") CHECKPOINT_PATH = Path(f"{CHECKPOINT_DIR}/NanoSocrates.zip") NANO_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/nano_optim.zip") ENC_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/enc_optim.zip") DEC_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/dec_optim.zip") LAST_EPOCH_PATH = Path(f"{CHECKPOINT_DIR}/last_epoch.txt") # log saver: loss_saver = Log(f"{CHECKPOINT_DIR}/log_loss.csv") # BPE Init SPECIAL_VOC = BPE.default_special_tokens() VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH) TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC) # Constants MASK_EXTRA_SPACE = 100 REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE EMBEDDED_SIZE = 256 FEED_FORWARD_MULTIPLIER = 4 ATTENTION_HEADS = 4 SENTENCE_LENGTH = 256 NUMBER_OF_BLOCKS = 2 MAX_EPOCHS = int(300) PRETRAIN_EPOCHS = int(20) WARMUP_EPOCHS = int(30) MINI_BATCH_SIZE = 20 VALIDATION_STEPS = 10 CHECKPOINT_STEPS = VALIDATION_STEPS PATIENCE = 4 CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text()) VERBOSE = False LEARNING_RATE = 0.05 LABEL_SMOOTHING = 0.01 SOS_TOKEN = TOKENANO.encode("")[0] PAD_TOKEN = TOKENANO.encode("")[0] END_TOKEN = TOKENANO.encode("")[0] SUBJ_TOKEN = TOKENANO.encode("")[0] REL_TOKEN = TOKENANO.encode("")[0] OBJ_TOKEN = TOKENANO.encode("")[0] MASK_TOKEN = TOKENANO.encode("")[0] SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens()))) ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN]) FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS # Spanned_Masker MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS, average_span=4) TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER) VALIDATION_BATCHER = Batch.Batcher( VALIDATION_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER ) TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER) # Model NANOSOCRATES = Transformer.TrainingModel( TOKEN_SPACE_SIZE, EMBEDDED_SIZE, FEED_FORWARD_MULTIPLIER, ATTENTION_HEADS, NUMBER_OF_BLOCKS, ) if CHECKPOINT_PATH.is_file(): nanosocrates_dict = torch.load(CHECKPOINT_PATH, weights_only=True) NANOSOCRATES.load_state_dict(nanosocrates_dict) _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates( NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE ) # Training constants nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING) encoder_ce = torch.nn.CrossEntropyLoss( label_smoothing=LABEL_SMOOTHING) decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING) nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE) encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE) decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE) if NANO_OPTIM_PATH.is_file(): optim_dict = torch.load(NANO_OPTIM_PATH) nano_optim.load_state_dict(optim_dict) if ENC_OPTIM_PATH.is_file(): optim_dict = torch.load(ENC_OPTIM_PATH) encoder_only_optim.load_state_dict(optim_dict) if DEC_OPTIM_PATH.is_file(): optim_dict = torch.load(DEC_OPTIM_PATH) decoder_only_optim.load_state_dict(optim_dict) nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE, last_epoch=CURRENT_EPOCH) encoder_only_scheduler = Transformer.WarmupLR( encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE, last_epoch=CURRENT_EPOCH ) decoder_only_scheduler = Transformer.WarmupLR( decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE, last_epoch=CURRENT_EPOCH ) current_epoch = CURRENT_EPOCH + 2 patience = 0 average_loss_validation = { "txt": float("inf"), "encoder_only": float("inf"), "decoder_only": float("inf"), } while current_epoch < MAX_EPOCHS: NANOSOCRATES.train() ENCODER_ONLY.train() DECODER_ONLY.train() text_batch_losses = [] encoder_batch_losses = [] decoder_batch_losses = [] batch_counter = 0 if VERBOSE: print(f"EPOCH {current_epoch} STARTING") for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE): batch_counter += 1 src_x, tgt_y, pad_x, pad_y, tasktype = batch enc_x = torch.tensor(src_x) ACTUAL_BATCH_SIZE, _ = enc_x.shape enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) tgt = torch.tensor(tgt_y) tgt_pad = torch.tensor(pad_y, dtype=torch.bool) dec_x = Transformer.get_decoder_input( ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ) dec_x[:, 1:] = tgt[:, :-1] dec_x_pad = dec_x.eq(PAD_TOKEN) if VERBOSE: for s in TUtils.decode_batch(enc_x, TOKENANO, MASK_TOKEN): print("Input") print(s) for s in TUtils.decode_batch(enc_x_pad, TOKENANO, MASK_TOKEN): print("Encoder Padding mask") print(s) for s in TUtils.decode_batch(tgt, TOKENANO, MASK_TOKEN): print("Desired Output") print(s) a_dx = dec_x[:,:] a_dx[:, -1]= END_TOKEN for s in TUtils.decode_batch(a_dx, TOKENANO, MASK_TOKEN): print("Decoder Input") print(s) if VERBOSE: print(f"\tBATCH {batch_counter} Starting") # Task 1 and Task 2 if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF: if VERBOSE: print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}") nano_optim.zero_grad() pred_logits: torch.Tensor = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt) loss.backward() nano_optim.step() text_batch_losses.append(loss) continue # Pretrain first if current_epoch < PRETRAIN_EPOCHS: continue # Task 3 if tasktype == Batch.TaskType.MASKING: if VERBOSE: print(f"\tExecuting TASK 3 - BATCH {batch_counter}") encoder_only_optim.zero_grad() pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) # print(torch.max(tgt)) loss: torch.Tensor = encoder_ce(pred_logits, tgt) loss.backward() encoder_only_optim.step() exp_tokens: list[int] = tgt_y[0] exp_tokens = list(map(lambda x: MASK_TOKEN if x > TOKENANO.vocabulary_size else x, exp_tokens)) exp_string = TOKENANO.decode(exp_tokens) enc_tokens: list[int] = src_x[0] enc_tokens = list(map(lambda x: MASK_TOKEN if x > TOKENANO.vocabulary_size else x, enc_tokens)) enc_string = TOKENANO.decode(enc_tokens) print(f"PROMPT:\n{enc_string}") print(f"EXPECTED:\n{exp_string}") encoder_batch_losses.append(loss.item()) continue # Task 4 if tasktype == Batch.TaskType.COMPLETATION: if VERBOSE: print(f"\tExecuting TASK 4 - BATCH {batch_counter}") decoder_only_optim.zero_grad() pred_logits = DECODER_ONLY((dec_x, enc_x_pad, dec_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = decoder_ce(pred_logits, tgt) loss.backward() decoder_only_optim.step() decoder_batch_losses.append( loss ) continue nano_scheduler.step() encoder_only_scheduler.step() decoder_only_scheduler.step() current_epoch += 1 if current_epoch % VALIDATION_STEPS == 0: NANOSOCRATES.eval() ENCODER_ONLY.eval() DECODER_ONLY.eval() with torch.no_grad(): txt_avg_batch_losses = [] enc_avg_batch_losses = [] dec_avg_batch_losses = [] for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE): src_x, tgt_y, pad_x, pad_y, tasktype = batch enc_x = torch.tensor(src_x) ACTUAL_BATCH_SIZE, _ = enc_x.shape enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) tgt = torch.tensor(tgt_y) tgt_pad = torch.tensor(pad_y, dtype=torch.bool) dec_x = Transformer.get_decoder_input( ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH ) dec_x[:, 1:] = tgt[:, :-1] dec_x_pad = dec_x.eq(PAD_TOKEN) # Task 1 and Task 2 if ( tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF ): pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = nano_cross_entropy( pred_logits, tgt ) txt_avg_batch_losses.append(loss) continue # Pretrain first if current_epoch <= PRETRAIN_EPOCHS: continue # Task 3 if tasktype == Batch.TaskType.MASKING: pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = encoder_ce(pred_logits, tgt) enc_avg_batch_losses.append(loss.item()) continue # Task 4 if tasktype == Batch.TaskType.COMPLETATION: pred_logits = DECODER_ONLY((dec_x, enc_x_pad, dec_x_pad)) pred_logits = pred_logits.permute(0, 2, 1) loss: torch.Tensor = decoder_ce(pred_logits, tgt) dec_avg_batch_losses.append(loss) continue txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses) enc_avg_loss = float("inf") dec_avg_loss = float("inf") if current_epoch > PRETRAIN_EPOCHS: enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses) dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses) if current_epoch < PRETRAIN_EPOCHS: if txt_avg_loss < average_loss_validation["txt"]: average_loss_validation["txt"] = txt_avg_loss else: patience += 1 if VERBOSE: print(f"losing a patience, current irritation: {patience}") else: counter = 0 if txt_avg_loss > average_loss_validation["txt"]: if VERBOSE: print("txt average is higher than lowest") counter += 1 else: average_loss_validation["txt"] = txt_avg_loss if enc_avg_loss > average_loss_validation["encoder_only"]: if VERBOSE: print("masking average is higher than lowest") counter += 1 else: average_loss_validation["encoder_only"] = enc_avg_loss if dec_avg_loss > average_loss_validation["decoder_only"]: if VERBOSE: print("decoding only average is higher than lowest") counter += 1 else: average_loss_validation["decoder_only"] = dec_avg_loss if counter > 1: patience += 1 if VERBOSE: print(f"losing a patience, current irritation: {patience}") if counter == 0: patience = max(0, patience - 1) if VERBOSE: print(f"all good, gaining a patience, current irritation: {patience}") txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses) enc_avg_train_loss = float("inf") dec_avg_train_loss = float("inf") if current_epoch > PRETRAIN_EPOCHS: try: enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses) dec_avg_train_loss = sum(decoder_batch_losses) / len(decoder_batch_losses) except: pass # write on log loss_saver.write([current_epoch, txt_train_avg_loss,enc_avg_train_loss,dec_avg_train_loss,txt_avg_loss,enc_avg_loss,dec_avg_loss]) SEPARATOR = "================================================================================================================" DEBUG_TEXT = "".join( [ f"{SEPARATOR}\n", f"EPOCH {current_epoch}\n", f"{SEPARATOR}\n", f"Train Losses:\n", f"\tAvg Losses:\n", f"\t\tavg_txt: {txt_train_avg_loss} - avg_enc: {enc_avg_train_loss} - avg_dec: {dec_avg_train_loss}\n", f"{SEPARATOR}\n", f"Validation Losses:\n", f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction_loss: {dec_avg_loss}\n", f"{SEPARATOR}\n", ] ) print(DEBUG_TEXT) # Warn about patience if patience == PATIENCE: print("Model is likely overfitting, so let's stop here") # SAVE MODEL if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE: print(f"Saving model at {CHECKPOINT_PATH.as_posix()}") torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH) torch.save(nano_optim.state_dict(), NANO_OPTIM_PATH) torch.save(encoder_only_optim.state_dict(), ENC_OPTIM_PATH) torch.save(decoder_only_optim.state_dict(), DEC_OPTIM_PATH) FILE = open(LAST_EPOCH_PATH, "w", encoding="utf-8") FILE.write(f"{current_epoch}") FILE.close() if patience == PATIENCE: exit(0)