{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "ddfb4457", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "EPOCH 1\n", "\tLoss: 7.424792\n", "[0] \n", "[1] \n", "[2] \n", "[3] \n", "[4] \n", "[5] \n", "[6] \n", "[7] \n", "[8] \n", "[9] \n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 114\u001b[39m\n\u001b[32m 112\u001b[39m loss_t = cross_entropy(logits_t, tgt[:, t]) \u001b[38;5;66;03m# CE expects raw logits; PAD ignored\u001b[39;00m\n\u001b[32m 113\u001b[39m loss_t.backward() \u001b[38;5;66;03m# backprop for this step\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m114\u001b[39m \u001b[43moptimizer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# update params\u001b[39;00m\n\u001b[32m 115\u001b[39m scheduler.step() \u001b[38;5;66;03m# Noam/warmup: step per optimizer step\u001b[39;00m\n\u001b[32m 117\u001b[39m total_loss = \u001b[38;5;28mfloat\u001b[39m(loss_t.detach()) \u001b[38;5;66;03m# keep last step loss for logging\u001b[39;00m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/deep_learning/lib/python3.13/site-packages/torch/optim/optimizer.py:516\u001b[39m, in \u001b[36mOptimizer.profile_hook_step..wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 511\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 512\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[32m 513\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 514\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m516\u001b[39m out = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 517\u001b[39m \u001b[38;5;28mself\u001b[39m._optimizer_step_code()\n\u001b[32m 519\u001b[39m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/deep_learning/lib/python3.13/site-packages/torch/optim/optimizer.py:81\u001b[39m, in \u001b[36m_use_grad_for_differentiable.._use_grad\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 79\u001b[39m torch.set_grad_enabled(\u001b[38;5;28mself\u001b[39m.defaults[\u001b[33m\"\u001b[39m\u001b[33mdifferentiable\u001b[39m\u001b[33m\"\u001b[39m])\n\u001b[32m 80\u001b[39m torch._dynamo.graph_break()\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m ret = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 82\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 83\u001b[39m torch._dynamo.graph_break()\n", "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/deep_learning/lib/python3.13/site-packages/torch/optim/adam.py:247\u001b[39m, in \u001b[36mAdam.step\u001b[39m\u001b[34m(self, closure)\u001b[39m\n\u001b[32m 235\u001b[39m beta1, beta2 = group[\u001b[33m\"\u001b[39m\u001b[33mbetas\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 237\u001b[39m has_complex = \u001b[38;5;28mself\u001b[39m._init_group(\n\u001b[32m 238\u001b[39m group,\n\u001b[32m 239\u001b[39m params_with_grad,\n\u001b[32m (...)\u001b[39m\u001b[32m 244\u001b[39m state_steps,\n\u001b[32m 245\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m247\u001b[39m \u001b[43madam\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 248\u001b[39m \u001b[43m \u001b[49m\u001b[43mparams_with_grad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 249\u001b[39m \u001b[43m \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 250\u001b[39m \u001b[43m \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 251\u001b[39m \u001b[43m \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 252\u001b[39m \u001b[43m \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 253\u001b[39m \u001b[43m \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 254\u001b[39m \u001b[43m \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mamsgrad\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 255\u001b[39m \u001b[43m \u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[43m=\u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 256\u001b[39m \u001b[43m \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[43m=\u001b[49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 257\u001b[39m \u001b[43m \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[43m=\u001b[49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 258\u001b[39m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mlr\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 259\u001b[39m \u001b[43m \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mweight_decay\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 260\u001b[39m \u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43meps\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 261\u001b[39m \u001b[43m \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmaximize\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 262\u001b[39m \u001b[43m \u001b[49m\u001b[43mforeach\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mforeach\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 263\u001b[39m \u001b[43m \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mcapturable\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 264\u001b[39m \u001b[43m \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdifferentiable\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 265\u001b[39m \u001b[43m \u001b[49m\u001b[43mfused\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mfused\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 266\u001b[39m \u001b[43m \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mgrad_scale\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 267\u001b[39m \u001b[43m \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mfound_inf\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 268\u001b[39m \u001b[43m \u001b[49m\u001b[43mdecoupled_weight_decay\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdecoupled_weight_decay\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 269\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 271\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n", "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/deep_learning/lib/python3.13/site-packages/torch/optim/optimizer.py:149\u001b[39m, in \u001b[36m_disable_dynamo_if_unsupported..wrapper..maybe_fallback\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 147\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m disabled_func(*args, **kwargs)\n\u001b[32m 148\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m149\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/deep_learning/lib/python3.13/site-packages/torch/optim/adam.py:949\u001b[39m, in \u001b[36madam\u001b[39m\u001b[34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, differentiable, fused, grad_scale, found_inf, has_complex, decoupled_weight_decay, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)\u001b[39m\n\u001b[32m 946\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 947\u001b[39m func = _single_tensor_adam\n\u001b[32m--> \u001b[39m\u001b[32m949\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 950\u001b[39m \u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 951\u001b[39m \u001b[43m \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 952\u001b[39m \u001b[43m \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 953\u001b[39m \u001b[43m \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 954\u001b[39m \u001b[43m \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 955\u001b[39m \u001b[43m \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 956\u001b[39m \u001b[43m \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[43m=\u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 957\u001b[39m \u001b[43m \u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[43m=\u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 958\u001b[39m \u001b[43m \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[43m=\u001b[49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 959\u001b[39m \u001b[43m \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[43m=\u001b[49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 960\u001b[39m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 961\u001b[39m \u001b[43m \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[43m=\u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 962\u001b[39m \u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m=\u001b[49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 963\u001b[39m \u001b[43m \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmaximize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 964\u001b[39m \u001b[43m \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcapturable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 965\u001b[39m \u001b[43m \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 966\u001b[39m \u001b[43m \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 967\u001b[39m \u001b[43m \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 968\u001b[39m \u001b[43m \u001b[49m\u001b[43mdecoupled_weight_decay\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdecoupled_weight_decay\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 969\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/deep_learning/lib/python3.13/site-packages/torch/optim/adam.py:411\u001b[39m, in \u001b[36m_single_tensor_adam\u001b[39m\u001b[34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, grad_scale, found_inf, amsgrad, has_complex, beta1, beta2, lr, weight_decay, eps, maximize, capturable, differentiable, decoupled_weight_decay)\u001b[39m\n\u001b[32m 408\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m weight_decay != \u001b[32m0\u001b[39m:\n\u001b[32m 409\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m decoupled_weight_decay:\n\u001b[32m 410\u001b[39m \u001b[38;5;66;03m# Perform stepweight decay\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m411\u001b[39m \u001b[43mparam\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmul_\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m-\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 412\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 413\u001b[39m \u001b[38;5;66;03m# Nested if is necessary to bypass jitscript rules\u001b[39;00m\n\u001b[32m 414\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m differentiable \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(weight_decay, Tensor):\n", "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/deep_learning/lib/python3.13/site-packages/torch/utils/_device.py:103\u001b[39m, in \u001b[36mDeviceContext.__torch_function__\u001b[39m\u001b[34m(self, func, types, args, kwargs)\u001b[39m\n\u001b[32m 101\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m func \u001b[38;5;129;01min\u001b[39;00m _device_constructors() \u001b[38;5;129;01mand\u001b[39;00m kwargs.get(\u001b[33m\"\u001b[39m\u001b[33mdevice\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 102\u001b[39m kwargs[\u001b[33m\"\u001b[39m\u001b[33mdevice\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[38;5;28mself\u001b[39m.device\n\u001b[32m--> \u001b[39m\u001b[32m103\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[31mKeyboardInterrupt\u001b[39m: " ] } ], "source": [ "import random\n", "import torch\n", "import pandas as pd\n", "from pathlib import Path\n", "import Project_Model.Libs.Embedder as Embedder\n", "import Project_Model.Libs.BPE as BPE\n", "import Project_Model.Libs.Transformer as Transformer\n", "import Project_Model.Libs.TorchShims as torch_shims\n", "from Project_Model.Libs.Training.learning_rade_shedulers import Custom_lr\n", "from Project_Model.Libs.Training.logistic_collector import LogitsCollector # import the external collector\n", "\n", "# set a fixed seed\n", "torch.manual_seed(0)\n", "random.seed(0)\n", "DEVICE = torch_shims.get_default_device()\n", "torch.set_default_device(DEVICE)\n", "\n", "# BPE Init\n", "VOCABULARY_PATH = Path(\"Assets/Model/toy_10/toy_dictionary.json\")\n", "SPECIAL_VOC = BPE.default_special_tokens()\n", "\n", "VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)\n", "TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)\n", "\n", "# Constants\n", "TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + 1\n", "EMBEDDED_SIZE = 256\n", "FEED_FORWARD_MULTIPLIER = 4\n", "ATTENTION_HEADS = 4\n", "SENTENCE_LENGTH = 256\n", "NUMBER_OF_BLOCKS = 2\n", "MAX_EPOCHS = int(1e3)\n", "\n", "PAD_TOKEN = TOKENANO.encode(\"\")[0]\n", "END_TOKEN = TOKENANO.encode(\"\")[0]\n", "\n", "# Load CSV\n", "TOY_DATASET_PATH = Path(\"Assets/Dataset/1-hop/toy/rdf_text.csv\")\n", "TOY_DATASET = pd.read_csv(TOY_DATASET_PATH)\n", "\n", "TOY_BATCH_INPUT_LIST: list[list[int]] = []\n", "TOY_BATCH_PADDING_LIST: list[list[bool]] = []\n", "TOY_BATCH_TARGET_LIST: list[list[int]] = []\n", "TOY_BATCH_DECODER_DEFAULT: list[list[int]] = []\n", "\n", "for index, row in TOY_DATASET.iterrows():\n", " RDFs: str = row[\"RDFs\"]\n", " Abstract: str = row[\"Abstract\"]\n", "\n", " input_tokens = TOKENANO.encode(RDFs) # encoder input ids\n", " output_tokens = TOKENANO.encode(Abstract)[1:] # decoder target ids (shifted left)\n", " decoder_default_tokens = TOKENANO.encode(\"\") # decoder input starts with \n", "\n", " input_tokens, padding = Transformer.normalize_sequence(\n", " input_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN\n", " ) # pad/trim + end token\n", " output_tokens, _ = Transformer.normalize_sequence(\n", " output_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN\n", " ) # pad/trim + end token\n", " decoder_default_tokens = Transformer.pad_sequence(\n", " decoder_default_tokens, SENTENCE_LENGTH, PAD_TOKEN\n", " ) # pad with PAD up to SENTENCE_LENGTH\n", "\n", " TOY_BATCH_INPUT_LIST.append(input_tokens)\n", " TOY_BATCH_PADDING_LIST.append(padding)\n", " TOY_BATCH_TARGET_LIST.append(output_tokens)\n", " TOY_BATCH_DECODER_DEFAULT.append(decoder_default_tokens)\n", "\n", "# Training loop\n", "LOSS_HISTORY = []\n", "NANOSOCRATES = Transformer.TrainingModel(\n", " TOKEN_SPACE_SIZE,\n", " EMBEDDED_SIZE,\n", " FEED_FORWARD_MULTIPLIER,\n", " ATTENTION_HEADS,\n", " NUMBER_OF_BLOCKS,\n", ")\n", "\n", "collector = LogitsCollector(PAD_TOKEN, END_TOKEN, TOKENANO) # collects logits and decodes\n", "\n", "NANOSOCRATES.train()\n", "cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)\n", "optimizer = torch.optim.AdamW(NANOSOCRATES.parameters())\n", "scheduler = Custom_lr(EMBEDDED_SIZE, 4000) # step each optimizer step\n", "\n", "current_epoch = 0\n", "BATCH_SIZE = min(32, len(TOY_BATCH_INPUT_LIST)) # small batch to stabilize\n", "\n", "while current_epoch < MAX_EPOCHS:\n", " # simple fixed mini-batch from the top; later you can shuffle/slice\n", " enc = torch.tensor(TOY_BATCH_INPUT_LIST[:BATCH_SIZE], dtype=torch.long) # [B,T] encoder token ids\n", " pad = torch.tensor(TOY_BATCH_PADDING_LIST[:BATCH_SIZE], dtype=torch.bool) # [B,T] True where encoder PAD is present\n", " tgt = torch.tensor(TOY_BATCH_TARGET_LIST[:BATCH_SIZE], dtype=torch.long) # [B,T] decoder targets (ground-truth)\n", "\n", " # decoder prefix buffer: at pos 0, PAD elsewhere (no shift here) # we will fill it step by step\n", " dec = torch.tensor(TOY_BATCH_DECODER_DEFAULT[:BATCH_SIZE], dtype=torch.long) # [B,T]\n", "\n", " total_loss = 0.0\n", " collector.reset() # start fresh for this epoch\n", "\n", " T = tgt.size(1) # sequence length\n", " for t in range(T):\n", " optimizer.zero_grad(set_to_none=True) # clear grads for this token step\n", "\n", " prefix = dec[:, : t + 1] # [B, t+1] current decoder prefix\n", " dec_pad_mask = prefix.eq(PAD_TOKEN) # [B, t+1] True where PAD inside prefix\n", "\n", " # one-step logits given prefix (trainer model expects 4 args now)\n", " logits_t: torch.Tensor = NANOSOCRATES((enc, pad, prefix, dec_pad_mask)) # [B,V] logits for step t\n", " collector.add(logits_t) # store logits for decoding later\n", "\n", " loss_t = cross_entropy(logits_t, tgt[:, t]) # CE expects raw logits; PAD ignored\n", " loss_t.backward() # backprop for this step\n", " optimizer.step() # update params\n", " scheduler.step() # Noam/warmup: step per optimizer step\n", "\n", " total_loss = float(loss_t.detach()) # keep last step loss for logging\n", "\n", " # teacher forcing: reveal the correct token for next position\n", " if t < T - 1:\n", " dec[:, t + 1] = tgt[:, t] # write ground-truth into next slot\n", "\n", " current_epoch += 1\n", " print(f\"EPOCH {current_epoch}\\n\\tLoss: {total_loss:.6f}\") # simple log\n", " collector.print_decoded() # print decoded predictions for the batch\n" ] } ], "metadata": { "kernelspec": { "display_name": "deep_learning", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.7" } }, "nbformat": 4, "nbformat_minor": 5 }