diff --git a/Playgrounds/encoder-decoder.ipynb b/Playgrounds/encoder-decoder.ipynb new file mode 100644 index 0000000..1540d13 --- /dev/null +++ b/Playgrounds/encoder-decoder.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "7a311d4b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[7706, 290, 756, 4270, 7357, 115, 351, 1507, 1213, 410, 3382, 317, 497, 4740, 2784, 7712], [7706, 290, 756, 4270, 7357, 115, 351, 1507, 1213, 410, 3382, 317, 497, 4740, 2784, 7712], [7706, 290, 756, 4270, 7357, 115, 351, 1507, 1213, 410, 3382, 317, 497, 4740, 2784, 7712]]\n", + "3\n", + "Embedder Tensor: torch.Size([3, 16, 256])\n", + "Values:\n", + "tensor([[[-0.6981, 0.0804, -2.1672, ..., 0.3919, 0.3341, 1.0794],\n", + " [ 2.5818, -0.2308, 0.6001, ..., -0.0500, -0.0408, -0.9852],\n", + " [-0.6967, 0.8109, 1.3108, ..., 2.1693, 1.4143, -0.1236],\n", + " ...,\n", + " [ 2.1226, 2.5695, -1.6178, ..., -0.0652, -0.0802, 0.1103],\n", + " [ 0.8770, -2.4782, 0.8536, ..., 2.0471, -1.5702, 0.7387],\n", + " [ 1.4284, -0.4654, 0.1394, ..., 1.6520, 0.6728, 1.3851]],\n", + "\n", + " [[-0.6981, 0.0804, -2.1672, ..., 0.3919, 0.3341, 1.0794],\n", + " [ 2.5818, -0.2308, 0.6001, ..., -0.0500, -0.0408, -0.9852],\n", + " [-0.6967, 0.8109, 1.3108, ..., 2.1693, 1.4143, -0.1236],\n", + " ...,\n", + " [ 2.1226, 2.5695, -1.6178, ..., -0.0652, -0.0802, 0.1103],\n", + " [ 0.8770, -2.4782, 0.8536, ..., 2.0471, -1.5702, 0.7387],\n", + " [ 1.4284, -0.4654, 0.1394, ..., 1.6520, 0.6728, 1.3851]],\n", + "\n", + " [[-0.6981, 0.0804, -2.1672, ..., 0.3919, 0.3341, 1.0794],\n", + " [ 2.5818, -0.2308, 0.6001, ..., -0.0500, -0.0408, -0.9852],\n", + " [-0.6967, 0.8109, 1.3108, ..., 2.1693, 1.4143, -0.1236],\n", + " ...,\n", + " [ 2.1226, 2.5695, -1.6178, ..., -0.0652, -0.0802, 0.1103],\n", + " [ 0.8770, -2.4782, 0.8536, ..., 2.0471, -1.5702, 0.7387],\n", + " [ 1.4284, -0.4654, 0.1394, ..., 1.6520, 0.6728, 1.3851]]],\n", + " grad_fn=)\n", + "ENCODER Tensor: torch.Size([3, 1, 256])\n", + "Values:\n", + "tensor([[[ 8.0069e-01, 4.0532e-01, -1.8316e+00, -1.3902e+00, -1.1784e+00,\n", + " 1.3667e+00, -9.7890e-01, 6.0696e-01, -1.4899e+00, 5.5765e-01,\n", + " 4.5991e-02, 5.1214e-01, 3.1901e-01, 4.7577e-01, -2.9585e-01,\n", + " -1.0811e+00, -1.5281e+00, -6.3773e-01, -9.5954e-01, 1.8497e+00,\n", + " -1.1789e+00, -9.7387e-01, 1.1931e-01, -7.2703e-01, 5.3108e-01,\n", + " -6.4877e-01, -4.5188e-01, 1.5185e+00, -8.3408e-01, 3.2824e-01,\n", + " -1.8166e+00, 1.9548e+00, -5.2419e-01, -1.0693e+00, -1.8510e+00,\n", + " 1.5440e+00, -3.2370e-01, -1.3990e+00, -4.6940e-01, 6.5840e-02,\n", + " -9.2057e-01, 1.2513e+00, -5.9168e-01, 7.8198e-01, -1.3121e+00,\n", + " 1.1492e+00, -2.3695e-01, -1.8935e+00, 1.1639e+00, -5.8169e-01,\n", + " 2.5051e-01, -8.1654e-01, -1.0328e+00, 1.4285e+00, -8.1485e-01,\n", + " 1.0614e+00, -3.3834e-01, -4.1667e-02, -1.1920e-01, 3.1383e-01,\n", + " -5.9857e-01, 1.7327e-01, -1.6854e+00, -1.5174e+00, -2.6508e-01,\n", + " -6.0082e-01, 5.1468e-01, 2.7909e-01, -2.5296e-01, -1.4670e+00,\n", + " -1.3587e+00, -8.8864e-02, 3.2825e-01, 1.0950e+00, -1.0371e+00,\n", + " 1.1744e+00, 5.2984e-01, 4.1751e-01, -9.8803e-01, 3.5631e-01,\n", + " 4.7484e-01, 2.2435e-01, 1.4022e+00, 1.2242e+00, 1.1447e+00,\n", + " -5.4052e-01, -9.1786e-01, -1.2299e+00, 1.1656e+00, 9.1570e-01,\n", + " 1.8956e+00, 7.4344e-01, 4.2187e-01, -9.5426e-02, -3.2428e-01,\n", + " 9.6364e-01, -2.3252e-01, 2.9036e-01, -2.4432e+00, 9.8019e-01,\n", + " -4.6697e-02, 8.3910e-01, -4.3541e-01, -7.1915e-01, -7.5638e-01,\n", + " 9.0217e-01, 2.0919e+00, -7.9533e-01, -1.5413e-01, -6.9260e-01,\n", + " -1.3086e+00, 7.8925e-01, 1.8855e-01, 7.4043e-01, -3.8834e-01,\n", + " 1.0272e-02, 1.0763e+00, 4.2142e-01, 6.6520e-01, 4.5996e-01,\n", + " -8.5060e-01, -9.0101e-01, -4.2090e-01, 2.5596e-01, -1.4946e+00,\n", + " 1.0925e-01, -7.5359e-01, -3.0447e-01, 1.0679e+00, 1.9398e+00,\n", + " 8.1472e-01, 1.3498e+00, 1.1107e+00, 6.3288e-01, 3.1149e-01,\n", + " -1.9333e+00, -1.5274e+00, 2.1794e-01, -3.1895e-02, 1.0756e+00,\n", + " 1.0215e+00, 1.6938e+00, -1.0939e+00, 2.2690e+00, -7.0921e-01,\n", + " 6.4212e-01, -6.5468e-01, 1.6839e+00, 5.7296e-01, -1.4031e+00,\n", + " 3.9133e-01, -5.3541e-01, 4.3439e-01, -1.6785e+00, 5.2030e-03,\n", + " 4.5155e-01, -7.0953e-01, -1.9656e-01, -3.8671e-02, -1.0927e+00,\n", + " -3.0405e-01, -1.3818e-02, -3.7748e-01, 1.4412e+00, -1.4254e-01,\n", + " 7.9939e-01, -8.5402e-01, -1.0330e+00, 1.7661e+00, -3.6084e-01,\n", + " 1.5622e+00, 1.0240e+00, 1.9056e-01, -4.1480e-01, 6.9056e-01,\n", + " 1.7204e+00, -9.9218e-01, -1.6504e-01, -1.1807e+00, 1.0827e+00,\n", + " 1.5973e+00, 1.4849e-01, -2.2867e+00, 7.7322e-01, -6.8401e-01,\n", + " -6.0493e-01, 1.0616e+00, -1.8034e-01, -1.8828e+00, 1.1031e-01,\n", + " 2.5452e-01, -4.2489e-02, 8.1171e-01, 1.3429e+00, -6.5058e-01,\n", + " -1.3531e+00, -1.2263e+00, 1.1226e+00, 1.2407e+00, -9.7453e-01,\n", + " 9.4696e-01, 6.6186e-01, -5.0804e-01, 1.2647e-01, -1.1777e+00,\n", + " 6.8443e-02, -1.3043e-01, 2.9595e-01, -1.5330e+00, -6.5733e-01,\n", + " 1.1291e+00, 6.9629e-01, 4.4690e-01, 8.0151e-01, -1.2406e+00,\n", + " 2.6085e+00, -2.0310e-01, -1.0226e+00, -6.9182e-02, 7.6600e-01,\n", + " -9.9842e-01, 2.0896e+00, 2.6334e-01, -1.1559e-01, -6.6876e-01,\n", + " -6.6295e-01, -1.6461e-01, 2.8270e+00, 3.2727e-01, 1.3724e+00,\n", + " -1.0749e+00, 3.7782e-01, -1.5472e+00, 3.0822e-01, 5.7273e-02,\n", + " 3.9136e-01, 8.2948e-01, 2.1438e-01, -9.8623e-01, 5.6053e-01,\n", + " -1.5617e+00, -3.9595e-01, 1.0451e-02, -1.1860e+00, -1.4994e-01,\n", + " 1.6566e+00, 2.0369e+00, -4.3995e-01, -4.4262e-01, -3.1014e-01,\n", + " 5.9083e-01, -1.0765e+00, -5.2906e-01, 4.6039e-02, -1.0154e+00,\n", + " 5.9942e-01]],\n", + "\n", + " [[ 1.2683e+00, -4.3200e-01, -1.3333e+00, -3.6705e-01, -5.8895e-01,\n", + " 9.9266e-01, -4.2914e-01, 9.2765e-01, -1.0935e+00, 1.4975e+00,\n", + " -5.3739e-01, -2.8332e-01, 9.1166e-01, 1.5010e+00, -2.1787e-01,\n", + " -1.4258e+00, -2.7524e-01, -1.2602e+00, 2.0117e-01, 2.3906e+00,\n", + " -9.6397e-01, -7.5872e-01, 3.3948e-01, -7.9353e-01, 9.1668e-01,\n", + " 8.7734e-04, -3.0271e-01, 1.7087e+00, -1.0273e+00, 1.5174e+00,\n", + " -2.6405e-02, 1.4236e+00, -9.9093e-01, 5.4787e-01, -1.0904e+00,\n", + " 5.2156e-01, -6.3470e-01, -7.7688e-01, -1.2538e+00, -3.9307e-01,\n", + " -7.6707e-01, 1.3733e+00, -7.2709e-01, 1.1185e+00, -1.5860e+00,\n", + " -2.6148e-01, -3.7984e-01, -1.3604e+00, 9.2864e-02, -7.9642e-01,\n", + " 1.0956e+00, 3.1202e-01, -4.1234e-01, 3.6488e-02, -1.4639e+00,\n", + " 1.0947e+00, -7.9230e-01, 4.6913e-01, -2.3407e-01, 4.1768e-02,\n", + " -1.5921e+00, 6.9743e-01, -7.0222e-01, -5.4705e-01, -6.5663e-01,\n", + " -4.1810e-01, 2.7744e-01, 7.9178e-01, 7.5886e-01, -7.6302e-01,\n", + " -1.2204e+00, -1.1103e+00, -1.3646e-01, 1.9589e+00, -1.3637e+00,\n", + " 9.0804e-01, 2.3094e-01, -5.5953e-02, -6.7626e-01, 1.4242e+00,\n", + " 1.0167e+00, 1.0705e+00, 2.2947e+00, 9.1274e-01, 1.2281e+00,\n", + " -7.0638e-01, -1.2249e+00, -8.9208e-02, 1.1016e+00, 1.1940e+00,\n", + " 3.5834e-01, 1.2961e+00, -4.6674e-01, 3.4572e-01, -4.3458e-01,\n", + " 1.1008e+00, 3.7783e-01, -6.5841e-01, -2.3127e+00, 1.4617e+00,\n", + " -1.2826e-01, 1.3463e-01, -8.5268e-01, -8.4144e-01, -1.8594e+00,\n", + " 1.9260e-01, 1.6432e+00, -2.0640e-02, -5.0030e-01, -1.5334e-01,\n", + " -6.1072e-01, -1.3694e-01, -3.7308e-01, 1.6603e+00, 1.1246e-01,\n", + " 6.0823e-02, 7.8749e-01, -1.7002e-01, 1.2058e+00, 8.5615e-01,\n", + " 1.2525e-01, -1.0584e+00, -4.7931e-01, 1.4088e-01, -1.8149e+00,\n", + " 1.4654e+00, -1.0936e+00, 5.3182e-01, 9.5694e-01, 3.2472e+00,\n", + " 3.4877e-01, 1.8491e+00, -1.5184e-01, 1.4711e+00, -7.6064e-01,\n", + " -2.2144e+00, -1.8952e+00, -4.9502e-01, -6.6836e-01, 1.4946e+00,\n", + " 6.7616e-01, 1.1501e+00, -9.4747e-01, 1.1009e+00, -1.4211e+00,\n", + " 3.9528e-01, -9.5220e-01, 1.4886e+00, 7.1784e-01, -1.9941e+00,\n", + " 6.7901e-02, -1.3109e-01, 1.1695e+00, 1.2861e-01, -2.8123e-01,\n", + " -6.1611e-01, 1.5513e-01, -3.9289e-01, -4.5543e-02, -2.8628e-01,\n", + " 2.6118e-01, 2.2623e-01, -6.3705e-01, 7.3591e-01, -7.8799e-01,\n", + " 2.5053e-01, -1.5923e-01, -4.9584e-01, 1.9009e+00, -2.3263e-01,\n", + " 1.2213e+00, 1.0313e+00, 2.0177e-02, -6.2209e-01, -3.5161e-01,\n", + " 1.5143e+00, -7.2332e-02, 2.3909e-02, -2.1261e+00, 8.5199e-01,\n", + " 1.9084e+00, 4.6845e-02, -2.3554e+00, 1.3735e+00, -7.3909e-01,\n", + " -8.3949e-01, -3.9314e-01, -4.3324e-01, -9.6804e-01, -5.3124e-01,\n", + " -6.5091e-01, -1.1738e+00, 1.3315e+00, 6.5606e-01, -1.4131e-01,\n", + " -1.7712e+00, -1.1628e+00, 9.6813e-01, 8.7314e-01, -9.8027e-01,\n", + " 6.9376e-01, 5.3878e-01, -1.6169e+00, 2.2860e-01, -6.2179e-01,\n", + " -1.1043e-01, -3.9658e-01, 2.8712e-01, 8.2201e-02, 2.0888e-01,\n", + " -5.9884e-01, 7.3092e-01, 6.9128e-01, 5.3977e-01, -1.5728e+00,\n", + " 1.6878e+00, -8.2669e-01, -9.8076e-01, -3.4203e-01, 4.6939e-02,\n", + " -1.3158e-01, 2.1923e+00, -6.6483e-02, -4.0687e-01, -1.2715e+00,\n", + " -8.1549e-01, -1.2047e+00, 1.3547e+00, -4.2072e-01, 1.1674e+00,\n", + " -5.1421e-01, 1.3055e+00, -1.1277e+00, 1.8372e+00, -1.1215e+00,\n", + " 1.4797e+00, 2.8354e-01, -6.3974e-01, -1.2869e+00, -2.7897e-01,\n", + " -1.0397e+00, 1.8622e-01, -5.0397e-02, -4.4865e-02, -7.6067e-01,\n", + " 1.7715e+00, 1.5040e+00, -2.6854e-01, -5.2802e-01, -5.3407e-01,\n", + " 2.0313e-02, -2.6276e-01, -7.0748e-01, -8.7328e-01, -3.4108e-01,\n", + " 1.4313e+00]],\n", + "\n", + " [[ 7.7464e-01, -4.2187e-01, -2.0571e+00, -8.6709e-01, -1.5722e+00,\n", + " 4.9540e-01, -1.5270e+00, 1.0499e+00, -1.9579e+00, -2.5298e-02,\n", + " 4.3419e-01, 5.8822e-01, 1.3392e+00, 6.9604e-01, -9.7883e-01,\n", + " -9.1354e-01, -9.1852e-01, -6.0951e-01, -6.6255e-02, 1.3907e+00,\n", + " -6.2912e-01, -2.7524e-01, 1.9520e-02, -2.7154e-01, 1.5162e-01,\n", + " 1.3318e-02, -8.9196e-01, 9.0976e-01, -1.3544e+00, 2.4276e-01,\n", + " -7.4038e-01, 9.7062e-01, 3.2011e-01, 3.4486e-01, -2.3374e+00,\n", + " 1.3311e+00, -3.1871e-02, -1.4468e+00, -1.5968e+00, 3.0418e-01,\n", + " -7.7136e-01, 1.3427e+00, -1.2493e+00, 1.4114e+00, -1.2475e+00,\n", + " 7.0239e-01, -9.6120e-02, -4.4365e-01, 5.3238e-01, -1.4933e+00,\n", + " 5.4476e-01, -1.8490e-02, -5.9936e-01, 1.0878e+00, -1.8892e+00,\n", + " 1.2810e+00, -1.0747e+00, 5.3514e-01, 1.7422e-01, 1.1354e+00,\n", + " -7.4837e-01, 4.0327e-01, -1.8950e+00, -7.2336e-01, 2.4441e-01,\n", + " -1.3650e-01, -4.8344e-01, 3.3921e-02, 5.0889e-01, -1.3769e+00,\n", + " -2.5907e-01, -2.7549e-01, -1.9128e-01, 1.9751e+00, -7.1191e-01,\n", + " 5.1910e-01, 1.0902e-01, 2.9995e-01, -3.5180e-01, -6.2139e-01,\n", + " 7.2905e-01, -5.3177e-01, 4.3340e-01, 1.0071e+00, 1.7586e+00,\n", + " -3.9963e-01, -2.5139e-01, -9.4213e-01, 9.2847e-01, 1.1298e+00,\n", + " 7.8545e-01, 1.3188e+00, 3.7466e-01, 9.0773e-01, -4.0454e-02,\n", + " 1.3444e+00, 6.0301e-01, 8.9929e-02, -2.0754e+00, 4.8614e-01,\n", + " -9.7160e-01, 8.2446e-01, -1.1813e+00, -9.6185e-01, -9.2922e-02,\n", + " 6.0154e-01, 1.6640e+00, -1.0461e+00, 1.5868e-01, -5.7239e-01,\n", + " -6.2726e-01, 3.2848e-01, 5.9609e-01, 1.5563e+00, -4.0883e-01,\n", + " 4.4902e-01, 1.4004e+00, 2.2426e-01, 3.8314e-01, -2.0641e-01,\n", + " -1.6465e-01, -6.4645e-01, 1.5772e-01, 6.8907e-01, -1.2703e+00,\n", + " 1.8914e-01, -6.2678e-01, 3.0179e-01, 1.2687e+00, 1.6849e+00,\n", + " 1.5690e+00, 1.0999e+00, 1.5820e+00, -6.4808e-01, 5.1003e-01,\n", + " -1.6674e+00, -1.2224e+00, 1.9769e-01, -1.3883e-01, 1.2179e+00,\n", + " 1.2971e+00, 4.6259e-01, -5.8717e-01, 1.4532e+00, -1.0540e+00,\n", + " 2.8689e-01, -1.3895e+00, 1.4014e+00, -4.0430e-01, -2.6099e+00,\n", + " -1.0293e+00, -1.1097e+00, 8.6266e-01, -1.0535e+00, 7.1789e-01,\n", + " 6.0642e-01, -1.2493e+00, -3.7762e-01, -4.1281e-02, -7.3049e-01,\n", + " -7.2913e-04, -7.3122e-02, -2.3850e-01, 1.2546e+00, 1.8802e-01,\n", + " 1.3135e+00, -5.0367e-01, 1.2456e-01, 2.7475e+00, -1.2486e+00,\n", + " 1.4441e+00, 8.7469e-01, -5.6901e-01, -1.2145e-01, 3.1091e-01,\n", + " 1.9406e+00, -8.1891e-01, 3.1316e-02, -1.2867e+00, 8.0780e-01,\n", + " 7.0041e-01, 2.8903e-01, -1.6387e+00, 6.6553e-01, -1.3696e+00,\n", + " -7.9454e-01, 3.3899e-01, -5.5822e-01, -8.1969e-01, -1.2410e-01,\n", + " -3.7024e-01, -7.2536e-01, 7.5648e-01, 1.6899e+00, -1.7404e-01,\n", + " -1.7191e+00, -7.2603e-01, 1.5046e+00, 8.3216e-01, -1.5304e+00,\n", + " -1.8264e-01, 3.3451e-01, -5.6636e-02, 6.1099e-01, -9.8517e-01,\n", + " 4.4856e-01, -8.6275e-01, 6.9264e-02, -1.1572e+00, 2.3373e-01,\n", + " 5.9896e-01, 1.2384e-01, 1.0309e+00, 1.4273e+00, -8.4776e-01,\n", + " 2.6236e+00, -9.0133e-01, -4.0009e-01, -4.9727e-01, 3.7945e-01,\n", + " -9.0712e-01, 1.5725e+00, 1.6298e-01, 1.1544e-01, -4.3125e-01,\n", + " -8.7131e-01, -2.5880e-01, 2.9032e+00, 2.7154e-01, 1.3677e+00,\n", + " -8.8544e-01, 5.6083e-01, -1.8256e+00, 9.4832e-01, -1.0762e+00,\n", + " 7.5421e-01, 6.5008e-01, -8.6361e-01, -1.4911e+00, -7.5930e-02,\n", + " -1.6896e+00, 1.5223e-02, -1.5283e-01, -1.8741e+00, 1.1400e-01,\n", + " 1.8822e+00, 2.6615e+00, 2.1607e-01, -5.6243e-01, 3.6730e-01,\n", + " 4.0374e-01, -1.1973e+00, -5.3006e-01, -3.4750e-01, -4.4187e-01,\n", + " 7.4358e-01]]], grad_fn=)\n" + ] + } + ], + "source": [ + "import random\n", + "import torch\n", + "from pathlib import Path\n", + "import Project_Model.Libs.Embedder as Embedder\n", + "import Project_Model.Libs.BPE as BPE\n", + "import Project_Model.Libs.Transformer as Transformer\n", + "\n", + "# set a fixed seed\n", + "torch.manual_seed(0)\n", + "random.seed(0)\n", + "\n", + "TEXT = (\n", + " \"The Dark Knight is a 2008 superhero film directed by Christopher Nolan,\"\n", + ")\n", + "OUT_TEXT = \"\"\n", + "VOCABULARY_PATH = Path(\"Assets/Model/toy_10/toy_dictionary.json\")\n", + "SPECIAL_VOC = BPE.default_special_tokens()\n", + "\n", + "VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)\n", + "TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)\n", + "\n", + "PAD_TOKEN = TOKENANO.encode(\"\")[0]\n", + "END_TOKEN = TOKENANO.encode(\"\")[0]\n", + "\n", + "ENCODER_INPUT = TOKENANO.encode(TEXT)\n", + "DECODER_INPUT = TOKENANO.encode(OUT_TEXT)\n", + "MAX_LEN = len(ENCODER_INPUT) + 1\n", + "\n", + "EN_IN, PAD_MASK = Transformer.normalize_sequence(ENCODER_INPUT, MAX_LEN, PAD_TOKEN, END_TOKEN)\n", + "DEC_IN, _ = Transformer.normalize_sequence(DECODER_INPUT, MAX_LEN, PAD_TOKEN, END_TOKEN)\n", + "BATCH_LEN = 3\n", + "\n", + "INPUT_TOKENIZATION = [\n", + " EN_IN\n", + "] * BATCH_LEN\n", + "OUTPUT_TOKENIZATION = [\n", + " DEC_IN\n", + "] * BATCH_LEN\n", + "\n", + "\n", + "print(INPUT_TOKENIZATION)\n", + "\n", + "TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size\n", + "EMBEDDED_SIZE = 256\n", + "FEED_FORWARD_DIM = EMBEDDED_SIZE * 4\n", + "\n", + "EMBEDDER = Embedder.NanoSocratesEmbedder(TOKEN_SPACE_SIZE, EMBEDDED_SIZE)\n", + "encoder_tensor: torch.Tensor = EMBEDDER(INPUT_TOKENIZATION)\n", + "ENCODER = torch.nn.Sequential(\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + ")\n", + "decoder_tensor: torch.Tensor = EMBEDDER(OUTPUT_TOKENIZATION)\n", + "DECODER = torch.nn.Sequential(\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + ")\n", + "\n", + "print(len(INPUT_TOKENIZATION))\n", + "print(f\"Embedder Tensor: {encoder_tensor.shape}\")\n", + "print(f\"Values:\\n{encoder_tensor}\")\n", + "\n", + "BATCH_SIZE, TOKENS, DIMENSIONS = encoder_tensor.shape\n", + "PAD_MASK = torch.tensor([PAD_MASK] * BATCH_LEN)\n", + "\n", + "encoder_out, _ = ENCODER((encoder_tensor, PAD_MASK))\n", + "tensor: torch.Tensor\n", + "tensor, _, _, _ = DECODER((decoder_tensor, encoder_out, encoder_out, None))\n", + "\n", + "print(f\"ENCODER Tensor: {tensor.shape}\")\n", + "print(f\"Values:\\n{tensor}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deep_learning", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Playgrounds/encoder.ipynb b/Playgrounds/encoder.ipynb index fb195e0..02b1e14 100644 --- a/Playgrounds/encoder.ipynb +++ b/Playgrounds/encoder.ipynb @@ -10,38 +10,59 @@ "name": "stdout", "output_type": "stream", "text": [ - "[7706, 290, 756, 4270, 7357, 115, 351, 1507, 1213, 410, 3382, 317, 497, 4740, 2784, 7700]\n", - "16\n", - "Embedder Tensor: torch.Size([13, 256])\n", + "[[7706, 290, 756, 4270, 7357, 115, 351, 1507, 1213, 410, 3382, 317, 497, 4740, 2784, 7700], [7706, 290, 756, 4270, 7357, 115, 351, 1507, 1213, 410, 3382, 317, 497, 4740, 2784, 7700]]\n", + "2\n", + "Embedder Tensor: torch.Size([2, 16, 256])\n", "Values:\n", - "tensor([[-1.0474, 1.9119, 1.3443, ..., -1.5243, 1.2989, 0.3618],\n", - " [ 1.0083, 1.4955, 0.9479, ..., 1.7371, 0.2389, 2.1217],\n", - " [-0.7681, -1.7427, 1.7070, ..., 0.6594, -0.6083, -0.4595],\n", - " ...,\n", - " [-0.7209, -0.3639, -0.6911, ..., 3.3490, -2.7354, 1.1244],\n", - " [-0.7352, -1.6731, 0.2976, ..., 1.5605, -1.3298, 1.3615],\n", - " [-0.5377, 0.3704, -0.4427, ..., 0.4723, 0.5781, 0.2003]],\n", + "tensor([[[-0.6981, 0.0804, -2.1672, ..., 0.3919, 0.3341, 1.0794],\n", + " [ 2.5818, -0.2308, 0.6001, ..., -0.0500, -0.0408, -0.9852],\n", + " [-0.6967, 0.8109, 1.3108, ..., 2.1693, 1.4143, -0.1236],\n", + " ...,\n", + " [ 2.1226, 2.5695, -1.6178, ..., -0.0652, -0.0802, 0.1103],\n", + " [ 0.8770, -2.4782, 0.8536, ..., 2.0471, -1.5702, 0.7387],\n", + " [-0.0495, -1.8601, 0.0405, ..., 2.3944, -0.4297, 1.1141]],\n", + "\n", + " [[-0.6981, 0.0804, -2.1672, ..., 0.3919, 0.3341, 1.0794],\n", + " [ 2.5818, -0.2308, 0.6001, ..., -0.0500, -0.0408, -0.9852],\n", + " [-0.6967, 0.8109, 1.3108, ..., 2.1693, 1.4143, -0.1236],\n", + " ...,\n", + " [ 2.1226, 2.5695, -1.6178, ..., -0.0652, -0.0802, 0.1103],\n", + " [ 0.8770, -2.4782, 0.8536, ..., 2.0471, -1.5702, 0.7387],\n", + " [-0.0495, -1.8601, 0.0405, ..., 2.3944, -0.4297, 1.1141]]],\n", " grad_fn=)\n", - "ENCODER Tensor: torch.Size([13, 256])\n", + "ENCODER Tensor: torch.Size([2, 16, 256])\n", "Values:\n", - "tensor([[-1.0270, 0.6589, -0.3154, ..., -1.3113, 0.5058, -0.0608],\n", - " [ 1.0235, 1.2011, -0.3139, ..., 0.1643, 0.6761, 0.9673],\n", - " [-0.7295, -1.5149, 0.4729, ..., 0.3185, -0.2433, -1.2669],\n", - " ...,\n", - " [-0.2189, -0.1399, -1.0049, ..., 1.8693, -2.4663, -0.3319],\n", - " [-0.1491, -0.4986, -0.7297, ..., 1.2760, -0.5654, 0.7038],\n", - " [-1.3576, 0.3478, -0.1016, ..., 0.0712, 0.3772, -0.1522]],\n", + "tensor([[[-1.6325, 0.4094, -2.1403, ..., 0.4654, 0.5993, 0.9683],\n", + " [ 1.8236, 0.4025, -0.6972, ..., 0.2430, 0.2536, -1.0889],\n", + " [-0.0587, 0.1618, -0.2335, ..., 1.7609, 1.2664, -0.4452],\n", + " ...,\n", + " [ 2.0337, 1.3184, -1.3165, ..., -0.3303, 0.6572, 0.0884],\n", + " [ 0.5752, -2.5594, -0.2393, ..., 1.3318, -1.4236, 0.4686],\n", + " [ 1.0075, -2.4273, -0.4593, ..., 1.6660, 0.0359, 0.2927]],\n", + "\n", + " [[-1.8300, -0.3079, -1.6585, ..., 0.4859, 0.5652, 0.8072],\n", + " [ 1.5461, -0.5666, -0.0330, ..., 0.5651, 0.2974, -1.0879],\n", + " [-0.9060, 0.2700, -0.4585, ..., 2.0363, 1.2657, -0.7060],\n", + " ...,\n", + " [ 1.6688, 1.7038, -1.9549, ..., -0.2052, 0.6270, 0.4598],\n", + " [ 0.0482, -2.3951, -0.4351, ..., 1.6230, -1.3662, -0.0390],\n", + " [ 0.8146, -2.6169, -0.6188, ..., 1.4525, 0.0507, 0.5177]]],\n", " grad_fn=)\n" ] } ], "source": [ + "import random\n", "import torch\n", "from pathlib import Path\n", "import Project_Model.Libs.Embedder as Embedder\n", "import Project_Model.Libs.BPE as BPE\n", "import Project_Model.Libs.Transformer as Transformer\n", "\n", + "# set a fixed seed\n", + "torch.manual_seed(0)\n", + "random.seed(0)\n", + "\n", "TEXT = \"The Dark Knight is a 2008 superhero film directed by Christopher Nolan,\"\n", "\n", "VOCABULARY_PATH = Path(\"Assets/Model/toy_10/toy_dictionary.json\")\n", @@ -53,7 +74,7 @@ " SPECIAL_VOC\n", ")\n", "\n", - "TOKENIZATION = TOKENANO.encode(TEXT)\n", + "TOKENIZATION = [TOKENANO.encode(TEXT), TOKENANO.encode(TEXT)]\n", "print(TOKENIZATION)\n", "\n", "TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size\n", @@ -61,22 +82,22 @@ "FEED_FORWARD_DIM = EMBEDDED_SIZE * 4\n", "\n", "EMBEDDER = Embedder.NanoSocratesEmbedder(TOKEN_SPACE_SIZE, EMBEDDED_SIZE)\n", - "tensor: torch.Tensor = EMBEDDER(TOKENIZATION[3:])\n", + "tensor: torch.Tensor = EMBEDDER(TOKENIZATION)\n", "ENCODER = torch.nn.Sequential(\n", " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", - " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4)\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", ")\n", "print(len(TOKENIZATION))\n", "print(f\"Embedder Tensor: {tensor.shape}\")\n", "print(f\"Values:\\n{tensor}\")\n", "\n", - "TOKENS, DIMENSIONS = tensor.shape\n", - "\n", - "tensor = ENCODER(tensor)\n", + "BATCH_SIZE, TOKENS, DIMENSIONS = tensor.shape\n", + "PAD_MASK = torch.tensor([[True] * TOKENS] * BATCH_SIZE, dtype=torch.bool)\n", + "tensor, _ = ENCODER((tensor, PAD_MASK))\n", "print(f\"ENCODER Tensor: {tensor.shape}\")\n", "print(f\"Values:\\n{tensor}\")\n", "\n", diff --git a/Playgrounds/nanosocrates-sanity-check.ipynb b/Playgrounds/nanosocrates-sanity-check.ipynb new file mode 100644 index 0000000..7fd1ca7 --- /dev/null +++ b/Playgrounds/nanosocrates-sanity-check.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "f5762da9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([3, 17, 7714])\n", + "torch.Size([3, 17])\n", + "tensor([[2034, 6523, 5406, 3985, 5406, 6523, 2034, 2034, 5745, 643, 5406, 7405,\n", + " 6523, 6230, 6419, 5745, 657],\n", + " [2458, 830, 5745, 5745, 5406, 3741, 2034, 5745, 6302, 6419, 5406, 2411,\n", + " 719, 830, 5745, 3189, 2775],\n", + " [2034, 5745, 5327, 4696, 6523, 643, 6419, 1671, 6302, 4406, 5745, 643,\n", + " 643, 1901, 1914, 1914, 719]])\n" + ] + } + ], + "source": [ + "import random\n", + "import torch\n", + "from pathlib import Path\n", + "import Project_Model.Libs.Embedder as Embedder\n", + "import Project_Model.Libs.BPE as BPE\n", + "import Project_Model.Libs.Transformer as Transformer\n", + "\n", + "# set a fixed seed\n", + "torch.manual_seed(0)\n", + "random.seed(0)\n", + "\n", + "# BPE Init\n", + "VOCABULARY_PATH = Path(\"Assets/Model/toy_10/toy_dictionary.json\")\n", + "SPECIAL_VOC = BPE.default_special_tokens()\n", + "\n", + "VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)\n", + "TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)\n", + "\n", + "\n", + "# Constants\n", + "TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + 1\n", + "EMBEDDED_SIZE = 256\n", + "FEED_FORWARD_DIM = EMBEDDED_SIZE * 4\n", + "\n", + "\n", + "# Model Init\n", + "ENCODER_EMBEDDER = Embedder.NanoSocratesEmbedder(TOKEN_SPACE_SIZE, EMBEDDED_SIZE)\n", + "DECODER_EMBEDDER = Embedder.NanoSocratesEmbedder(TOKEN_SPACE_SIZE, EMBEDDED_SIZE)\n", + "\n", + "ENCODER = torch.nn.Sequential(\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + ")\n", + "\n", + "DECODER = torch.nn.Sequential(\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + " Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, 4),\n", + ")\n", + "\n", + "DETOKENER = Transformer.DeToken(\n", + " EMBEDDED_SIZE,\n", + " TOKEN_SPACE_SIZE\n", + ")\n", + "\n", + "\n", + "# Data\n", + "TEXT = (\n", + " \"The Dark Knight is a 2008 superhero film directed by Christopher Nolan,\"\n", + ")\n", + "OUT_TEXT = \"\"\n", + "\n", + "PAD_TOKEN = TOKENANO.encode(\"\")[0]\n", + "END_TOKEN = TOKENANO.encode(\"\")[0]\n", + "\n", + "ENCODER_INPUT = TOKENANO.encode(TEXT)\n", + "DECODER_INPUT = TOKENANO.encode(OUT_TEXT)\n", + "MAX_LEN = len(ENCODER_INPUT) + 1\n", + "\n", + "EN_IN, PAD_MASK = Transformer.normalize_sequence(ENCODER_INPUT, MAX_LEN, PAD_TOKEN, END_TOKEN)\n", + "DEC_IN, _ = Transformer.normalize_sequence(DECODER_INPUT, MAX_LEN, PAD_TOKEN, END_TOKEN)\n", + "\n", + "BATCH_LEN = 3\n", + "\n", + "INPUT_TOKENIZATION = [\n", + " EN_IN\n", + "] * BATCH_LEN\n", + "OUTPUT_TOKENIZATION = [\n", + " DEC_IN\n", + "] * BATCH_LEN\n", + "\n", + "encoder_tensor_input = ENCODER_EMBEDDER(INPUT_TOKENIZATION)\n", + "encoder_padding_mask = torch.tensor([PAD_MASK] * BATCH_LEN)\n", + "\n", + "encoder_output, _ = ENCODER((encoder_tensor_input, encoder_padding_mask))\n", + "\n", + "decoder_tensor_input = DECODER_EMBEDDER(OUTPUT_TOKENIZATION)\n", + "decoder_padding_mask = torch.tensor([[False] * MAX_LEN] * BATCH_LEN)\n", + "\n", + "decoder_output, _, _, _ = DECODER((decoder_tensor_input, encoder_output, encoder_output, None))\n", + "\n", + "logits: torch.Tensor = DETOKENER(decoder_output)\n", + "\n", + "print(logits.shape)\n", + "\n", + "# print(logits)\n", + "\n", + "most_probable_tokens = torch.argmax(logits, 2)\n", + "\n", + "print(most_probable_tokens.shape)\n", + "print(most_probable_tokens)\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deep_learning", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Playgrounds/sanity-check-pytorch.ipynb b/Playgrounds/sanity-check-pytorch.ipynb index 2ca5105..fa4e14a 100644 --- a/Playgrounds/sanity-check-pytorch.ipynb +++ b/Playgrounds/sanity-check-pytorch.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "id": "dd23cc94", "metadata": {}, "outputs": [ @@ -22,6 +22,18 @@ "\n", "print(f\"Current detected architecture is: {DEVICE.type}\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6584882e", + "metadata": {}, + "outputs": [], + "source": [ + "import Project_Model.Libs.Transformer as Transformer\n", + "DECODER = Transformer.Decoder(256, 1024, 4)\n", + "print()" + ] } ], "metadata": { diff --git a/Project_Model/Libs/BPE/Enums/SpecialToken.py b/Project_Model/Libs/BPE/Enums/SpecialToken.py index 382c032..b3c42c1 100644 --- a/Project_Model/Libs/BPE/Enums/SpecialToken.py +++ b/Project_Model/Libs/BPE/Enums/SpecialToken.py @@ -10,7 +10,6 @@ class SpecialToken(Enum): RELATIONSHIP = "" OBJECT = "" ABSTRACT = "" - CORPUS_END = "" ## Tasks' Token RDF_TO_TEXT = "" @@ -20,4 +19,6 @@ class SpecialToken(Enum): # BPE Training: # NanoSocrates - START = "" \ No newline at end of file + START = "" + CORPUS_END = "" + PAD = "" \ No newline at end of file diff --git a/Project_Model/Libs/Transformer/Classes/DeToken.py b/Project_Model/Libs/Transformer/Classes/DeToken.py new file mode 100644 index 0000000..c0b961e --- /dev/null +++ b/Project_Model/Libs/Transformer/Classes/DeToken.py @@ -0,0 +1,19 @@ +import torch + + +class DeToken(torch.nn.Module): + + def __init__(self, embedding_size: int, vocabulary_size: int) -> None: + super().__init__() + + self.__linear = torch.nn.Linear(embedding_size, vocabulary_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # 1) Go from latent space to vocabularu space + x = self.__linear(x) + + # 2) Go to logits + x = torch.softmax(x, 2) + + return x diff --git a/Project_Model/Libs/Transformer/Classes/Decoder.py b/Project_Model/Libs/Transformer/Classes/Decoder.py index 0a818ee..a9c7907 100644 --- a/Project_Model/Libs/Transformer/Classes/Decoder.py +++ b/Project_Model/Libs/Transformer/Classes/Decoder.py @@ -35,22 +35,28 @@ class Decoder(nn.Module): ) self.__layer_norm_3 = nn.LayerNorm(embedding_dimension) - - - def forward(self, x, k_x, v_x, padding_mask = None): #-> list[torch.Tensor]: # k_x = v_x . While x_q = x + def forward( + self, + args: tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor + ] + ): # -> list[torch.Tensor]: # k_x = v_x . While x_q = x + # WARNING: args is needed to have sequential + x, k_x, v_x, padding_mask = args # build of attention mask attention_mask = get_causal_attention_mask(x.size(1)) # 1) Masked Attention MASKED_ATTENTION = self.__masked_attention( - x, x, x, key_padding_mask=padding_mask, attn_mask=attention_mask + x, x, x, key_padding_mask=padding_mask, attention_mask=attention_mask ) # 2) Dropout - DROPPED_MASKED_ATTENTION = self.__dropout( - MASKED_ATTENTION - ) + DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION) del MASKED_ATTENTION # 3) Residual Connection @@ -61,7 +67,9 @@ class Decoder(nn.Module): x = self.__layer_norm_1(x) # 5) Encoder–decoder (cross) attention - CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x, key_padding_mask=padding_mask) + CROSS_ATTENTION = self.__cross_attention( + x, k_x, v_x, key_padding_mask=padding_mask + ) # 6) Dropout DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION) @@ -88,7 +96,7 @@ class Decoder(nn.Module): # 12) Layer Normalization x = self.__layer_norm_3(x) - return x, k_x, v_x, padding_mask + return (x, k_x, v_x, padding_mask) # use eval to disable dropout ecc diff --git a/Project_Model/Libs/Transformer/Classes/Encoder.py b/Project_Model/Libs/Transformer/Classes/Encoder.py index 0c46fe0..e232a18 100644 --- a/Project_Model/Libs/Transformer/Classes/Encoder.py +++ b/Project_Model/Libs/Transformer/Classes/Encoder.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from Project_Model.Libs.Transformer.Classes.FeedForwardNetwork import FeedForwardNetwork from Project_Model.Libs.Transformer.Classes.TorchMultiHeadAttention import ( @@ -29,14 +30,17 @@ class Encoder( embedding_dimension ) # norm of second "Add and Normalize" self.__dropout = nn.Dropout(0.1) # ... - pass - def forward(self, x, padding_mask = None): + + def forward(self, args: tuple[torch.Tensor, torch.Tensor]): + # WARNING: args is needed to have sequential + x, padding_mask = args + # -> ATTENTION -> dropout -> add and normalize -> FF -> dropout -> add and normalize -> # Attention with Residual Connection [ input + self-attention] # 1) Multi Head Attention - ATTENTION = self.__attention(x, x, x,key_padding_mask= padding_mask) + ATTENTION = self.__attention(x, x, x, key_padding_mask=padding_mask) # 2) Dropout DROPPED_ATTENTION = self.__dropout(ATTENTION) @@ -62,7 +66,7 @@ class Encoder( # 8) Layer Normalization x = self.__layer_norm_2(x) - return x,padding_mask + return (x, padding_mask) # use eval to disable dropout ecc diff --git a/Project_Model/Libs/Transformer/Classes/MultiHeadAttention.py b/Project_Model/Libs/Transformer/Classes/MultiHeadAttention.py deleted file mode 100644 index 63c9a6f..0000000 --- a/Project_Model/Libs/Transformer/Classes/MultiHeadAttention.py +++ /dev/null @@ -1,24 +0,0 @@ -# multi-head attention -> (then to) ff -# attention: qkv -> score = qk -> divide -> softamx -# multihead -> QKV diferent in each head ( built by : X*[WQ/QK/WV]) -# z = soft(Q*K'/sqr(d))*V -# recombine Z: 1) concatenate. 2) [z01234] * W = Z -# we expect later to have padding token -######################## -# WIP -######################## - -import torch.nn as nn - -embed_dim = 256 -num_heads = 8 -multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - - -class MultiheadAttention: - - def __init__( - self, - num_heads=8, - ) -> None: - pass diff --git a/Project_Model/Libs/Transformer/Classes/NanoSocratesCore.py b/Project_Model/Libs/Transformer/Classes/NanoSocratesCore.py index bb2d971..fca307a 100644 --- a/Project_Model/Libs/Transformer/Classes/NanoSocratesCore.py +++ b/Project_Model/Libs/Transformer/Classes/NanoSocratesCore.py @@ -4,55 +4,108 @@ from .Encoder import Encoder from ....Libs.Embedder import NanoSocratesEmbedder import torch + class NanoSocratesCore(torch.nn.Module): - def __init__(self, - embedded_size: int, - feed_forward_dim: int, - encoder_layers: int, - decoder_layers:int, - attention_heads: int, - vocab_size: int) -> None: - + def __init__( + self, + sentence_length: int, + vocab_size: int, + embedding_size: int = 256, + feed_forward_multiplier: int = 4, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + num_attention_heads: int = 4, + ) -> None: + + feed_forward_dim = embedding_size * feed_forward_multiplier + + self.__sentence_length = sentence_length self.__encoder_sequence = torch.nn.Sequential( - *[Encoder(embedded_size, feed_forward_dim, attention_heads) for _ in range(encoder_layers)] - ) - - #* unpack the list so that each encoder has its own weights - + *[ + Encoder(embedding_size, feed_forward_dim, num_attention_heads) + for _ in range(num_encoder_layers) + ] + ) + + # * unpack the list so that each encoder has its own weights + self.__decoder_sequence = torch.nn.Sequential( - *[Decoder(embedded_size, feed_forward_dim, attention_heads) for _ in range(decoder_layers)] + *[ + Decoder(embedding_size, feed_forward_dim, num_attention_heads) + for _ in range(num_decoder_layers) + ] + ) + + self.__linear = torch.nn.Linear(embedding_size, vocab_size) + + self.__input_embeder = NanoSocratesEmbedder(vocab_size, embedding_size) + self.__output_embedder = NanoSocratesEmbedder(vocab_size, embedding_size) + + def forward( + self, + encoder_input: list[list[int]], + decoder_input: list[list[int]], + encoder_padding_mask: list[list[int]], + ): + + if len(encoder_padding_mask) != len(encoder_input): + raise Exception("Mismatch in received_dimensions") + + # TODO: check for tensor in input to embedder + # 1) Embed User-Input for encoders + ENCODER_INPUT = self.__input_embeder(encoder_input) + + # 2) Encode User-Input + ENCODER_OUTPUT, _ = self.__encoder_sequence(ENCODER_INPUT, encoder_padding_mask) + del ENCODER_INPUT + + exit_loop = False + decoder_token_list = decoder_input[:] + decoder_phase = 0 + + LOGITS_HISTORY: list[torch.Tensor] = [] + + # 3) Autoregressive Output + while not exit_loop: + + # 3.0) Increment Counter + decoder_phase += 1 + + # 3.1) Embed Decoder Input + decoder_input = self.__output_embedder(decoder_token_list) + + # 3.2) Decode Decoder Input + DECODER_OUTPUT, _, _, _ = self.__decoder_sequence( + decoder_input, ENCODER_OUTPUT, ENCODER_OUTPUT ) - - self.__linear = torch.nn.Linear(embedded_size, vocab_size, bias=False) - self.__input_embeder = NanoSocratesEmbedder(vocab_size,embedded_size) - self.__output_embedder = NanoSocratesEmbedder(vocab_size,embedded_size) + # 3.3) Go back to Token space + # TODO: change name + LOGITS = self.__linear(DECODER_OUTPUT) + del DECODER_OUTPUT + # 3.4) Transform in probabilities + # TODO: change name + TOKEN_PROBABILITIES = torch.softmax(LOGITS, dim=-1) + del LOGITS - def forward(self, token_list, padding_mask = None): - x = self.__input_embeder(token_list) - x = self.__encoder_sequence(x, padding_mask)[0] + LOGITS_HISTORY.append(TOKEN_PROBABILITIES) + # 3.5) Take most probable tokens + TOKEN_IDS = torch.argmax(TOKEN_PROBABILITIES, -1) - # do while - x = self.__decoder_sequence(x,x,x, padding_mask)[0] - logits = self.__linear(x) - log_prob = torch.softmax(logits, dim=-1) - output = torch.argmax(log_prob) - while self.keep_going(log_prob): - # from log_prob again into x + # TODO: check for dimensions and for efficiency + DECODER_TOKEN_TENSOR = torch.tensor(decoder_token_list) + DECODER_TOKEN_TENSOR[:, decoder_phase] = TOKEN_IDS + decoder_token_list = DECODER_TOKEN_TENSOR.tolist() + del TOKEN_IDS + del DECODER_TOKEN_TENSOR - x = self.__decoder_sequence(x,x,x, padding_mask)[0] - logits = self.__linear(x) - log_prob = torch.softmax(logits, dim=-1) - # argmax - - return log_prob - - - - def keep_going(self, x: ) -> bool: + # 3.6) Check if we generated all tokens + if decoder_phase == self.__sentence_length - 1: + exit_loop = True + return LOGITS_HISTORY diff --git a/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py b/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py index 52c0cc5..38aeb6d 100644 --- a/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py +++ b/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py @@ -8,17 +8,16 @@ class TorchMultiHeadAttention(nn.Module): self, embedding_dimension: int, number_of_attention_heads: int, - dropout: float = 0.0, + dropout: float = 0.0 ): super().__init__() - self.attention = nn.MultiheadAttention( + self.attention = torch.nn.MultiheadAttention( embedding_dimension, - number_of_attention_heads, + num_heads=number_of_attention_heads, dropout=dropout, batch_first=True, ) - def forward( self, x_q: torch.Tensor, diff --git a/Project_Model/Libs/Transformer/Classes/__init__.py b/Project_Model/Libs/Transformer/Classes/__init__.py index 837ea82..b4507f1 100644 --- a/Project_Model/Libs/Transformer/Classes/__init__.py +++ b/Project_Model/Libs/Transformer/Classes/__init__.py @@ -1,15 +1,16 @@ from .Decoder import Decoder from .Encoder import Encoder from .FeedForwardNetwork import FeedForwardNetwork -from .MultiHeadAttention import MultiheadAttention +# from .MultiHeadAttention import MultiheadAttention from .TorchMultiHeadAttention import TorchMultiHeadAttention from .SpannedMasker import SpannedMasker +from .DeToken import DeToken __all__ = [ "Decoder", "Encoder", "FeedForwardNetwork", - "MultiheadAttention", "TorchMultiHeadAttention", - "SpannedMasker" + "SpannedMasker", + "DeToken" ] \ No newline at end of file diff --git a/Project_Model/Libs/Transformer/Utils/post_tokenization.py b/Project_Model/Libs/Transformer/Utils/post_tokenization.py index fc68363..23d5e4d 100644 --- a/Project_Model/Libs/Transformer/Utils/post_tokenization.py +++ b/Project_Model/Libs/Transformer/Utils/post_tokenization.py @@ -47,7 +47,7 @@ def normalize_sequence( ) -> tuple[list[int], list[bool]]: new_sequence = pad_sequence(sequence, max_length, pad_token) - new_sequence = truncate_sequence(sequence, max_length, end_token) - PADDING_MASK = create_padding_mask(sequence, pad_token) + new_sequence = truncate_sequence(new_sequence, max_length, end_token) + PADDING_MASK = create_padding_mask(new_sequence, pad_token) return (new_sequence, PADDING_MASK)