V0.0.1 Athene
This commit is contained in:
parent
49946727d8
commit
160b7dbfc0
BIN
Assets/Dataset/1-hop/small/holdout/evaluation.csv
(Stored with Git LFS)
Normal file
BIN
Assets/Dataset/1-hop/small/holdout/evaluation.csv
(Stored with Git LFS)
Normal file
Binary file not shown.
|
BIN
Assets/Dataset/1-hop/small/holdout/test.csv
(Stored with Git LFS)
Normal file
BIN
Assets/Dataset/1-hop/small/holdout/test.csv
(Stored with Git LFS)
Normal file
Binary file not shown.
|
BIN
Assets/Dataset/1-hop/small/holdout/train.csv
(Stored with Git LFS)
Normal file
BIN
Assets/Dataset/1-hop/small/holdout/train.csv
(Stored with Git LFS)
Normal file
Binary file not shown.
|
File diff suppressed because one or more lines are too long
509
Playgrounds/nanosocrates-train.ipynb
Normal file
509
Playgrounds/nanosocrates-train.ipynb
Normal file
@ -0,0 +1,509 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "adbef43f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"c:\\Users\\Chris\\miniconda3\\envs\\deep_learning\\Lib\\site-packages\\torch\\utils\\_device.py:103: UserWarning: Aten Op fallback from XPU to CPU happends. This may have performance implications. If need debug the fallback ops please set environment variable `PYTORCH_DEBUG_XPU_FALLBACK=1` (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\build\\xpu\\ATen\\RegisterXPU_0.cpp:54528.)\n",
|
||||||
|
" return func(*args, **kwargs)\n",
|
||||||
|
"c:\\Users\\Chris\\miniconda3\\envs\\deep_learning\\Lib\\site-packages\\torch\\optim\\lr_scheduler.py:192: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "IndexError",
|
||||||
|
"evalue": "list index out of range",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||||
|
"\u001b[31mIndexError\u001b[39m Traceback (most recent call last)",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 383\u001b[39m\n\u001b[32m 381\u001b[39m txt_min_train_losses = text_batch_losses[:][\u001b[32m0\u001b[39m]\n\u001b[32m 382\u001b[39m txt_avg_train_losses = text_batch_losses[:][\u001b[32m1\u001b[39m]\n\u001b[32m--> \u001b[39m\u001b[32m383\u001b[39m txt_max_train_losses = \u001b[43mtext_batch_losses\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 385\u001b[39m txt_min_loss = \u001b[38;5;28mmin\u001b[39m(txt_min_train_losses)\n\u001b[32m 386\u001b[39m txt_avg_min_loss = \u001b[38;5;28msum\u001b[39m(txt_min_train_losses) / \u001b[38;5;28mlen\u001b[39m(txt_min_train_losses)\n",
|
||||||
|
"\u001b[31mIndexError\u001b[39m: list index out of range"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import random\n",
|
||||||
|
"import sys\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"import Project_Model.Libs.Embedder as Embedder\n",
|
||||||
|
"import Project_Model.Libs.BPE as BPE\n",
|
||||||
|
"import Project_Model.Libs.Transformer as Transformer\n",
|
||||||
|
"import Project_Model.Libs.TransformerUtils as TUtils\n",
|
||||||
|
"import Project_Model.Libs.TorchShims as torch_shims\n",
|
||||||
|
"import Project_Model.Libs.Batch as Batch\n",
|
||||||
|
"\n",
|
||||||
|
"# set a fixed seed\n",
|
||||||
|
"torch.manual_seed(0)\n",
|
||||||
|
"random.seed(0)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# set a default device\n",
|
||||||
|
"DEVICE = torch_shims.get_default_device()\n",
|
||||||
|
"torch.set_default_device(DEVICE)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Get paths\n",
|
||||||
|
"VOCABULARY_PATH = Path(\"Assets/Model/small/bpe-small-16.json\")\n",
|
||||||
|
"TRAIN_DATASET_PATH = Path(\"Assets/Dataset/1-hop/toy/rdf_text.csv\")\n",
|
||||||
|
"VALIDATION_DATASET_PATH = Path(\"Assets/Dataset/1-hop/toy/rdf_text.csv\")\n",
|
||||||
|
"TEST_DATASET_PATH = Path(\"Assets/Dataset/1-hop/toy/rdf_text.csv\")\n",
|
||||||
|
"CHECKPOINT_PATH = Path(\"Assets/Dataset/Tmp/NanoSocrates.zip\")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# BPE Init\n",
|
||||||
|
"SPECIAL_VOC = BPE.default_special_tokens()\n",
|
||||||
|
"VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)\n",
|
||||||
|
"TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Constants\n",
|
||||||
|
"MASK_EXTRA_SPACE = 25\n",
|
||||||
|
"TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE\n",
|
||||||
|
"EMBEDDED_SIZE = 256\n",
|
||||||
|
"FEED_FORWARD_MULTIPLIER = 4\n",
|
||||||
|
"ATTENTION_HEADS = 8\n",
|
||||||
|
"SENTENCE_LENGTH = 256\n",
|
||||||
|
"NUMBER_OF_BLOCKS = 4\n",
|
||||||
|
"MAX_EPOCHS = int(1e3)\n",
|
||||||
|
"PRETRAIN_EPOCHS = int(2)\n",
|
||||||
|
"WARMUP_EPOCHS = int(4e3)\n",
|
||||||
|
"MINI_BATCH_SIZE = 10\n",
|
||||||
|
"VALIDATION_STEPS = 1\n",
|
||||||
|
"CHECKPOINT_STEPS = VALIDATION_STEPS * 4\n",
|
||||||
|
"PATIENCE = 4\n",
|
||||||
|
"CURRENT_EPOCH = 0\n",
|
||||||
|
"\n",
|
||||||
|
"SOS_TOKEN = TOKENANO.encode(\"<SOS>\")[0]\n",
|
||||||
|
"\n",
|
||||||
|
"PAD_TOKEN = TOKENANO.encode(\"<PAD>\")[0]\n",
|
||||||
|
"END_TOKEN = TOKENANO.encode(\"<END>\")[0]\n",
|
||||||
|
"SUBJ_TOKEN = TOKENANO.encode(\"<SUBJ>\")[0]\n",
|
||||||
|
"REL_TOKEN = TOKENANO.encode(\"<PRED>\")[0]\n",
|
||||||
|
"OBJ_TOKEN = TOKENANO.encode(\"<OBJ>\")[0]\n",
|
||||||
|
"\n",
|
||||||
|
"SPECIAL_TOKENS: set[int] = set(TOKENANO.encode(\"\".join(BPE.default_special_tokens())))\n",
|
||||||
|
"ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN])\n",
|
||||||
|
"FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Spanned_Masker\n",
|
||||||
|
"MASKER = Transformer.SpannedMasker(\n",
|
||||||
|
" TOKEN_SPACE_SIZE,\n",
|
||||||
|
" FORBIDDEN_TOKENS\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"TRAIN_BATCHER = Batch.Batcher(\n",
|
||||||
|
" TRAIN_DATASET_PATH,\n",
|
||||||
|
" SENTENCE_LENGTH,\n",
|
||||||
|
" TOKENANO,\n",
|
||||||
|
" MASKER\n",
|
||||||
|
")\n",
|
||||||
|
"VALIDATION_BATCHER = Batch.Batcher(\n",
|
||||||
|
" VALIDATION_DATASET_PATH,\n",
|
||||||
|
" SENTENCE_LENGTH,\n",
|
||||||
|
" TOKENANO,\n",
|
||||||
|
" MASKER\n",
|
||||||
|
")\n",
|
||||||
|
"TEST_BATCHER = Batch.Batcher(\n",
|
||||||
|
" TEST_DATASET_PATH,\n",
|
||||||
|
" SENTENCE_LENGTH,\n",
|
||||||
|
" TOKENANO,\n",
|
||||||
|
" MASKER\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Model\n",
|
||||||
|
"NANOSOCRATES = Transformer.TrainingModel(\n",
|
||||||
|
" TOKEN_SPACE_SIZE,\n",
|
||||||
|
" EMBEDDED_SIZE,\n",
|
||||||
|
" FEED_FORWARD_MULTIPLIER,\n",
|
||||||
|
" ATTENTION_HEADS,\n",
|
||||||
|
" NUMBER_OF_BLOCKS\n",
|
||||||
|
")\n",
|
||||||
|
"_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(\n",
|
||||||
|
" NANOSOCRATES,\n",
|
||||||
|
" TOKEN_SPACE_SIZE,\n",
|
||||||
|
" EMBEDDED_SIZE\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Training constants\n",
|
||||||
|
"cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)\n",
|
||||||
|
"nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters())\n",
|
||||||
|
"encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters())\n",
|
||||||
|
"decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters())\n",
|
||||||
|
"\n",
|
||||||
|
"nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)\n",
|
||||||
|
"encoder_only_scheduler = Transformer.WarmupLR(encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)\n",
|
||||||
|
"decoder_only_scheduler = Transformer.WarmupLR(decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)\n",
|
||||||
|
"\n",
|
||||||
|
"current_epoch = CURRENT_EPOCH\n",
|
||||||
|
"patience = 0\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"average_loss_validation = {\n",
|
||||||
|
" \"txt\": float(\"inf\"),\n",
|
||||||
|
" \"encoder_only\": float(\"inf\"),\n",
|
||||||
|
" \"decoder_only\": float(\"inf\")\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"while current_epoch < MAX_EPOCHS:\n",
|
||||||
|
"\n",
|
||||||
|
" text_batch_losses = []\n",
|
||||||
|
" encoder_batch_losses = []\n",
|
||||||
|
" decoder_batch_losses = []\n",
|
||||||
|
"\n",
|
||||||
|
" for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):\n",
|
||||||
|
"\n",
|
||||||
|
" src_x, tgt_y, pad_x, pad_y, tasktype = batch\n",
|
||||||
|
"\n",
|
||||||
|
" enc_x = torch.tensor(src_x)\n",
|
||||||
|
" enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)\n",
|
||||||
|
" dec_x = Transformer.get_decoder_input(MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH)\n",
|
||||||
|
" dec_x_pad = dec_x.eq(PAD_TOKEN)\n",
|
||||||
|
" tgt = torch.tensor(tgt_y)\n",
|
||||||
|
" tgt_pad = torch.tensor(pad_y, dtype=torch.bool)\n",
|
||||||
|
"\n",
|
||||||
|
" # Task 1 and Task 2\n",
|
||||||
|
" if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:\n",
|
||||||
|
"\n",
|
||||||
|
" BATCH_LOSS = []\n",
|
||||||
|
"\n",
|
||||||
|
" for token_idx in range(0, SENTENCE_LENGTH):\n",
|
||||||
|
"\n",
|
||||||
|
" nano_optim.zero_grad()\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" pred_logits = NANOSOCRATES((\n",
|
||||||
|
" enc_x, enc_x_pad, dec_x, dec_x_pad\n",
|
||||||
|
" ))\n",
|
||||||
|
"\n",
|
||||||
|
" pred_logits = pred_logits[:, token_idx, :]\n",
|
||||||
|
"\n",
|
||||||
|
" loss: torch.Tensor= cross_entropy(pred_logits, tgt[:, token_idx])\n",
|
||||||
|
"\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" nano_optim.step()\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" BATCH_LOSS.append(\n",
|
||||||
|
" loss.item()\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" if token_idx < SENTENCE_LENGTH - 1:\n",
|
||||||
|
" dec_x[:,token_idx + 1] = tgt[:, token_idx]\n",
|
||||||
|
"\n",
|
||||||
|
" MIN_BATCH_LOSS = min(BATCH_LOSS)\n",
|
||||||
|
" MAX_BATCH_LOSS = max(BATCH_LOSS)\n",
|
||||||
|
" AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE\n",
|
||||||
|
"\n",
|
||||||
|
" text_batch_losses.append([MIN_BATCH_LOSS, AVG_BATCH_LOSS, MAX_BATCH_LOSS])\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" # Pretrain first\n",
|
||||||
|
" if current_epoch < PRETRAIN_EPOCHS:\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" # Task 3\n",
|
||||||
|
" if tasktype == Batch.TaskType.MASKING:\n",
|
||||||
|
"\n",
|
||||||
|
" encoder_only_optim.zero_grad()\n",
|
||||||
|
"\n",
|
||||||
|
" pred_logits = ENCODER_ONLY((\n",
|
||||||
|
" enc_x, enc_x_pad\n",
|
||||||
|
" ))\n",
|
||||||
|
"\n",
|
||||||
|
" loss: torch.Tensor= cross_entropy(pred_logits, tgt)\n",
|
||||||
|
"\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" encoder_only_optim.step()\n",
|
||||||
|
"\n",
|
||||||
|
" encoder_batch_losses.append(\n",
|
||||||
|
" loss.item()\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" # Task 4\n",
|
||||||
|
" if tasktype == Batch.TaskType.COMPLETATION:\n",
|
||||||
|
"\n",
|
||||||
|
" BATCH_LOSS = []\n",
|
||||||
|
"\n",
|
||||||
|
" for token_idx in range(0, SENTENCE_LENGTH):\n",
|
||||||
|
"\n",
|
||||||
|
" decoder_only_optim.zero_grad()\n",
|
||||||
|
"\n",
|
||||||
|
" pred_logits = DECODER_ONLY((\n",
|
||||||
|
" enc_x, enc_x_pad\n",
|
||||||
|
" ))\n",
|
||||||
|
"\n",
|
||||||
|
" pred_logits = pred_logits[:, token_idx, :]\n",
|
||||||
|
"\n",
|
||||||
|
" loss: torch.Tensor= cross_entropy(pred_logits, tgt[:, token_idx])\n",
|
||||||
|
"\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" decoder_only_optim.step()\n",
|
||||||
|
"\n",
|
||||||
|
" BATCH_LOSS.append(\n",
|
||||||
|
" loss.item()\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" if token_idx < SENTENCE_LENGTH - 1:\n",
|
||||||
|
" dec_x[:,token_idx + 1] = tgt[:, token_idx]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" MIN_BATCH_LOSS = min(BATCH_LOSS)\n",
|
||||||
|
" MAX_BATCH_LOSS = max(BATCH_LOSS)\n",
|
||||||
|
" AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE\n",
|
||||||
|
"\n",
|
||||||
|
" decoder_batch_losses.append([MIN_BATCH_LOSS, AVG_BATCH_LOSS, MAX_BATCH_LOSS])\n",
|
||||||
|
"\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" nano_scheduler.step()\n",
|
||||||
|
" encoder_only_scheduler.step()\n",
|
||||||
|
" decoder_only_scheduler.step()\n",
|
||||||
|
"\n",
|
||||||
|
" current_epoch += 1\n",
|
||||||
|
"\n",
|
||||||
|
" if current_epoch % VALIDATION_STEPS == 0:\n",
|
||||||
|
"\n",
|
||||||
|
" txt_avg_batch_losses = []\n",
|
||||||
|
" enc_avg_batch_losses = []\n",
|
||||||
|
" dec_avg_batch_losses = []\n",
|
||||||
|
"\n",
|
||||||
|
" for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):\n",
|
||||||
|
"\n",
|
||||||
|
" src_x, tgt_y, pad_x, pad_y, tasktype = batch\n",
|
||||||
|
"\n",
|
||||||
|
" enc_x = torch.tensor(src_x)\n",
|
||||||
|
" enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)\n",
|
||||||
|
" dec_x = Transformer.get_decoder_input(MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH)\n",
|
||||||
|
" dec_x_pad = dec_x.eq(PAD_TOKEN)\n",
|
||||||
|
" tgt = torch.tensor(tgt_y)\n",
|
||||||
|
" tgt_pad = torch.tensor(pad_y, dtype=torch.bool)\n",
|
||||||
|
"\n",
|
||||||
|
" # Task 1 and Task 2\n",
|
||||||
|
" if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:\n",
|
||||||
|
"\n",
|
||||||
|
" BATCH_LOSS = []\n",
|
||||||
|
"\n",
|
||||||
|
" for token_idx in range(0, SENTENCE_LENGTH):\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" pred_logits = NANOSOCRATES((\n",
|
||||||
|
" enc_x, enc_x_pad, dec_x, dec_x_pad\n",
|
||||||
|
" ))\n",
|
||||||
|
"\n",
|
||||||
|
" pred_logits = pred_logits[:, token_idx, :]\n",
|
||||||
|
"\n",
|
||||||
|
" loss: torch.Tensor= cross_entropy(pred_logits, tgt[:, token_idx])\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" BATCH_LOSS.append(\n",
|
||||||
|
" loss.item()\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" if token_idx < SENTENCE_LENGTH - 1:\n",
|
||||||
|
" dec_x[:,token_idx + 1] = tgt[:, token_idx]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE\n",
|
||||||
|
" txt_avg_batch_losses.append(AVG_BATCH_LOSS)\n",
|
||||||
|
"\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" # Pretrain first\n",
|
||||||
|
" if current_epoch < PRETRAIN_EPOCHS:\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" # Task 3\n",
|
||||||
|
" if tasktype == Batch.TaskType.MASKING:\n",
|
||||||
|
"\n",
|
||||||
|
" pred_logits = ENCODER_ONLY((\n",
|
||||||
|
" enc_x, enc_x_pad\n",
|
||||||
|
" ))\n",
|
||||||
|
"\n",
|
||||||
|
" loss: torch.Tensor= cross_entropy(pred_logits, tgt)\n",
|
||||||
|
"\n",
|
||||||
|
" enc_avg_batch_losses.append(\n",
|
||||||
|
" loss.item()\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" # Task 4\n",
|
||||||
|
" if tasktype == Batch.TaskType.COMPLETATION:\n",
|
||||||
|
"\n",
|
||||||
|
" BATCH_LOSS = []\n",
|
||||||
|
"\n",
|
||||||
|
" for token_idx in range(0, SENTENCE_LENGTH):\n",
|
||||||
|
"\n",
|
||||||
|
" pred_logits = DECODER_ONLY((\n",
|
||||||
|
" enc_x, enc_x_pad\n",
|
||||||
|
" ))\n",
|
||||||
|
"\n",
|
||||||
|
" pred_logits = pred_logits[:, token_idx, :]\n",
|
||||||
|
"\n",
|
||||||
|
" loss: torch.Tensor= cross_entropy(pred_logits, tgt[:, token_idx])\n",
|
||||||
|
"\n",
|
||||||
|
" BATCH_LOSS.append(\n",
|
||||||
|
" loss.item()\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" if token_idx < SENTENCE_LENGTH - 1:\n",
|
||||||
|
" dec_x[:,token_idx + 1] = tgt[:, token_idx]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE\n",
|
||||||
|
"\n",
|
||||||
|
" dec_avg_batch_losses.append(AVG_BATCH_LOSS)\n",
|
||||||
|
"\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)\n",
|
||||||
|
" enc_avg_loss = float(\"inf\")\n",
|
||||||
|
" dec_avg_loss = float(\"inf\")\n",
|
||||||
|
"\n",
|
||||||
|
" if current_epoch >= PRETRAIN_EPOCHS:\n",
|
||||||
|
" enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses)\n",
|
||||||
|
" dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses)\n",
|
||||||
|
"\n",
|
||||||
|
" if current_epoch < PRETRAIN_EPOCHS:\n",
|
||||||
|
"\n",
|
||||||
|
" if txt_avg_loss < average_loss_validation[\"txt\"]:\n",
|
||||||
|
" average_loss_validation[\"txt\"] = txt_avg_loss\n",
|
||||||
|
" else:\n",
|
||||||
|
" patience += 1\n",
|
||||||
|
" else:\n",
|
||||||
|
"\n",
|
||||||
|
" counter = 0\n",
|
||||||
|
"\n",
|
||||||
|
" if txt_avg_loss > average_loss_validation[\"txt\"]:\n",
|
||||||
|
" counter += 1\n",
|
||||||
|
"\n",
|
||||||
|
" if txt_avg_loss > average_loss_validation[\"encoder_only\"]:\n",
|
||||||
|
" counter += 1\n",
|
||||||
|
"\n",
|
||||||
|
" if txt_avg_loss > average_loss_validation[\"decoder_only\"]:\n",
|
||||||
|
" counter += 1\n",
|
||||||
|
"\n",
|
||||||
|
" if counter > 1:\n",
|
||||||
|
" patience += 1\n",
|
||||||
|
"\n",
|
||||||
|
" txt_min_train_losses = text_batch_losses[:][0]\n",
|
||||||
|
" txt_avg_train_losses = text_batch_losses[:][1]\n",
|
||||||
|
" txt_max_train_losses = text_batch_losses[:][2]\n",
|
||||||
|
"\n",
|
||||||
|
" txt_min_loss = min(txt_min_train_losses)\n",
|
||||||
|
" txt_avg_min_loss = sum(txt_min_train_losses) / len(txt_min_train_losses)\n",
|
||||||
|
" txt_max_loss = max(txt_max_train_losses)\n",
|
||||||
|
" txt_avg_max_loss = sum(txt_max_train_losses) / len(txt_max_train_losses)\n",
|
||||||
|
" txt_avg_loss = sum(txt_avg_train_losses) / len(txt_avg_train_losses)\n",
|
||||||
|
"\n",
|
||||||
|
" enc_avg_train_loss = float(\"inf\")\n",
|
||||||
|
"\n",
|
||||||
|
" dec_min_loss = float(\"inf\")\n",
|
||||||
|
" dec_avg_min_loss = float(\"inf\")\n",
|
||||||
|
" dec_max_loss = float(\"inf\")\n",
|
||||||
|
" dec_avg_max_loss = float(\"inf\")\n",
|
||||||
|
" dec_avg_loss = float(\"inf\")\n",
|
||||||
|
"\n",
|
||||||
|
" if current_epoch >= PRETRAIN_EPOCHS:\n",
|
||||||
|
" enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses)\n",
|
||||||
|
"\n",
|
||||||
|
" dec_min_train_losses = decoder_batch_losses[:][0]\n",
|
||||||
|
" dec_avg_train_losses = decoder_batch_losses[:][1]\n",
|
||||||
|
" dec_max_train_losses = decoder_batch_losses[:][2]\n",
|
||||||
|
"\n",
|
||||||
|
" dec_min_loss = min(dec_min_train_losses)\n",
|
||||||
|
" dec_avg_min_loss = sum(dec_min_train_losses) / len(dec_min_train_losses)\n",
|
||||||
|
" dec_max_loss = max(dec_max_train_losses)\n",
|
||||||
|
" dec_avg_max_loss = sum(dec_max_train_losses) / len(dec_max_train_losses)\n",
|
||||||
|
" dec_avg_loss = sum(dec_avg_train_losses) / len(dec_avg_train_losses)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" SEPARATOR = \"===========================================================================================\"\n",
|
||||||
|
" DEBUG_TEXT = \"\".join([\n",
|
||||||
|
" f\"{SEPARATOR}\\n\",\n",
|
||||||
|
" f\"EPOCH {current_epoch}\"\n",
|
||||||
|
" f\"{SEPARATOR}\\n\",\n",
|
||||||
|
" f\"Train Losses:\\n\"\n",
|
||||||
|
" f\"\\tMin Losses:\\n\"\n",
|
||||||
|
" f\"\\t\\tmin_txt: {txt_min_loss} - avg_txt: {txt_avg_min_loss}\\n\"\n",
|
||||||
|
" f\"\\t\\tmin_dec: {dec_min_loss} - avg_dec: {dec_avg_min_loss}\\n\"\n",
|
||||||
|
" f\"\\tMax Losses:\\n\"\n",
|
||||||
|
" f\"\\t\\tmax_txt: {txt_max_loss} - avg_txt: {txt_avg_max_loss}\\n\"\n",
|
||||||
|
" f\"\\t\\tmax_dec: {dec_min_loss} - avg_dec: {dec_avg_max_loss}\\n\"\n",
|
||||||
|
" f\"\\tAvg Losses:\\n\"\n",
|
||||||
|
" f\"\\t\\tavg_txt: {txt_avg_loss} - avg_enc: {enc_avg_loss} - avg_dec: {dec_avg_loss}\\n\"\n",
|
||||||
|
" f\"{SEPARATOR}\\n\",\n",
|
||||||
|
" f\"Validation Losses:\\n\"\n",
|
||||||
|
" f\"\\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\"\n",
|
||||||
|
" f\"{SEPARATOR}\\n\",\n",
|
||||||
|
" ])\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" # Warn about patience\n",
|
||||||
|
" if patience == PATIENCE:\n",
|
||||||
|
" print(\n",
|
||||||
|
" \"Model is likely overfitting, so let's stop here\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" # SAVE MODEL\n",
|
||||||
|
" if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE:\n",
|
||||||
|
" print(f\"Saving model at {CHECKPOINT_PATH.as_posix()}\")\n",
|
||||||
|
" torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "deep_learning",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.13.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
410
Playgrounds/nanosocrates-train.py
Normal file
410
Playgrounds/nanosocrates-train.py
Normal file
@ -0,0 +1,410 @@
|
|||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
import Project_Model.Libs.Embedder as Embedder
|
||||||
|
import Project_Model.Libs.BPE as BPE
|
||||||
|
import Project_Model.Libs.Transformer as Transformer
|
||||||
|
import Project_Model.Libs.TransformerUtils as TUtils
|
||||||
|
import Project_Model.Libs.TorchShims as torch_shims
|
||||||
|
import Project_Model.Libs.Batch as Batch
|
||||||
|
|
||||||
|
# set a fixed seed
|
||||||
|
torch.manual_seed(0)
|
||||||
|
random.seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
# set a default device
|
||||||
|
DEVICE = torch_shims.get_default_device()
|
||||||
|
torch.set_default_device(DEVICE)
|
||||||
|
|
||||||
|
|
||||||
|
# Get paths
|
||||||
|
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
||||||
|
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
|
||||||
|
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
|
||||||
|
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
|
||||||
|
CHECKPOINT_PATH = Path("Assets/Dataset/Tmp/NanoSocrates.zip")
|
||||||
|
|
||||||
|
|
||||||
|
# BPE Init
|
||||||
|
SPECIAL_VOC = BPE.default_special_tokens()
|
||||||
|
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
|
||||||
|
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)
|
||||||
|
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
MASK_EXTRA_SPACE = 100
|
||||||
|
REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size
|
||||||
|
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
|
||||||
|
EMBEDDED_SIZE = 256
|
||||||
|
FEED_FORWARD_MULTIPLIER = 4
|
||||||
|
ATTENTION_HEADS = 8
|
||||||
|
SENTENCE_LENGTH = 256
|
||||||
|
NUMBER_OF_BLOCKS = 4
|
||||||
|
MAX_EPOCHS = int(1e3)
|
||||||
|
PRETRAIN_EPOCHS = int(10)
|
||||||
|
WARMUP_EPOCHS = int(4e3)
|
||||||
|
MINI_BATCH_SIZE = 100
|
||||||
|
VALIDATION_STEPS = 5
|
||||||
|
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
||||||
|
PATIENCE = 4
|
||||||
|
CURRENT_EPOCH = 0
|
||||||
|
|
||||||
|
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||||
|
|
||||||
|
PAD_TOKEN = TOKENANO.encode("<PAD>")[0]
|
||||||
|
END_TOKEN = TOKENANO.encode("<END>")[0]
|
||||||
|
SUBJ_TOKEN = TOKENANO.encode("<SUBJ>")[0]
|
||||||
|
REL_TOKEN = TOKENANO.encode("<PRED>")[0]
|
||||||
|
OBJ_TOKEN = TOKENANO.encode("<OBJ>")[0]
|
||||||
|
|
||||||
|
SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens())))
|
||||||
|
ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN])
|
||||||
|
FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS
|
||||||
|
|
||||||
|
|
||||||
|
# Spanned_Masker
|
||||||
|
MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS)
|
||||||
|
|
||||||
|
TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
||||||
|
VALIDATION_BATCHER = Batch.Batcher(
|
||||||
|
VALIDATION_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER
|
||||||
|
)
|
||||||
|
TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
||||||
|
|
||||||
|
|
||||||
|
# Model
|
||||||
|
NANOSOCRATES = Transformer.TrainingModel(
|
||||||
|
TOKEN_SPACE_SIZE,
|
||||||
|
EMBEDDED_SIZE,
|
||||||
|
FEED_FORWARD_MULTIPLIER,
|
||||||
|
ATTENTION_HEADS,
|
||||||
|
NUMBER_OF_BLOCKS,
|
||||||
|
)
|
||||||
|
_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
||||||
|
NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Training constants
|
||||||
|
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||||
|
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||||
|
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||||
|
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters())
|
||||||
|
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters())
|
||||||
|
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters())
|
||||||
|
|
||||||
|
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
||||||
|
encoder_only_scheduler = Transformer.WarmupLR(
|
||||||
|
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
||||||
|
)
|
||||||
|
decoder_only_scheduler = Transformer.WarmupLR(
|
||||||
|
decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
|
current_epoch = CURRENT_EPOCH
|
||||||
|
patience = 0
|
||||||
|
|
||||||
|
|
||||||
|
average_loss_validation = {
|
||||||
|
"txt": float("inf"),
|
||||||
|
"encoder_only": float("inf"),
|
||||||
|
"decoder_only": float("inf"),
|
||||||
|
}
|
||||||
|
|
||||||
|
while current_epoch < MAX_EPOCHS:
|
||||||
|
|
||||||
|
NANOSOCRATES.train()
|
||||||
|
ENCODER_ONLY.train()
|
||||||
|
DECODER_ONLY.train()
|
||||||
|
|
||||||
|
text_batch_losses = []
|
||||||
|
encoder_batch_losses = []
|
||||||
|
decoder_batch_losses = []
|
||||||
|
|
||||||
|
for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):
|
||||||
|
|
||||||
|
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||||
|
|
||||||
|
enc_x = torch.tensor(src_x)
|
||||||
|
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
||||||
|
dec_x = Transformer.get_decoder_input(
|
||||||
|
MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
||||||
|
)
|
||||||
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||||
|
tgt = torch.tensor(tgt_y)
|
||||||
|
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||||
|
|
||||||
|
# Task 1 and Task 2
|
||||||
|
if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:
|
||||||
|
BATCH_LOSS = []
|
||||||
|
|
||||||
|
for token_idx in range(0, SENTENCE_LENGTH):
|
||||||
|
|
||||||
|
nano_optim.zero_grad()
|
||||||
|
|
||||||
|
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
||||||
|
|
||||||
|
pred_logits = pred_logits[:, token_idx, :]
|
||||||
|
|
||||||
|
loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt[:, token_idx])
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
nano_optim.step()
|
||||||
|
|
||||||
|
BATCH_LOSS.append(loss.item())
|
||||||
|
|
||||||
|
if token_idx < SENTENCE_LENGTH - 1:
|
||||||
|
dec_x[:, token_idx + 1] = tgt[:, token_idx]
|
||||||
|
|
||||||
|
MIN_BATCH_LOSS = min(BATCH_LOSS)
|
||||||
|
MAX_BATCH_LOSS = max(BATCH_LOSS)
|
||||||
|
AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE
|
||||||
|
|
||||||
|
text_batch_losses.append([MIN_BATCH_LOSS, AVG_BATCH_LOSS, MAX_BATCH_LOSS])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Pretrain first
|
||||||
|
if current_epoch < PRETRAIN_EPOCHS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Task 3
|
||||||
|
if tasktype == Batch.TaskType.MASKING:
|
||||||
|
|
||||||
|
encoder_only_optim.zero_grad()
|
||||||
|
|
||||||
|
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
||||||
|
pred_logits = pred_logits.permute(0, 2, 1)
|
||||||
|
print(torch.max(tgt))
|
||||||
|
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
encoder_only_optim.step()
|
||||||
|
|
||||||
|
encoder_batch_losses.append(loss.item())
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Task 4
|
||||||
|
if tasktype == Batch.TaskType.COMPLETATION:
|
||||||
|
|
||||||
|
BATCH_LOSS = []
|
||||||
|
|
||||||
|
for token_idx in range(0, SENTENCE_LENGTH):
|
||||||
|
|
||||||
|
decoder_only_optim.zero_grad()
|
||||||
|
|
||||||
|
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
||||||
|
|
||||||
|
pred_logits = pred_logits[:, token_idx, :]
|
||||||
|
|
||||||
|
loss: torch.Tensor = decoder_ce(pred_logits, tgt[:, token_idx])
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
decoder_only_optim.step()
|
||||||
|
|
||||||
|
BATCH_LOSS.append(loss.item())
|
||||||
|
|
||||||
|
if token_idx < SENTENCE_LENGTH - 1:
|
||||||
|
dec_x[:, token_idx + 1] = tgt[:, token_idx]
|
||||||
|
|
||||||
|
MIN_BATCH_LOSS = min(BATCH_LOSS)
|
||||||
|
MAX_BATCH_LOSS = max(BATCH_LOSS)
|
||||||
|
AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE
|
||||||
|
|
||||||
|
decoder_batch_losses.append(
|
||||||
|
[MIN_BATCH_LOSS, AVG_BATCH_LOSS, MAX_BATCH_LOSS]
|
||||||
|
)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
nano_scheduler.step()
|
||||||
|
encoder_only_scheduler.step()
|
||||||
|
decoder_only_scheduler.step()
|
||||||
|
|
||||||
|
current_epoch += 1
|
||||||
|
|
||||||
|
if current_epoch % VALIDATION_STEPS == 0:
|
||||||
|
|
||||||
|
NANOSOCRATES.eval()
|
||||||
|
ENCODER_ONLY.eval()
|
||||||
|
DECODER_ONLY.eval()
|
||||||
|
|
||||||
|
txt_avg_batch_losses = []
|
||||||
|
enc_avg_batch_losses = []
|
||||||
|
dec_avg_batch_losses = []
|
||||||
|
|
||||||
|
for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):
|
||||||
|
|
||||||
|
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
||||||
|
|
||||||
|
enc_x = torch.tensor(src_x)
|
||||||
|
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
||||||
|
dec_x = Transformer.get_decoder_input(
|
||||||
|
MINI_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
||||||
|
)
|
||||||
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||||
|
tgt = torch.tensor(tgt_y)
|
||||||
|
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||||
|
|
||||||
|
# Task 1 and Task 2
|
||||||
|
if (
|
||||||
|
tasktype == Batch.TaskType.RDF2TXT
|
||||||
|
or tasktype == Batch.TaskType.TEXT2RDF
|
||||||
|
):
|
||||||
|
|
||||||
|
BATCH_LOSS = []
|
||||||
|
|
||||||
|
for token_idx in range(0, SENTENCE_LENGTH):
|
||||||
|
|
||||||
|
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
||||||
|
|
||||||
|
pred_logits = pred_logits[:, token_idx, :]
|
||||||
|
|
||||||
|
loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt[:, token_idx])
|
||||||
|
|
||||||
|
BATCH_LOSS.append(loss.item())
|
||||||
|
|
||||||
|
if token_idx < SENTENCE_LENGTH - 1:
|
||||||
|
dec_x[:, token_idx + 1] = tgt[:, token_idx]
|
||||||
|
|
||||||
|
AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE
|
||||||
|
txt_avg_batch_losses.append(AVG_BATCH_LOSS)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Pretrain first
|
||||||
|
if current_epoch < PRETRAIN_EPOCHS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Task 3
|
||||||
|
if tasktype == Batch.TaskType.MASKING:
|
||||||
|
|
||||||
|
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
||||||
|
pred_logits = pred_logits.permute(0, 2, 1)
|
||||||
|
|
||||||
|
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
||||||
|
|
||||||
|
enc_avg_batch_losses.append(loss.item())
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Task 4
|
||||||
|
if tasktype == Batch.TaskType.COMPLETATION:
|
||||||
|
|
||||||
|
BATCH_LOSS = []
|
||||||
|
|
||||||
|
for token_idx in range(0, SENTENCE_LENGTH):
|
||||||
|
|
||||||
|
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
||||||
|
|
||||||
|
pred_logits = pred_logits[:, token_idx, :]
|
||||||
|
|
||||||
|
loss: torch.Tensor = decoder_ce(pred_logits, tgt[:, token_idx])
|
||||||
|
|
||||||
|
BATCH_LOSS.append(loss.item())
|
||||||
|
|
||||||
|
if token_idx < SENTENCE_LENGTH - 1:
|
||||||
|
dec_x[:, token_idx + 1] = tgt[:, token_idx]
|
||||||
|
|
||||||
|
AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE
|
||||||
|
|
||||||
|
dec_avg_batch_losses.append(AVG_BATCH_LOSS)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)
|
||||||
|
enc_avg_loss = float("inf")
|
||||||
|
dec_avg_loss = float("inf")
|
||||||
|
|
||||||
|
if current_epoch >= PRETRAIN_EPOCHS:
|
||||||
|
enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses)
|
||||||
|
dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses)
|
||||||
|
|
||||||
|
if current_epoch < PRETRAIN_EPOCHS:
|
||||||
|
|
||||||
|
if txt_avg_loss < average_loss_validation["txt"]:
|
||||||
|
average_loss_validation["txt"] = txt_avg_loss
|
||||||
|
else:
|
||||||
|
patience += 1
|
||||||
|
else:
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
if txt_avg_loss > average_loss_validation["txt"]:
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
if txt_avg_loss > average_loss_validation["encoder_only"]:
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
if txt_avg_loss > average_loss_validation["decoder_only"]:
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
if counter > 1:
|
||||||
|
patience += 1
|
||||||
|
|
||||||
|
txt_min_train_losses = [row[0] for row in text_batch_losses]
|
||||||
|
txt_avg_train_losses = [row[1] for row in text_batch_losses]
|
||||||
|
txt_max_train_losses = [row[2] for row in text_batch_losses]
|
||||||
|
|
||||||
|
txt_min_loss = min(txt_min_train_losses)
|
||||||
|
txt_avg_min_loss = sum(txt_min_train_losses) / len(txt_min_train_losses)
|
||||||
|
txt_max_loss = max(txt_max_train_losses)
|
||||||
|
txt_avg_max_loss = sum(txt_max_train_losses) / len(txt_max_train_losses)
|
||||||
|
txt_avg_loss = sum(txt_avg_train_losses) / len(txt_avg_train_losses)
|
||||||
|
|
||||||
|
enc_avg_train_loss = float("inf")
|
||||||
|
|
||||||
|
dec_min_loss = float("inf")
|
||||||
|
dec_avg_min_loss = float("inf")
|
||||||
|
dec_max_loss = float("inf")
|
||||||
|
dec_avg_max_loss = float("inf")
|
||||||
|
dec_avg_loss = float("inf")
|
||||||
|
|
||||||
|
if current_epoch >= PRETRAIN_EPOCHS:
|
||||||
|
enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses)
|
||||||
|
|
||||||
|
dec_min_train_losses = [row[0] for row in decoder_batch_losses]
|
||||||
|
dec_avg_train_losses = [row[1] for row in decoder_batch_losses]
|
||||||
|
dec_max_train_losses = [row[2] for row in decoder_batch_losses]
|
||||||
|
|
||||||
|
dec_min_loss = min(dec_min_train_losses)
|
||||||
|
dec_avg_min_loss = sum(dec_min_train_losses) / len(dec_min_train_losses)
|
||||||
|
dec_max_loss = max(dec_max_train_losses)
|
||||||
|
dec_avg_max_loss = sum(dec_max_train_losses) / len(dec_max_train_losses)
|
||||||
|
dec_avg_loss = sum(dec_avg_train_losses) / len(dec_avg_train_losses)
|
||||||
|
|
||||||
|
SEPARATOR = "================================================================================================================"
|
||||||
|
DEBUG_TEXT = "".join(
|
||||||
|
[
|
||||||
|
f"{SEPARATOR}\n",
|
||||||
|
f"EPOCH {current_epoch}\n",
|
||||||
|
f"{SEPARATOR}\n",
|
||||||
|
f"Train Losses:\n",
|
||||||
|
f"\tMin Losses:\n",
|
||||||
|
f"\t\tmin_txt: {txt_min_loss} - avg_txt: {txt_avg_min_loss}\n",
|
||||||
|
f"\t\tmin_dec: {dec_min_loss} - avg_dec: {dec_avg_min_loss}\n",
|
||||||
|
f"\tMax Losses:\n",
|
||||||
|
f"\t\tmax_txt: {txt_max_loss} - avg_txt: {txt_avg_max_loss}\n",
|
||||||
|
f"\t\tmax_dec: {dec_min_loss} - avg_dec: {dec_avg_max_loss}\n",
|
||||||
|
f"\tAvg Losses:\n",
|
||||||
|
f"\t\tavg_txt: {txt_avg_loss} - avg_enc: {enc_avg_loss} - avg_dec: {dec_avg_loss}\n",
|
||||||
|
f"{SEPARATOR}\n",
|
||||||
|
f"Validation Losses:\n",
|
||||||
|
f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n",
|
||||||
|
f"{SEPARATOR}\n",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(DEBUG_TEXT)
|
||||||
|
|
||||||
|
# Warn about patience
|
||||||
|
if patience == PATIENCE:
|
||||||
|
print("Model is likely overfitting, so let's stop here")
|
||||||
|
|
||||||
|
# SAVE MODEL
|
||||||
|
if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE:
|
||||||
|
print(f"Saving model at {CHECKPOINT_PATH.as_posix()}")
|
||||||
|
torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)
|
||||||
@ -3,18 +3,32 @@ import sys
|
|||||||
from typing import Any, Generator
|
from typing import Any, Generator
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from Project_Model.Libs.Batch.Enums.TaskType import TaskType
|
from ..Enums import TaskType
|
||||||
import Project_Model.Libs.BPE as BPE
|
import Project_Model.Libs.BPE as BPE
|
||||||
|
|
||||||
# from Scripts.Libs.CleaningPipeline.special_token import SpecialToken
|
# from Scripts.Libs.CleaningPipeline.special_token import SpecialToken
|
||||||
from Project_Model.Libs.Transformer import SpannedMasker, truncate_rdf_list, normalize_sequence
|
from Project_Model.Libs.Transformer import (
|
||||||
from TokenCompletation import TokenCompletationTransformer
|
SpannedMasker,
|
||||||
|
truncate_rdf_list,
|
||||||
|
normalize_sequence,
|
||||||
|
)
|
||||||
|
|
||||||
from Project_Model.Libs.BPE import SpecialToken
|
from Project_Model.Libs.BPE import SpecialToken
|
||||||
|
|
||||||
|
|
||||||
MAX_LENGHT = 128
|
|
||||||
|
|
||||||
|
|
||||||
class Batcher:
|
class Batcher:
|
||||||
|
|
||||||
def __init__(self, dataset_path: Path, tokenizer: BPE.TokeNanoCore, masker: SpannedMasker, seed:int = 0) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_path: Path,
|
||||||
|
max_length: int,
|
||||||
|
tokenizer: BPE.TokeNanoCore,
|
||||||
|
masker: SpannedMasker,
|
||||||
|
seed: int = 0,
|
||||||
|
) -> None:
|
||||||
# ABSTRACT, TRIPLE
|
# ABSTRACT, TRIPLE
|
||||||
# tasks:
|
# tasks:
|
||||||
# rdf2text: X: TRIPLE, Y: ABSTRACT
|
# rdf2text: X: TRIPLE, Y: ABSTRACT
|
||||||
@ -26,15 +40,22 @@ class Batcher:
|
|||||||
self._dataset_path = dataset_path
|
self._dataset_path = dataset_path
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
self._masker = masker
|
self._masker = masker
|
||||||
|
self.__max_length = max_length
|
||||||
self._seed = seed
|
self._seed = seed
|
||||||
# self._token_completation = TokenCompletationTransformer(sotl,eos)
|
# self._token_completation = TokenCompletationTransformer(sotl,eos)
|
||||||
self._completation_task_token_truncator = truncate_rdf_list
|
self._completation_task_token_truncator = truncate_rdf_list
|
||||||
|
|
||||||
|
def batch(self, batch_size) -> Generator[
|
||||||
|
tuple[
|
||||||
|
list[list[int]],
|
||||||
def batch(self, batch_size)-> Generator[tuple[list[list[int]], list[list[int]], list[list[int]],list[list[int]], TaskType],Any,Any]:
|
list[list[int]],
|
||||||
|
list[list[int]],
|
||||||
|
list[list[int]],
|
||||||
|
TaskType
|
||||||
|
],
|
||||||
|
Any,
|
||||||
|
Any,
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Yields: X,Y,padding_X
|
Yields: X,Y,padding_X
|
||||||
"""
|
"""
|
||||||
@ -45,18 +66,34 @@ class Batcher:
|
|||||||
|
|
||||||
tokenized_batch = pd.DataFrame()
|
tokenized_batch = pd.DataFrame()
|
||||||
# encode
|
# encode
|
||||||
tokenized_batch[["Abstract","RDFs"]] = (
|
tokenized_batch[["Abstract", "RDFs"]] = batch[["Abstract", "RDFs"]].map(
|
||||||
batch[["Abstract","RDFs"]]
|
lambda t: self._tokenizer.encode(t)
|
||||||
.map(lambda t: self._tokenizer.encode(t))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
X, Y, padding_X, padding_Y = self.__rdf2txt_transformation(tokenized_batch)
|
X, Y, padding_X, padding_Y = self.__rdf2txt_transformation(tokenized_batch)
|
||||||
yield X, Y, padding_X, padding_Y, TaskType.RDF2TXT
|
yield X, Y, padding_X, padding_Y, TaskType.RDF2TXT
|
||||||
X,Y, padding_X, padding_Y, = self.__txt2rdf_transformation(tokenized_batch)
|
(
|
||||||
|
X,
|
||||||
|
Y,
|
||||||
|
padding_X,
|
||||||
|
padding_Y,
|
||||||
|
) = self.__txt2rdf_transformation(tokenized_batch)
|
||||||
yield X, Y, padding_X, padding_Y, TaskType.TEXT2RDF
|
yield X, Y, padding_X, padding_Y, TaskType.TEXT2RDF
|
||||||
X,Y, padding_X, padding_Y, = self.__masking_trasformation(tokenized_batch)
|
(
|
||||||
|
X,
|
||||||
|
Y,
|
||||||
|
padding_X,
|
||||||
|
padding_Y,
|
||||||
|
) = self.__masking_trasformation(tokenized_batch)
|
||||||
yield X, Y, padding_X, padding_Y, TaskType.MASKING
|
yield X, Y, padding_X, padding_Y, TaskType.MASKING
|
||||||
X,Y, padding_X, padding_Y, = self.__token_completation_task(tokenized_batch, RNG.randint(0,sys.maxsize))
|
(
|
||||||
|
X,
|
||||||
|
Y,
|
||||||
|
padding_X,
|
||||||
|
padding_Y,
|
||||||
|
) = self.__token_completation_task(
|
||||||
|
tokenized_batch, RNG.randint(0, sys.maxsize)
|
||||||
|
)
|
||||||
yield X, Y, padding_X, padding_Y, TaskType.COMPLETATION
|
yield X, Y, padding_X, padding_Y, TaskType.COMPLETATION
|
||||||
|
|
||||||
# output = pd.concat([rdf2txt_batch,txt2rdf_batch,completation_batch],ignore_index=True)
|
# output = pd.concat([rdf2txt_batch,txt2rdf_batch,completation_batch],ignore_index=True)
|
||||||
@ -64,7 +101,6 @@ class Batcher:
|
|||||||
# self.decode_debug(output)
|
# self.decode_debug(output)
|
||||||
# yield output
|
# yield output
|
||||||
|
|
||||||
|
|
||||||
def __random_subset_rdfs(self, batch: pd.DataFrame, seed=0):
|
def __random_subset_rdfs(self, batch: pd.DataFrame, seed=0):
|
||||||
# WIP
|
# WIP
|
||||||
rng = random.Random(seed)
|
rng = random.Random(seed)
|
||||||
@ -72,20 +108,16 @@ class Batcher:
|
|||||||
def to_list(x):
|
def to_list(x):
|
||||||
return x.split(SpecialToken.START_TRIPLE.value)[1:]
|
return x.split(SpecialToken.START_TRIPLE.value)[1:]
|
||||||
|
|
||||||
batch["RDFs"] = batch["RDFs"].map(
|
batch["RDFs"] = batch["RDFs"].map(to_list)
|
||||||
to_list
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode_debug(self, batch: pd.DataFrame):
|
def decode_debug(self, batch: pd.DataFrame):
|
||||||
decoded = pd.DataFrame()
|
decoded = pd.DataFrame()
|
||||||
decoded[["X","Y"]] = (
|
decoded[["X", "Y"]] = batch[["X", "Y"]].map(lambda t: self._tokenizer.decode(t))
|
||||||
batch[["X","Y"]]
|
|
||||||
.map(lambda t: self._tokenizer.decode(t))
|
|
||||||
)
|
|
||||||
print(decoded)
|
print(decoded)
|
||||||
|
|
||||||
|
def __normalization(
|
||||||
def __normalization(self, X:list[list[int]], Y: list[list[int]])-> tuple[list[list[int]], list[list[int]], list[list[int]], list[list[int]]]:
|
self, X: list[list[int]], Y: list[list[int]]
|
||||||
|
) -> tuple[list[list[int]], list[list[int]], list[list[int]], list[list[int]]]:
|
||||||
pad_token = self._tokenizer.encode(SpecialToken.PAD.value)[0]
|
pad_token = self._tokenizer.encode(SpecialToken.PAD.value)[0]
|
||||||
end_token = self._tokenizer.encode(SpecialToken.END_OF_SEQUENCE.value)[0]
|
end_token = self._tokenizer.encode(SpecialToken.END_OF_SEQUENCE.value)[0]
|
||||||
out_X = []
|
out_X = []
|
||||||
@ -94,32 +126,33 @@ class Batcher:
|
|||||||
padding_Y = []
|
padding_Y = []
|
||||||
|
|
||||||
for x in X:
|
for x in X:
|
||||||
out_x, padding_x = normalize_sequence(x,MAX_LENGHT,pad_token,end_token,True)
|
out_x, padding_x = normalize_sequence(
|
||||||
|
x, self.__max_length, pad_token, end_token, True
|
||||||
|
)
|
||||||
out_X.append(out_x)
|
out_X.append(out_x)
|
||||||
padding_X.append(padding_x)
|
padding_X.append(padding_x)
|
||||||
|
|
||||||
for y in Y:
|
for y in Y:
|
||||||
out_y, padding_y = normalize_sequence(y,MAX_LENGHT,pad_token,end_token,True)
|
out_y, padding_y = normalize_sequence(
|
||||||
|
y, self.__max_length, pad_token, end_token, True
|
||||||
|
)
|
||||||
out_Y.append(out_y)
|
out_Y.append(out_y)
|
||||||
padding_Y.append(padding_y)
|
padding_Y.append(padding_y)
|
||||||
|
|
||||||
return out_X, out_Y, padding_X, padding_Y
|
return out_X, out_Y, padding_X, padding_Y
|
||||||
|
|
||||||
|
|
||||||
def __rdf2txt_transformation(self, batch: pd.DataFrame):
|
def __rdf2txt_transformation(self, batch: pd.DataFrame):
|
||||||
task_token = self._tokenizer.encode(SpecialToken.RDF_TO_TEXT.value)
|
task_token = self._tokenizer.encode(SpecialToken.RDF_TO_TEXT.value)
|
||||||
out = batch.rename(columns={"RDFs": "X", "Abstract": "Y"})[["X", "Y"]]
|
out = batch.rename(columns={"RDFs": "X", "Abstract": "Y"})[["X", "Y"]]
|
||||||
out["X"] = [task_token + x for x in out["X"]]
|
out["X"] = [task_token + x for x in out["X"]]
|
||||||
return self.__normalization(out["X"].to_list(), out["Y"].to_list())
|
return self.__normalization(out["X"].to_list(), out["Y"].to_list())
|
||||||
|
|
||||||
|
|
||||||
def __txt2rdf_transformation(self, batch: pd.DataFrame):
|
def __txt2rdf_transformation(self, batch: pd.DataFrame):
|
||||||
task_token = self._tokenizer.encode(SpecialToken.TEXT_TO_RDF.value)
|
task_token = self._tokenizer.encode(SpecialToken.TEXT_TO_RDF.value)
|
||||||
out = batch.rename(columns={"Abstract": "X", "RDFs": "Y"})[["X", "Y"]]
|
out = batch.rename(columns={"Abstract": "X", "RDFs": "Y"})[["X", "Y"]]
|
||||||
out["X"] = [task_token + x for x in out["X"]]
|
out["X"] = [task_token + x for x in out["X"]]
|
||||||
return self.__normalization(out["X"].to_list(), out["Y"].to_list())
|
return self.__normalization(out["X"].to_list(), out["Y"].to_list())
|
||||||
|
|
||||||
|
|
||||||
def __masking_trasformation(self, batch: pd.DataFrame):
|
def __masking_trasformation(self, batch: pd.DataFrame):
|
||||||
X = []
|
X = []
|
||||||
Y = []
|
Y = []
|
||||||
@ -129,27 +162,29 @@ class Batcher:
|
|||||||
Y.append(y)
|
Y.append(y)
|
||||||
return self.__normalization(X, Y)
|
return self.__normalization(X, Y)
|
||||||
|
|
||||||
|
|
||||||
def __token_completation_task(self, batch: pd.DataFrame, minibatch_seed: int):
|
def __token_completation_task(self, batch: pd.DataFrame, minibatch_seed: int):
|
||||||
continue_triple_token = self._tokenizer.encode(SpecialToken.CONTINUE_RDF.value)[0]
|
continue_triple_token = self._tokenizer.encode(SpecialToken.CONTINUE_RDF.value)[
|
||||||
|
0
|
||||||
|
]
|
||||||
eot = self._tokenizer.encode(SpecialToken.END_TRIPLE.value)[0]
|
eot = self._tokenizer.encode(SpecialToken.END_TRIPLE.value)[0]
|
||||||
X = []
|
X = []
|
||||||
Y = []
|
Y = []
|
||||||
for rdf in batch["RDFs"]:
|
for rdf in batch["RDFs"]:
|
||||||
x,y = self._completation_task_token_truncator(rdf, 0.5, continue_triple_token, eot, minibatch_seed)
|
x, y = self._completation_task_token_truncator(
|
||||||
|
rdf, 0.5, continue_triple_token, eot, minibatch_seed
|
||||||
|
)
|
||||||
X.append(x)
|
X.append(x)
|
||||||
Y.append(y)
|
Y.append(y)
|
||||||
return self.__normalization(X, Y)
|
return self.__normalization(X, Y)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
DATASET_PATH = Path("Assets/Dataset/Tmp/rdf_text.csv")
|
DATASET_PATH = Path("Assets/Dataset/Tmp/rdf_text.csv")
|
||||||
VOCABULARY_path = "Assets/Dataset/Tmp/trimmed.json"
|
VOCABULARY_path = "Assets/Dataset/Tmp/trimmed.json"
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
VOCABULARY = BPE.load_nanos_vocabulary(Path(VOCABULARY_path))
|
VOCABULARY = BPE.load_nanos_vocabulary(Path(VOCABULARY_path))
|
||||||
SPECIAL_LIST = BPE.default_special_tokens()
|
SPECIAL_LIST = BPE.default_special_tokens()
|
||||||
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_LIST)
|
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_LIST)
|
||||||
|
|||||||
2
Project_Model/Libs/Batch/Classes/__init__.py
Normal file
2
Project_Model/Libs/Batch/Classes/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .Batcher import Batcher
|
||||||
|
from .TokenCompletation import TokenCompletationTransformer
|
||||||
5
Project_Model/Libs/Batch/Enums/__init__.py
Normal file
5
Project_Model/Libs/Batch/Enums/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .TaskType import TaskType
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TaskType"
|
||||||
|
]
|
||||||
5
Project_Model/Libs/Batch/__init__.py
Normal file
5
Project_Model/Libs/Batch/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .Classes import *
|
||||||
|
from .Enums import *
|
||||||
|
|
||||||
|
from . import Classes
|
||||||
|
from . import Enums
|
||||||
@ -10,7 +10,7 @@ class SpannedMasker:
|
|||||||
max_vocabulary: int,
|
max_vocabulary: int,
|
||||||
forbidden_tokens: set[int],
|
forbidden_tokens: set[int],
|
||||||
change_token_probability: float = 0.15,
|
change_token_probability: float = 0.15,
|
||||||
average_span: int = 1,
|
average_span: int = 2,
|
||||||
seed: int = random.randint(0, sys.maxsize),
|
seed: int = random.randint(0, sys.maxsize),
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequen
|
|||||||
from .inference_masking import inference_masking
|
from .inference_masking import inference_masking
|
||||||
from .truncate_rdf_list import truncate_rdf_list
|
from .truncate_rdf_list import truncate_rdf_list
|
||||||
from .decode_out import tensor2token
|
from .decode_out import tensor2token
|
||||||
|
from .decoder_input import get_decoder_input
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -17,4 +18,5 @@ __all__ = [
|
|||||||
"inference_masking",
|
"inference_masking",
|
||||||
"truncate_rdf_list",
|
"truncate_rdf_list",
|
||||||
"tensor2token",
|
"tensor2token",
|
||||||
|
"get_decoder_input"
|
||||||
]
|
]
|
||||||
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from Project_Model.Libs.Transformer import normalize_sequence
|
from ..Utils import normalize_sequence
|
||||||
# from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder
|
# from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user