NanoSocrates/Playgrounds/nanosocrates-train-experiment-2.py

386 lines
11 KiB
Python
Raw Normal View History

2025-10-12 00:05:30 +02:00
import random
import sys
import torch
import pandas as pd
from pathlib import Path
import Project_Model.Libs.Embedder as Embedder
import Project_Model.Libs.BPE as BPE
import Project_Model.Libs.Transformer as Transformer
import Project_Model.Libs.TransformerUtils as TUtils
import Project_Model.Libs.TorchShims as torch_shims
import Project_Model.Libs.Batch as Batch
# set a fixed seed
torch.manual_seed(0)
random.seed(0)
# set a default device
DEVICE = torch_shims.get_default_device()
torch.set_default_device(DEVICE)
# Get paths
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
CHECKPOINT_PATH = Path("Assets/Dataset/Tmp/NanoSocrates.zip")
# BPE Init
SPECIAL_VOC = BPE.default_special_tokens()
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)
# Constants
MASK_EXTRA_SPACE = 100
REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
EMBEDDED_SIZE = 256
FEED_FORWARD_MULTIPLIER = 4
ATTENTION_HEADS = 8
SENTENCE_LENGTH = 256
NUMBER_OF_BLOCKS = 4
2025-10-12 00:47:20 +02:00
MAX_EPOCHS = int(3e3)
2025-10-12 00:05:30 +02:00
PRETRAIN_EPOCHS = int(300)
2025-10-12 00:47:20 +02:00
WARMUP_EPOCHS = int(1e3)
2025-10-12 01:16:09 +02:00
MINI_BATCH_SIZE = 80
2025-10-12 00:57:24 +02:00
VALIDATION_STEPS = 5
2025-10-12 00:05:30 +02:00
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
PATIENCE = 4
CURRENT_EPOCH = 0
2025-10-12 00:57:24 +02:00
VERBOSE = True
2025-10-12 00:47:20 +02:00
LEARNING_RATE = 1.5
2025-10-12 00:05:30 +02:00
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
PAD_TOKEN = TOKENANO.encode("<PAD>")[0]
END_TOKEN = TOKENANO.encode("<END>")[0]
SUBJ_TOKEN = TOKENANO.encode("<SUBJ>")[0]
REL_TOKEN = TOKENANO.encode("<PRED>")[0]
OBJ_TOKEN = TOKENANO.encode("<OBJ>")[0]
SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens())))
ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN])
FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS
# Spanned_Masker
MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS)
TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
VALIDATION_BATCHER = Batch.Batcher(
VALIDATION_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER
)
TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
# Model
2025-10-12 12:28:24 +02:00
2025-10-12 00:05:30 +02:00
NANOSOCRATES = Transformer.TrainingModel(
TOKEN_SPACE_SIZE,
EMBEDDED_SIZE,
FEED_FORWARD_MULTIPLIER,
ATTENTION_HEADS,
NUMBER_OF_BLOCKS,
)
2025-10-12 12:28:24 +02:00
if CHECKPOINT_PATH.is_file():
nanosocrates_dict = torch.load(CHECKPOINT_PATH, weights_only=True)
NANOSOCRATES.load_state_dict(nanosocrates_dict)
2025-10-12 00:05:30 +02:00
_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE
)
# Training constants
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
2025-10-12 00:47:20 +02:00
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
2025-10-12 00:05:30 +02:00
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
encoder_only_scheduler = Transformer.WarmupLR(
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
)
decoder_only_scheduler = Transformer.WarmupLR(
decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
)
current_epoch = CURRENT_EPOCH
patience = 0
average_loss_validation = {
"txt": float("inf"),
"encoder_only": float("inf"),
"decoder_only": float("inf"),
}
while current_epoch < MAX_EPOCHS:
NANOSOCRATES.train()
ENCODER_ONLY.train()
DECODER_ONLY.train()
text_batch_losses = []
encoder_batch_losses = []
decoder_batch_losses = []
batch_counter = 0
2025-10-12 00:13:03 +02:00
if VERBOSE:
print(f"EPOCH {current_epoch} STARTING")
2025-10-12 00:05:30 +02:00
for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):
batch_counter += 1
src_x, tgt_y, pad_x, pad_y, tasktype = batch
enc_x = torch.tensor(src_x)
ACTUAL_BATCH_SIZE, _ = enc_x.shape
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
tgt = torch.tensor(tgt_y)
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
dec_x = Transformer.get_decoder_input(
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
)
dec_x[:, 1:] = tgt[:, :-1]
dec_x_pad = dec_x.eq(PAD_TOKEN)
2025-10-12 00:13:03 +02:00
if VERBOSE:
print(f"\tBATCH {batch_counter} Starting")
2025-10-12 00:05:30 +02:00
# Task 1 and Task 2
if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:
2025-10-12 00:15:15 +02:00
if VERBOSE:
print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}")
2025-10-12 00:05:30 +02:00
nano_optim.zero_grad()
pred_logits: torch.Tensor = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)
loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt)
loss.backward()
nano_optim.step()
text_batch_losses.append(loss)
continue
# Pretrain first
if current_epoch < PRETRAIN_EPOCHS:
continue
# Task 3
if tasktype == Batch.TaskType.MASKING:
2025-10-12 00:13:03 +02:00
if VERBOSE:
print(f"\tExecuting TASK 3 - BATCH {batch_counter}")
2025-10-12 00:05:30 +02:00
encoder_only_optim.zero_grad()
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)
2025-10-12 01:16:09 +02:00
# print(torch.max(tgt))
2025-10-12 00:05:30 +02:00
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
loss.backward()
encoder_only_optim.step()
encoder_batch_losses.append(loss.item())
continue
# Task 4
if tasktype == Batch.TaskType.COMPLETATION:
2025-10-12 00:13:03 +02:00
if VERBOSE:
print(f"\tExecuting TASK 4 - BATCH {batch_counter}")
2025-10-12 00:05:30 +02:00
decoder_only_optim.zero_grad()
2025-10-12 12:22:38 +02:00
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
2025-10-12 00:05:30 +02:00
pred_logits = pred_logits.permute(0, 2, 1)
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
loss.backward()
decoder_only_optim.step()
decoder_batch_losses.append(
loss
)
continue
nano_scheduler.step()
encoder_only_scheduler.step()
decoder_only_scheduler.step()
current_epoch += 1
if current_epoch % VALIDATION_STEPS == 0:
NANOSOCRATES.eval()
ENCODER_ONLY.eval()
DECODER_ONLY.eval()
2025-10-12 00:47:20 +02:00
with torch.no_grad():
txt_avg_batch_losses = []
enc_avg_batch_losses = []
dec_avg_batch_losses = []
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
src_x, tgt_y, pad_x, pad_y, tasktype = batch
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
enc_x = torch.tensor(src_x)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
ACTUAL_BATCH_SIZE, _ = enc_x.shape
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
tgt = torch.tensor(tgt_y)
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
dec_x = Transformer.get_decoder_input(
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
)
dec_x[:, 1:] = tgt[:, :-1]
dec_x_pad = dec_x.eq(PAD_TOKEN)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
# Task 1 and Task 2
if (
tasktype == Batch.TaskType.RDF2TXT
or tasktype == Batch.TaskType.TEXT2RDF
):
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
pred_logits = pred_logits.permute(0, 2, 1)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
loss: torch.Tensor = nano_cross_entropy(
pred_logits, tgt
)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
txt_avg_batch_losses.append(loss)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
continue
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
# Pretrain first
if current_epoch <= PRETRAIN_EPOCHS:
continue
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
# Task 3
if tasktype == Batch.TaskType.MASKING:
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
enc_avg_batch_losses.append(loss.item())
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
continue
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
# Task 4
if tasktype == Batch.TaskType.COMPLETATION:
2025-10-12 00:05:30 +02:00
2025-10-12 12:22:38 +02:00
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
pred_logits = pred_logits.permute(0, 2, 1)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
dec_avg_batch_losses.append(loss)
2025-10-12 00:05:30 +02:00
2025-10-12 00:47:20 +02:00
continue
2025-10-12 00:05:30 +02:00
txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)
enc_avg_loss = float("inf")
dec_avg_loss = float("inf")
if current_epoch > PRETRAIN_EPOCHS:
enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses)
dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses)
if current_epoch < PRETRAIN_EPOCHS:
if txt_avg_loss < average_loss_validation["txt"]:
average_loss_validation["txt"] = txt_avg_loss
else:
patience += 1
else:
counter = 0
if txt_avg_loss > average_loss_validation["txt"]:
counter += 1
if txt_avg_loss > average_loss_validation["encoder_only"]:
counter += 1
if txt_avg_loss > average_loss_validation["decoder_only"]:
counter += 1
if counter > 1:
patience += 1
2025-10-12 01:41:34 +02:00
2025-10-12 01:22:06 +02:00
if counter == 0:
patience = max(0, patience - 1)
2025-10-12 00:05:30 +02:00
2025-10-12 00:57:24 +02:00
txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)
2025-10-12 00:05:30 +02:00
enc_avg_train_loss = float("inf")
2025-10-12 00:57:24 +02:00
dec_avg_train_loss = float("inf")
2025-10-12 00:05:30 +02:00
if current_epoch > PRETRAIN_EPOCHS:
try:
enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses)
2025-10-12 00:57:24 +02:00
dec_avg_train_loss = sum(decoder_batch_losses) / len(decoder_batch_losses)
2025-10-12 00:05:30 +02:00
except:
pass
SEPARATOR = "================================================================================================================"
DEBUG_TEXT = "".join(
[
f"{SEPARATOR}\n",
f"EPOCH {current_epoch}\n",
f"{SEPARATOR}\n",
f"Train Losses:\n",
f"\tAvg Losses:\n",
2025-10-12 00:57:24 +02:00
f"\t\tavg_txt: {txt_train_avg_loss} - avg_enc: {enc_avg_train_loss} - avg_dec: {dec_avg_train_loss}\n",
2025-10-12 00:05:30 +02:00
f"{SEPARATOR}\n",
f"Validation Losses:\n",
2025-10-12 01:41:34 +02:00
f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction_loss: {dec_avg_loss}\n",
2025-10-12 00:05:30 +02:00
f"{SEPARATOR}\n",
]
)
print(DEBUG_TEXT)
# Warn about patience
if patience == PATIENCE:
print("Model is likely overfitting, so let's stop here")
# SAVE MODEL
if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE:
print(f"Saving model at {CHECKPOINT_PATH.as_posix()}")
torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)
2025-10-12 01:41:34 +02:00
if patience == PATIENCE:
exit(0)