NanoSocrates/Playgrounds/nanosocrates-train.ipynb

510 lines
20 KiB
Plaintext
Raw Normal View History

2025-10-11 19:35:43 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "adbef43f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Chris\\miniconda3\\envs\\deep_learning\\Lib\\site-packages\\torch\\utils\\_device.py:103: UserWarning: Aten Op fallback from XPU to CPU happends. This may have performance implications. If need debug the fallback ops please set environment variable `PYTORCH_DEBUG_XPU_FALLBACK=1` (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\build\\xpu\\ATen\\RegisterXPU_0.cpp:54528.)\n",
" return func(*args, **kwargs)\n",
"c:\\Users\\Chris\\miniconda3\\envs\\deep_learning\\Lib\\site-packages\\torch\\optim\\lr_scheduler.py:192: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
" warnings.warn(\n"
]
},
{
"ename": "IndexError",
"evalue": "list index out of range",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mIndexError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 383\u001b[39m\n\u001b[32m 381\u001b[39m txt_min_train_losses = text_batch_losses[:][\u001b[32m0\u001b[39m]\n\u001b[32m 382\u001b[39m txt_avg_train_losses = text_batch_losses[:][\u001b[32m1\u001b[39m]\n\u001b[32m--> \u001b[39m\u001b[32m383\u001b[39m txt_max_train_losses = \u001b[43mtext_batch_losses\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 385\u001b[39m txt_min_loss = \u001b[38;5;28mmin\u001b[39m(txt_min_train_losses)\n\u001b[32m 386\u001b[39m txt_avg_min_loss = \u001b[38;5;28msum\u001b[39m(txt_min_train_losses) / \u001b[38;5;28mlen\u001b[39m(txt_min_train_losses)\n",
"\u001b[31mIndexError\u001b[39m: list index out of range"
]
}
],
"source": [
"import random\n",
"import sys\n",
"import torch\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"import Project_Model.Libs.Embedder as Embedder\n",
"import Project_Model.Libs.BPE as BPE\n",
"import Project_Model.Libs.Transformer as Transformer\n",
"import Project_Model.Libs.TransformerUtils as TUtils\n",
"import Project_Model.Libs.TorchShims as torch_shims\n",
"import Project_Model.Libs.Batch as Batch\n",
"\n",
"# set a fixed seed\n",
"torch.manual_seed(0)\n",
"random.seed(0)\n",
"\n",
"\n",
"# set a default device\n",
"DEVICE = torch_shims.get_default_device()\n",
"torch.set_default_device(DEVICE)\n",
"\n",
"\n",
"# Get paths\n",
"VOCABULARY_PATH = Path(\"Assets/Model/small/bpe-small-16.json\")\n",
"TRAIN_DATASET_PATH = Path(\"Assets/Dataset/1-hop/toy/rdf_text.csv\")\n",
"VALIDATION_DATASET_PATH = Path(\"Assets/Dataset/1-hop/toy/rdf_text.csv\")\n",
"TEST_DATASET_PATH = Path(\"Assets/Dataset/1-hop/toy/rdf_text.csv\")\n",
"CHECKPOINT_PATH = Path(\"Assets/Dataset/Tmp/NanoSocrates.zip\")\n",
"\n",
"\n",
"# BPE Init\n",
"SPECIAL_VOC = BPE.default_special_tokens()\n",
"VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)\n",
"TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)\n",
"\n",
"\n",
"# Constants\n",
"MASK_EXTRA_SPACE = 25\n",
"TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE\n",
"EMBEDDED_SIZE = 256\n",
"FEED_FORWARD_MULTIPLIER = 4\n",
"ATTENTION_HEADS = 8\n",
"SENTENCE_LENGTH = 256\n",
"NUMBER_OF_BLOCKS = 4\n",
"MAX_EPOCHS = int(1e3)\n",
"PRETRAIN_EPOCHS = int(2)\n",
"WARMUP_EPOCHS = int(4e3)\n",
"MINI_BATCH_SIZE = 10\n",
"VALIDATION_STEPS = 1\n",
"CHECKPOINT_STEPS = VALIDATION_STEPS * 4\n",
"PATIENCE = 4\n",
"CURRENT_EPOCH = 0\n",
"\n",
"SOS_TOKEN = TOKENANO.encode(\"<SOS>\")[0]\n",
"\n",
"PAD_TOKEN = TOKENANO.encode(\"<PAD>\")[0]\n",
"END_TOKEN = TOKENANO.encode(\"<END>\")[0]\n",
"SUBJ_TOKEN = TOKENANO.encode(\"<SUBJ>\")[0]\n",
"REL_TOKEN = TOKENANO.encode(\"<PRED>\")[0]\n",
"OBJ_TOKEN = TOKENANO.encode(\"<OBJ>\")[0]\n",
"\n",
"SPECIAL_TOKENS: set[int] = set(TOKENANO.encode(\"\".join(BPE.default_special_tokens())))\n",
"ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN])\n",
"FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS\n",
"\n",
"\n",
"# Spanned_Masker\n",
"MASKER = Transformer.SpannedMasker(\n",
" TOKEN_SPACE_SIZE,\n",
" FORBIDDEN_TOKENS\n",
")\n",
"\n",
"TRAIN_BATCHER = Batch.Batcher(\n",
" TRAIN_DATASET_PATH,\n",
" SENTENCE_LENGTH,\n",
" TOKENANO,\n",
" MASKER\n",
")\n",
"VALIDATION_BATCHER = Batch.Batcher(\n",
" VALIDATION_DATASET_PATH,\n",
" SENTENCE_LENGTH,\n",
" TOKENANO,\n",
" MASKER\n",
")\n",
"TEST_BATCHER = Batch.Batcher(\n",
" TEST_DATASET_PATH,\n",
" SENTENCE_LENGTH,\n",
" TOKENANO,\n",
" MASKER\n",
")\n",
"\n",
"\n",
"# Model\n",
"NANOSOCRATES = Transformer.TrainingModel(\n",
" TOKEN_SPACE_SIZE,\n",
" EMBEDDED_SIZE,\n",
" FEED_FORWARD_MULTIPLIER,\n",
" ATTENTION_HEADS,\n",
" NUMBER_OF_BLOCKS\n",
")\n",
"_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(\n",
" NANOSOCRATES,\n",
" TOKEN_SPACE_SIZE,\n",
" EMBEDDED_SIZE\n",
")\n",
"\n",
"\n",
"# Training constants\n",
"cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)\n",
"nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters())\n",
"encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters())\n",
"decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters())\n",
"\n",
"nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)\n",
"encoder_only_scheduler = Transformer.WarmupLR(encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)\n",
"decoder_only_scheduler = Transformer.WarmupLR(decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)\n",
"\n",
"current_epoch = CURRENT_EPOCH\n",
"patience = 0\n",
"\n",
"\n",
"average_loss_validation = {\n",
" \"txt\": float(\"inf\"),\n",
" \"encoder_only\": float(\"inf\"),\n",
" \"decoder_only\": float(\"inf\")\n",
"}\n",
"\n",
"while current_epoch < MAX_EPOCHS:\n",
"\n",
" text_batch_losses = []\n",
" encoder_batch_losses = []\n",
" decoder_batch_losses = []\n",
"\n",
" for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):\n",
"\n",
" src_x, tgt_y, pad_x, pad_y, tasktype = batch\n",
"\n",
" enc_x = torch.tensor(src_x)\n",
" enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)\n",
" dec_x = Transformer.get_decoder_input(MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH)\n",
" dec_x_pad = dec_x.eq(PAD_TOKEN)\n",
" tgt = torch.tensor(tgt_y)\n",
" tgt_pad = torch.tensor(pad_y, dtype=torch.bool)\n",
"\n",
" # Task 1 and Task 2\n",
" if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:\n",
"\n",
" BATCH_LOSS = []\n",
"\n",
" for token_idx in range(0, SENTENCE_LENGTH):\n",
"\n",
" nano_optim.zero_grad()\n",
"\n",
"\n",
"\n",
" pred_logits = NANOSOCRATES((\n",
" enc_x, enc_x_pad, dec_x, dec_x_pad\n",
" ))\n",
"\n",
" pred_logits = pred_logits[:, token_idx, :]\n",
"\n",
" loss: torch.Tensor= cross_entropy(pred_logits, tgt[:, token_idx])\n",
"\n",
" loss.backward()\n",
" nano_optim.step()\n",
"\n",
"\n",
" BATCH_LOSS.append(\n",
" loss.item()\n",
" )\n",
"\n",
" if token_idx < SENTENCE_LENGTH - 1:\n",
" dec_x[:,token_idx + 1] = tgt[:, token_idx]\n",
"\n",
" MIN_BATCH_LOSS = min(BATCH_LOSS)\n",
" MAX_BATCH_LOSS = max(BATCH_LOSS)\n",
" AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE\n",
"\n",
" text_batch_losses.append([MIN_BATCH_LOSS, AVG_BATCH_LOSS, MAX_BATCH_LOSS])\n",
" continue\n",
"\n",
"\n",
" # Pretrain first\n",
" if current_epoch < PRETRAIN_EPOCHS:\n",
" continue\n",
"\n",
"\n",
" # Task 3\n",
" if tasktype == Batch.TaskType.MASKING:\n",
"\n",
" encoder_only_optim.zero_grad()\n",
"\n",
" pred_logits = ENCODER_ONLY((\n",
" enc_x, enc_x_pad\n",
" ))\n",
"\n",
" loss: torch.Tensor= cross_entropy(pred_logits, tgt)\n",
"\n",
" loss.backward()\n",
" encoder_only_optim.step()\n",
"\n",
" encoder_batch_losses.append(\n",
" loss.item()\n",
" )\n",
"\n",
" continue\n",
"\n",
"\n",
" # Task 4\n",
" if tasktype == Batch.TaskType.COMPLETATION:\n",
"\n",
" BATCH_LOSS = []\n",
"\n",
" for token_idx in range(0, SENTENCE_LENGTH):\n",
"\n",
" decoder_only_optim.zero_grad()\n",
"\n",
" pred_logits = DECODER_ONLY((\n",
" enc_x, enc_x_pad\n",
" ))\n",
"\n",
" pred_logits = pred_logits[:, token_idx, :]\n",
"\n",
" loss: torch.Tensor= cross_entropy(pred_logits, tgt[:, token_idx])\n",
"\n",
" loss.backward()\n",
" decoder_only_optim.step()\n",
"\n",
" BATCH_LOSS.append(\n",
" loss.item()\n",
" )\n",
"\n",
" if token_idx < SENTENCE_LENGTH - 1:\n",
" dec_x[:,token_idx + 1] = tgt[:, token_idx]\n",
"\n",
"\n",
" MIN_BATCH_LOSS = min(BATCH_LOSS)\n",
" MAX_BATCH_LOSS = max(BATCH_LOSS)\n",
" AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE\n",
"\n",
" decoder_batch_losses.append([MIN_BATCH_LOSS, AVG_BATCH_LOSS, MAX_BATCH_LOSS])\n",
"\n",
" continue\n",
"\n",
"\n",
" nano_scheduler.step()\n",
" encoder_only_scheduler.step()\n",
" decoder_only_scheduler.step()\n",
"\n",
" current_epoch += 1\n",
"\n",
" if current_epoch % VALIDATION_STEPS == 0:\n",
"\n",
" txt_avg_batch_losses = []\n",
" enc_avg_batch_losses = []\n",
" dec_avg_batch_losses = []\n",
"\n",
" for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):\n",
"\n",
" src_x, tgt_y, pad_x, pad_y, tasktype = batch\n",
"\n",
" enc_x = torch.tensor(src_x)\n",
" enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)\n",
" dec_x = Transformer.get_decoder_input(MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH)\n",
" dec_x_pad = dec_x.eq(PAD_TOKEN)\n",
" tgt = torch.tensor(tgt_y)\n",
" tgt_pad = torch.tensor(pad_y, dtype=torch.bool)\n",
"\n",
" # Task 1 and Task 2\n",
" if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:\n",
"\n",
" BATCH_LOSS = []\n",
"\n",
" for token_idx in range(0, SENTENCE_LENGTH):\n",
"\n",
"\n",
"\n",
" pred_logits = NANOSOCRATES((\n",
" enc_x, enc_x_pad, dec_x, dec_x_pad\n",
" ))\n",
"\n",
" pred_logits = pred_logits[:, token_idx, :]\n",
"\n",
" loss: torch.Tensor= cross_entropy(pred_logits, tgt[:, token_idx])\n",
"\n",
"\n",
" BATCH_LOSS.append(\n",
" loss.item()\n",
" )\n",
"\n",
" if token_idx < SENTENCE_LENGTH - 1:\n",
" dec_x[:,token_idx + 1] = tgt[:, token_idx]\n",
"\n",
"\n",
" AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE\n",
" txt_avg_batch_losses.append(AVG_BATCH_LOSS)\n",
"\n",
" continue\n",
"\n",
"\n",
" # Pretrain first\n",
" if current_epoch < PRETRAIN_EPOCHS:\n",
" continue\n",
"\n",
"\n",
" # Task 3\n",
" if tasktype == Batch.TaskType.MASKING:\n",
"\n",
" pred_logits = ENCODER_ONLY((\n",
" enc_x, enc_x_pad\n",
" ))\n",
"\n",
" loss: torch.Tensor= cross_entropy(pred_logits, tgt)\n",
"\n",
" enc_avg_batch_losses.append(\n",
" loss.item()\n",
" )\n",
"\n",
" continue\n",
"\n",
"\n",
" # Task 4\n",
" if tasktype == Batch.TaskType.COMPLETATION:\n",
"\n",
" BATCH_LOSS = []\n",
"\n",
" for token_idx in range(0, SENTENCE_LENGTH):\n",
"\n",
" pred_logits = DECODER_ONLY((\n",
" enc_x, enc_x_pad\n",
" ))\n",
"\n",
" pred_logits = pred_logits[:, token_idx, :]\n",
"\n",
" loss: torch.Tensor= cross_entropy(pred_logits, tgt[:, token_idx])\n",
"\n",
" BATCH_LOSS.append(\n",
" loss.item()\n",
" )\n",
"\n",
" if token_idx < SENTENCE_LENGTH - 1:\n",
" dec_x[:,token_idx + 1] = tgt[:, token_idx]\n",
"\n",
"\n",
" AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE\n",
"\n",
" dec_avg_batch_losses.append(AVG_BATCH_LOSS)\n",
"\n",
" continue\n",
"\n",
" txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)\n",
" enc_avg_loss = float(\"inf\")\n",
" dec_avg_loss = float(\"inf\")\n",
"\n",
" if current_epoch >= PRETRAIN_EPOCHS:\n",
" enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses)\n",
" dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses)\n",
"\n",
" if current_epoch < PRETRAIN_EPOCHS:\n",
"\n",
" if txt_avg_loss < average_loss_validation[\"txt\"]:\n",
" average_loss_validation[\"txt\"] = txt_avg_loss\n",
" else:\n",
" patience += 1\n",
" else:\n",
"\n",
" counter = 0\n",
"\n",
" if txt_avg_loss > average_loss_validation[\"txt\"]:\n",
" counter += 1\n",
"\n",
" if txt_avg_loss > average_loss_validation[\"encoder_only\"]:\n",
" counter += 1\n",
"\n",
" if txt_avg_loss > average_loss_validation[\"decoder_only\"]:\n",
" counter += 1\n",
"\n",
" if counter > 1:\n",
" patience += 1\n",
"\n",
" txt_min_train_losses = text_batch_losses[:][0]\n",
" txt_avg_train_losses = text_batch_losses[:][1]\n",
" txt_max_train_losses = text_batch_losses[:][2]\n",
"\n",
" txt_min_loss = min(txt_min_train_losses)\n",
" txt_avg_min_loss = sum(txt_min_train_losses) / len(txt_min_train_losses)\n",
" txt_max_loss = max(txt_max_train_losses)\n",
" txt_avg_max_loss = sum(txt_max_train_losses) / len(txt_max_train_losses)\n",
" txt_avg_loss = sum(txt_avg_train_losses) / len(txt_avg_train_losses)\n",
"\n",
" enc_avg_train_loss = float(\"inf\")\n",
"\n",
" dec_min_loss = float(\"inf\")\n",
" dec_avg_min_loss = float(\"inf\")\n",
" dec_max_loss = float(\"inf\")\n",
" dec_avg_max_loss = float(\"inf\")\n",
" dec_avg_loss = float(\"inf\")\n",
"\n",
" if current_epoch >= PRETRAIN_EPOCHS:\n",
" enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses)\n",
"\n",
" dec_min_train_losses = decoder_batch_losses[:][0]\n",
" dec_avg_train_losses = decoder_batch_losses[:][1]\n",
" dec_max_train_losses = decoder_batch_losses[:][2]\n",
"\n",
" dec_min_loss = min(dec_min_train_losses)\n",
" dec_avg_min_loss = sum(dec_min_train_losses) / len(dec_min_train_losses)\n",
" dec_max_loss = max(dec_max_train_losses)\n",
" dec_avg_max_loss = sum(dec_max_train_losses) / len(dec_max_train_losses)\n",
" dec_avg_loss = sum(dec_avg_train_losses) / len(dec_avg_train_losses)\n",
"\n",
"\n",
" SEPARATOR = \"===========================================================================================\"\n",
" DEBUG_TEXT = \"\".join([\n",
" f\"{SEPARATOR}\\n\",\n",
" f\"EPOCH {current_epoch}\"\n",
" f\"{SEPARATOR}\\n\",\n",
" f\"Train Losses:\\n\"\n",
" f\"\\tMin Losses:\\n\"\n",
" f\"\\t\\tmin_txt: {txt_min_loss} - avg_txt: {txt_avg_min_loss}\\n\"\n",
" f\"\\t\\tmin_dec: {dec_min_loss} - avg_dec: {dec_avg_min_loss}\\n\"\n",
" f\"\\tMax Losses:\\n\"\n",
" f\"\\t\\tmax_txt: {txt_max_loss} - avg_txt: {txt_avg_max_loss}\\n\"\n",
" f\"\\t\\tmax_dec: {dec_min_loss} - avg_dec: {dec_avg_max_loss}\\n\"\n",
" f\"\\tAvg Losses:\\n\"\n",
" f\"\\t\\tavg_txt: {txt_avg_loss} - avg_enc: {enc_avg_loss} - avg_dec: {dec_avg_loss}\\n\"\n",
" f\"{SEPARATOR}\\n\",\n",
" f\"Validation Losses:\\n\"\n",
" f\"\\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\"\n",
" f\"{SEPARATOR}\\n\",\n",
" ])\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" # Warn about patience\n",
" if patience == PATIENCE:\n",
" print(\n",
" \"Model is likely overfitting, so let's stop here\"\n",
" )\n",
"\n",
" # SAVE MODEL\n",
" if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE:\n",
" print(f\"Saving model at {CHECKPOINT_PATH.as_posix()}\")\n",
" torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "deep_learning",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}