From 56fbadd55e6c7c1b9d5fc089a6f367aabea7a22b Mon Sep 17 00:00:00 2001 From: Christian Risi <75698846+CnF-Gris@users.noreply.github.com> Date: Sun, 12 Oct 2025 00:05:30 +0200 Subject: [PATCH] Fixed training --- .../nanosocrates-train-experiment-2.py | 365 ++++++++++++++++++ 1 file changed, 365 insertions(+) create mode 100644 Playgrounds/nanosocrates-train-experiment-2.py diff --git a/Playgrounds/nanosocrates-train-experiment-2.py b/Playgrounds/nanosocrates-train-experiment-2.py new file mode 100644 index 0000000..31a68b8 --- /dev/null +++ b/Playgrounds/nanosocrates-train-experiment-2.py @@ -0,0 +1,365 @@ +import random +import sys +import torch +import pandas as pd +from pathlib import Path +import Project_Model.Libs.Embedder as Embedder +import Project_Model.Libs.BPE as BPE +import Project_Model.Libs.Transformer as Transformer +import Project_Model.Libs.TransformerUtils as TUtils +import Project_Model.Libs.TorchShims as torch_shims +import Project_Model.Libs.Batch as Batch + +# set a fixed seed +torch.manual_seed(0) +random.seed(0) + + +# set a default device +DEVICE = torch_shims.get_default_device() +torch.set_default_device(DEVICE) + + +# Get paths +VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json") +TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv") +VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv") +TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv") +CHECKPOINT_PATH = Path("Assets/Dataset/Tmp/NanoSocrates.zip") + + +# BPE Init +SPECIAL_VOC = BPE.default_special_tokens() +VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH) +TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC) + + +# Constants +MASK_EXTRA_SPACE = 100 +REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size +TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE +EMBEDDED_SIZE = 256 +FEED_FORWARD_MULTIPLIER = 4 +ATTENTION_HEADS = 8 +SENTENCE_LENGTH = 256 +NUMBER_OF_BLOCKS = 4 +MAX_EPOCHS = int(1e3) +PRETRAIN_EPOCHS = int(300) +WARMUP_EPOCHS = int(4e3) +MINI_BATCH_SIZE = 100 +VALIDATION_STEPS = 50 +CHECKPOINT_STEPS = VALIDATION_STEPS * 4 +PATIENCE = 4 +CURRENT_EPOCH = 0 + +SOS_TOKEN = TOKENANO.encode("")[0] + +PAD_TOKEN = TOKENANO.encode("")[0] +END_TOKEN = TOKENANO.encode("")[0] +SUBJ_TOKEN = TOKENANO.encode("")[0] +REL_TOKEN = TOKENANO.encode("")[0] +OBJ_TOKEN = TOKENANO.encode("")[0] + +SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens()))) +ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN]) +FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS + + +# Spanned_Masker +MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS) + +TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER) +VALIDATION_BATCHER = Batch.Batcher( + VALIDATION_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER +) +TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER) + + +# Model +NANOSOCRATES = Transformer.TrainingModel( + TOKEN_SPACE_SIZE, + EMBEDDED_SIZE, + FEED_FORWARD_MULTIPLIER, + ATTENTION_HEADS, + NUMBER_OF_BLOCKS, +) +_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates( + NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE +) + + +# Training constants +nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) +encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) +decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) +nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), 1) +encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), 1) +decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), 1) +nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE) +encoder_only_scheduler = Transformer.WarmupLR( + encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE +) +decoder_only_scheduler = Transformer.WarmupLR( + decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE +) + +current_epoch = CURRENT_EPOCH +patience = 0 + + +average_loss_validation = { + "txt": float("inf"), + "encoder_only": float("inf"), + "decoder_only": float("inf"), +} + +while current_epoch < MAX_EPOCHS: + + NANOSOCRATES.train() + ENCODER_ONLY.train() + DECODER_ONLY.train() + + text_batch_losses = [] + encoder_batch_losses = [] + decoder_batch_losses = [] + + batch_counter = 0 + + print(f"EPOCH {current_epoch} STARTING") + + for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE): + + batch_counter += 1 + + src_x, tgt_y, pad_x, pad_y, tasktype = batch + + enc_x = torch.tensor(src_x) + + ACTUAL_BATCH_SIZE, _ = enc_x.shape + enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) + tgt = torch.tensor(tgt_y) + tgt_pad = torch.tensor(pad_y, dtype=torch.bool) + + dec_x = Transformer.get_decoder_input( + ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH + ) + dec_x[:, 1:] = tgt[:, :-1] + dec_x_pad = dec_x.eq(PAD_TOKEN) + + print(f"\tBATCH {batch_counter} Starting") + + # Task 1 and Task 2 + if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF: + + print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}") + + nano_optim.zero_grad() + + pred_logits: torch.Tensor = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad)) + pred_logits = pred_logits.permute(0, 2, 1) + + loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt) + + loss.backward() + nano_optim.step() + + text_batch_losses.append(loss) + continue + + # Pretrain first + if current_epoch < PRETRAIN_EPOCHS: + continue + + # Task 3 + if tasktype == Batch.TaskType.MASKING: + + print(f"\tExecuting TASK 3 - BATCH {batch_counter}") + + encoder_only_optim.zero_grad() + + pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) + pred_logits = pred_logits.permute(0, 2, 1) + print(torch.max(tgt)) + loss: torch.Tensor = encoder_ce(pred_logits, tgt) + + loss.backward() + encoder_only_optim.step() + + encoder_batch_losses.append(loss.item()) + + continue + + # Task 4 + if tasktype == Batch.TaskType.COMPLETATION: + + print(f"\tExecuting TASK 4 - BATCH {batch_counter}") + + decoder_only_optim.zero_grad() + + pred_logits = DECODER_ONLY((enc_x, enc_x_pad)) + pred_logits = pred_logits.permute(0, 2, 1) + + loss: torch.Tensor = decoder_ce(pred_logits, tgt) + + loss.backward() + decoder_only_optim.step() + + decoder_batch_losses.append( + loss + ) + + continue + + nano_scheduler.step() + encoder_only_scheduler.step() + decoder_only_scheduler.step() + + current_epoch += 1 + + if current_epoch % VALIDATION_STEPS == 0: + + NANOSOCRATES.eval() + ENCODER_ONLY.eval() + DECODER_ONLY.eval() + + txt_avg_batch_losses = [] + enc_avg_batch_losses = [] + dec_avg_batch_losses = [] + + for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE): + + src_x, tgt_y, pad_x, pad_y, tasktype = batch + + enc_x = torch.tensor(src_x) + + ACTUAL_BATCH_SIZE, _ = enc_x.shape + enc_x_pad = torch.tensor(pad_x, dtype=torch.bool) + tgt = torch.tensor(tgt_y) + tgt_pad = torch.tensor(pad_y, dtype=torch.bool) + + dec_x = Transformer.get_decoder_input( + ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH + ) + dec_x[:, 1:] = tgt[:, :-1] + dec_x_pad = dec_x.eq(PAD_TOKEN) + + # Task 1 and Task 2 + if ( + tasktype == Batch.TaskType.RDF2TXT + or tasktype == Batch.TaskType.TEXT2RDF + ): + + + pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad)) + + pred_logits = pred_logits.permute(0, 2, 1) + + loss: torch.Tensor = nano_cross_entropy( + pred_logits, tgt + ) + + txt_avg_batch_losses.append(loss) + + continue + + # Pretrain first + if current_epoch <= PRETRAIN_EPOCHS: + continue + + # Task 3 + if tasktype == Batch.TaskType.MASKING: + + pred_logits = ENCODER_ONLY((enc_x, enc_x_pad)) + pred_logits = pred_logits.permute(0, 2, 1) + + loss: torch.Tensor = encoder_ce(pred_logits, tgt) + + enc_avg_batch_losses.append(loss.item()) + + continue + + # Task 4 + if tasktype == Batch.TaskType.COMPLETATION: + + + + pred_logits = DECODER_ONLY((enc_x, enc_x_pad)) + + pred_logits = pred_logits.permute(0, 2, 1) + + loss: torch.Tensor = decoder_ce(pred_logits, tgt) + + + dec_avg_batch_losses.append(loss) + + continue + + txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses) + enc_avg_loss = float("inf") + dec_avg_loss = float("inf") + + if current_epoch > PRETRAIN_EPOCHS: + enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses) + dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses) + + if current_epoch < PRETRAIN_EPOCHS: + + if txt_avg_loss < average_loss_validation["txt"]: + average_loss_validation["txt"] = txt_avg_loss + else: + patience += 1 + else: + + counter = 0 + + if txt_avg_loss > average_loss_validation["txt"]: + counter += 1 + + if txt_avg_loss > average_loss_validation["encoder_only"]: + counter += 1 + + if txt_avg_loss > average_loss_validation["decoder_only"]: + counter += 1 + + if counter > 1: + patience += 1 + + + txt_avg_loss = sum(text_batch_losses) / len(text_batch_losses) + + enc_avg_train_loss = float("inf") + dec_avg_loss = float("inf") + + if current_epoch > PRETRAIN_EPOCHS: + try: + enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses) + dec_avg_loss = sum(decoder_batch_losses) / len(decoder_batch_losses) + except: + pass + + SEPARATOR = "================================================================================================================" + DEBUG_TEXT = "".join( + [ + f"{SEPARATOR}\n", + f"EPOCH {current_epoch}\n", + f"{SEPARATOR}\n", + f"Train Losses:\n", + f"\tAvg Losses:\n", + f"\t\tavg_txt: {txt_avg_loss} - avg_enc: {enc_avg_loss} - avg_dec: {dec_avg_loss}\n", + f"{SEPARATOR}\n", + f"Validation Losses:\n", + f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n", + f"{SEPARATOR}\n", + ] + ) + + print(DEBUG_TEXT) + + # Warn about patience + if patience == PATIENCE: + print("Model is likely overfitting, so let's stop here") + + # SAVE MODEL + if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE: + print(f"Saving model at {CHECKPOINT_PATH.as_posix()}") + torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)