222 lines
9.5 KiB
Plaintext
222 lines
9.5 KiB
Plaintext
|
|
{
|
||
|
|
"cells": [
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": null,
|
||
|
|
"id": "c8741a8f",
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [
|
||
|
|
{
|
||
|
|
"name": "stdout",
|
||
|
|
"output_type": "stream",
|
||
|
|
"text": [
|
||
|
|
"EPOCH 1\n",
|
||
|
|
"\tLoss: 7.424792\n",
|
||
|
|
"[0] \n",
|
||
|
|
"[1] \n",
|
||
|
|
"[2] \n",
|
||
|
|
"[3] \n",
|
||
|
|
"[4] \n",
|
||
|
|
"[5] \n",
|
||
|
|
"[6] \n",
|
||
|
|
"[7] \n",
|
||
|
|
"[8] \n",
|
||
|
|
"[9] \n"
|
||
|
|
]
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"source": [
|
||
|
|
"import random\n",
|
||
|
|
"import torch\n",
|
||
|
|
"import pandas as pd\n",
|
||
|
|
"from pathlib import Path\n",
|
||
|
|
"import Project_Model.Libs.Embedder as Embedder\n",
|
||
|
|
"import Project_Model.Libs.BPE as BPE\n",
|
||
|
|
"import Project_Model.Libs.Transformer as Transformer\n",
|
||
|
|
"import Project_Model.Libs.TorchShims as torch_shims\n",
|
||
|
|
"from Project_Model.Libs.Training.learning_rade_shedulers import Custom_lr\n",
|
||
|
|
"\n",
|
||
|
|
"import torch\n",
|
||
|
|
"\n",
|
||
|
|
"class LogitsCollector:\n",
|
||
|
|
" def __init__(self, pad_token: int, end_token: int, tokenizer) -> None:\n",
|
||
|
|
" self.__pad_token = pad_token # used to skip PAD\n",
|
||
|
|
" self.__end_token = end_token # used to stop at END\n",
|
||
|
|
" self.__tokenizer = tokenizer # exposes .decode(list[int]) -> str\n",
|
||
|
|
" self.__steps: list[torch.Tensor] = [] # list of per-step logits [B,V]\n",
|
||
|
|
"\n",
|
||
|
|
" def reset(self) -> None:\n",
|
||
|
|
" self.__steps.clear() # clear history\n",
|
||
|
|
"\n",
|
||
|
|
" def add(self, logits_step: torch.Tensor) -> None:\n",
|
||
|
|
" if logits_step.dim() == 3: # handle [B,1,V]\n",
|
||
|
|
" logits_step = logits_step[:, -1, :] # -> [B,V]\n",
|
||
|
|
" self.__steps.append(logits_step.detach()) # store raw logits (detached)\n",
|
||
|
|
"\n",
|
||
|
|
" def tokens(self) -> list[list[int]]:\n",
|
||
|
|
" if not self.__steps:\n",
|
||
|
|
" return []\n",
|
||
|
|
" stack = torch.stack(self.__steps, dim=0) # [T,B,V]\n",
|
||
|
|
" probs = torch.softmax(stack, dim=-1) # softmax over vocab -> [T,B,V]\n",
|
||
|
|
" ids = probs.argmax(dim=-1).transpose(0, 1) # greedy ids -> [B,T]\n",
|
||
|
|
" out: list[list[int]] = []\n",
|
||
|
|
" for row in ids.tolist():\n",
|
||
|
|
" seq: list[int] = []\n",
|
||
|
|
" for tok in row:\n",
|
||
|
|
" if tok == self.__end_token: # stop on END\n",
|
||
|
|
" break\n",
|
||
|
|
" if tok == self.__pad_token: # skip PAD\n",
|
||
|
|
" continue\n",
|
||
|
|
" seq.append(tok)\n",
|
||
|
|
" out.append(seq)\n",
|
||
|
|
" return out\n",
|
||
|
|
"\n",
|
||
|
|
" def print_decoded(self) -> None:\n",
|
||
|
|
" for i, seq in enumerate(self.tokens()):\n",
|
||
|
|
" try:\n",
|
||
|
|
" text = self.__tokenizer.decode(seq) # decode tokens to string\n",
|
||
|
|
" except Exception:\n",
|
||
|
|
" text = str(seq) # fallback to ids\n",
|
||
|
|
" print(f\"[{i}] {text}\") # simple print\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"# set a fixed seed\n",
|
||
|
|
"torch.manual_seed(0)\n",
|
||
|
|
"random.seed(0)\n",
|
||
|
|
"DEVICE = torch_shims.get_default_device()\n",
|
||
|
|
"torch.set_default_device(DEVICE)\n",
|
||
|
|
"\n",
|
||
|
|
"# BPE Init\n",
|
||
|
|
"VOCABULARY_PATH = Path(\"Assets/Model/toy_10/toy_dictionary.json\")\n",
|
||
|
|
"SPECIAL_VOC = BPE.default_special_tokens()\n",
|
||
|
|
"\n",
|
||
|
|
"VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)\n",
|
||
|
|
"TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)\n",
|
||
|
|
"\n",
|
||
|
|
"# Constants\n",
|
||
|
|
"TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + 1\n",
|
||
|
|
"EMBEDDED_SIZE = 256\n",
|
||
|
|
"FEED_FORWARD_MULTIPLIER = 4\n",
|
||
|
|
"ATTENTION_HEADS = 4\n",
|
||
|
|
"SENTENCE_LENGTH = 256\n",
|
||
|
|
"NUMBER_OF_BLOCKS = 2\n",
|
||
|
|
"MAX_EPOCHS = int(1e3)\n",
|
||
|
|
"\n",
|
||
|
|
"PAD_TOKEN = TOKENANO.encode(\"<PAD>\")[0]\n",
|
||
|
|
"END_TOKEN = TOKENANO.encode(\"<END>\")[0]\n",
|
||
|
|
"\n",
|
||
|
|
"# Load CSV\n",
|
||
|
|
"TOY_DATASET_PATH = Path(\"Assets/Dataset/1-hop/toy/rdf_text.csv\")\n",
|
||
|
|
"TOY_DATASET = pd.read_csv(TOY_DATASET_PATH)\n",
|
||
|
|
"\n",
|
||
|
|
"TOY_BATCH_INPUT_LIST: list[list[int]] = []\n",
|
||
|
|
"TOY_BATCH_PADDING_LIST: list[list[bool]] = []\n",
|
||
|
|
"TOY_BATCH_TARGET_LIST: list[list[int]] = []\n",
|
||
|
|
"TOY_BATCH_DECODER_DEFAULT: list[list[int]] = []\n",
|
||
|
|
"\n",
|
||
|
|
"for index, row in TOY_DATASET.iterrows():\n",
|
||
|
|
" RDFs: str = row[\"RDFs\"]\n",
|
||
|
|
" Abstract: str = row[\"Abstract\"]\n",
|
||
|
|
"\n",
|
||
|
|
" input_tokens = TOKENANO.encode(RDFs) # encoder input ids\n",
|
||
|
|
" output_tokens = TOKENANO.encode(Abstract)[1:] # decoder target ids (shifted left)\n",
|
||
|
|
" decoder_default_tokens = TOKENANO.encode(\"<SOS>\") # decoder input starts with <SOS>\n",
|
||
|
|
"\n",
|
||
|
|
" input_tokens, padding = Transformer.normalize_sequence(\n",
|
||
|
|
" input_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN\n",
|
||
|
|
" ) # pad/trim + end token\n",
|
||
|
|
" output_tokens, _ = Transformer.normalize_sequence(\n",
|
||
|
|
" output_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN\n",
|
||
|
|
" ) # pad/trim + end token\n",
|
||
|
|
" decoder_default_tokens = Transformer.pad_sequence(\n",
|
||
|
|
" decoder_default_tokens, SENTENCE_LENGTH, PAD_TOKEN\n",
|
||
|
|
" ) # pad with PAD up to SENTENCE_LENGTH\n",
|
||
|
|
"\n",
|
||
|
|
" TOY_BATCH_INPUT_LIST.append(input_tokens)\n",
|
||
|
|
" TOY_BATCH_PADDING_LIST.append(padding)\n",
|
||
|
|
" TOY_BATCH_TARGET_LIST.append(output_tokens)\n",
|
||
|
|
" TOY_BATCH_DECODER_DEFAULT.append(decoder_default_tokens)\n",
|
||
|
|
"\n",
|
||
|
|
"# Training loop\n",
|
||
|
|
"LOSS_HISTORY = []\n",
|
||
|
|
"NANOSOCRATES = Transformer.TrainingModel(\n",
|
||
|
|
" TOKEN_SPACE_SIZE,\n",
|
||
|
|
" EMBEDDED_SIZE,\n",
|
||
|
|
" FEED_FORWARD_MULTIPLIER,\n",
|
||
|
|
" ATTENTION_HEADS,\n",
|
||
|
|
" NUMBER_OF_BLOCKS,\n",
|
||
|
|
")\n",
|
||
|
|
"\n",
|
||
|
|
"collector = LogitsCollector(PAD_TOKEN, END_TOKEN, TOKENANO) # collects logits and decodes\n",
|
||
|
|
"\n",
|
||
|
|
"NANOSOCRATES.train()\n",
|
||
|
|
"cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)\n",
|
||
|
|
"optimizer = torch.optim.AdamW(NANOSOCRATES.parameters())\n",
|
||
|
|
"scheduler = Custom_lr(EMBEDDED_SIZE, 4000) # step each optimizer step\n",
|
||
|
|
"\n",
|
||
|
|
"current_epoch = 0\n",
|
||
|
|
"BATCH_SIZE = min(32, len(TOY_BATCH_INPUT_LIST)) # small batch to stabilize\n",
|
||
|
|
"\n",
|
||
|
|
"while current_epoch < MAX_EPOCHS:\n",
|
||
|
|
" # simple fixed mini-batch from the top; later you can shuffle/slice\n",
|
||
|
|
" enc = torch.tensor(TOY_BATCH_INPUT_LIST[:BATCH_SIZE], dtype=torch.long) # [B,T] encoder token ids\n",
|
||
|
|
" pad = torch.tensor(TOY_BATCH_PADDING_LIST[:BATCH_SIZE], dtype=torch.bool) # [B,T] True where encoder PAD is present\n",
|
||
|
|
" tgt = torch.tensor(TOY_BATCH_TARGET_LIST[:BATCH_SIZE], dtype=torch.long) # [B,T] decoder targets (ground-truth)\n",
|
||
|
|
"\n",
|
||
|
|
" # decoder prefix buffer: <SOS> at pos 0, PAD elsewhere (no shift here) # we will fill it step by step\n",
|
||
|
|
" dec = torch.tensor(TOY_BATCH_DECODER_DEFAULT[:BATCH_SIZE], dtype=torch.long) # [B,T]\n",
|
||
|
|
"\n",
|
||
|
|
" total_loss = 0.0\n",
|
||
|
|
" collector.reset() # start fresh for this epoch\n",
|
||
|
|
"\n",
|
||
|
|
" T = tgt.size(1) # sequence length\n",
|
||
|
|
" for t in range(T):\n",
|
||
|
|
" optimizer.zero_grad(set_to_none=True) # clear grads for this token step\n",
|
||
|
|
"\n",
|
||
|
|
" prefix = dec[:, : t + 1] # [B, t+1] current decoder prefix\n",
|
||
|
|
" dec_pad_mask = prefix.eq(PAD_TOKEN) # [B, t+1] True where PAD inside prefix\n",
|
||
|
|
"\n",
|
||
|
|
" # one-step logits given prefix (trainer model expects 4 args now)\n",
|
||
|
|
" logits_t: torch.Tensor = NANOSOCRATES((enc, pad, prefix, dec_pad_mask)) # [B,V] logits for step t\n",
|
||
|
|
" collector.add(logits_t) # store logits for decoding later\n",
|
||
|
|
"\n",
|
||
|
|
" loss_t = cross_entropy(logits_t, tgt[:, t]) # CE expects raw logits; PAD ignored\n",
|
||
|
|
" loss_t.backward() # backprop for this step\n",
|
||
|
|
" optimizer.step() # update params\n",
|
||
|
|
" scheduler.step() # Noam/warmup: step per optimizer step\n",
|
||
|
|
"\n",
|
||
|
|
" total_loss = float(loss_t.detach()) # keep last step loss for logging\n",
|
||
|
|
"\n",
|
||
|
|
" # teacher forcing: reveal the correct token for next position\n",
|
||
|
|
" if t < T - 1:\n",
|
||
|
|
" dec[:, t + 1] = tgt[:, t] # write ground-truth into next slot\n",
|
||
|
|
"\n",
|
||
|
|
" current_epoch += 1\n",
|
||
|
|
" print(f\"EPOCH {current_epoch}\\n\\tLoss: {total_loss:.6f}\") # simple log\n",
|
||
|
|
" collector.print_decoded() # print decoded predictions for the batch\n"
|
||
|
|
]
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"metadata": {
|
||
|
|
"kernelspec": {
|
||
|
|
"display_name": "deep_learning",
|
||
|
|
"language": "python",
|
||
|
|
"name": "python3"
|
||
|
|
},
|
||
|
|
"language_info": {
|
||
|
|
"codemirror_mode": {
|
||
|
|
"name": "ipython",
|
||
|
|
"version": 3
|
||
|
|
},
|
||
|
|
"file_extension": ".py",
|
||
|
|
"mimetype": "text/x-python",
|
||
|
|
"name": "python",
|
||
|
|
"nbconvert_exporter": "python",
|
||
|
|
"pygments_lexer": "ipython3",
|
||
|
|
"version": "3.13.7"
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"nbformat": 4,
|
||
|
|
"nbformat_minor": 5
|
||
|
|
}
|