2025-10-08 22:51:36 +02:00

126 lines
5.4 KiB
Python

import random
import torch
import pandas as pd
from pathlib import Path
import Project_Model.Libs.Embedder as Embedder
import Project_Model.Libs.BPE as BPE
import Project_Model.Libs.Transformer as Transformer
import Project_Model.Libs.TorchShims as torch_shims
from Project_Model.Libs.Training.learning_rade_shedulers import Custom_lr
from Project_Model.Libs.Training.logistic_collector import LogitsCollector # import the external collector
# set a fixed seed
torch.manual_seed(0)
random.seed(0)
DEVICE = torch_shims.get_default_device()
torch.set_default_device(DEVICE)
# BPE Init
VOCABULARY_PATH = Path("Assets/Model/toy_10/toy_dictionary.json")
SPECIAL_VOC = BPE.default_special_tokens()
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)
# Constants
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + 1
EMBEDDED_SIZE = 256
FEED_FORWARD_MULTIPLIER = 4
ATTENTION_HEADS = 4
SENTENCE_LENGTH = 256
NUMBER_OF_BLOCKS = 2
MAX_EPOCHS = int(1e3)
PAD_TOKEN = TOKENANO.encode("<PAD>")[0]
END_TOKEN = TOKENANO.encode("<END>")[0]
# Load CSV
TOY_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv")
TOY_DATASET = pd.read_csv(TOY_DATASET_PATH)
TOY_BATCH_INPUT_LIST: list[list[int]] = []
TOY_BATCH_PADDING_LIST: list[list[bool]] = []
TOY_BATCH_TARGET_LIST: list[list[int]] = []
TOY_BATCH_DECODER_DEFAULT: list[list[int]] = []
for index, row in TOY_DATASET.iterrows():
RDFs: str = row["RDFs"]
Abstract: str = row["Abstract"]
input_tokens = TOKENANO.encode(RDFs) # encoder input ids
output_tokens = TOKENANO.encode(Abstract)[1:] # decoder target ids (shifted left)
decoder_default_tokens = TOKENANO.encode("<SOS>") # decoder input starts with <SOS>
input_tokens, padding = Transformer.normalize_sequence(
input_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN
) # pad/trim + end token
output_tokens, _ = Transformer.normalize_sequence(
output_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN
) # pad/trim + end token
decoder_default_tokens = Transformer.pad_sequence(
decoder_default_tokens, SENTENCE_LENGTH, PAD_TOKEN
) # pad with PAD up to SENTENCE_LENGTH
TOY_BATCH_INPUT_LIST.append(input_tokens)
TOY_BATCH_PADDING_LIST.append(padding)
TOY_BATCH_TARGET_LIST.append(output_tokens)
TOY_BATCH_DECODER_DEFAULT.append(decoder_default_tokens)
# Training loop
LOSS_HISTORY = []
NANOSOCRATES = Transformer.TrainingModel(
TOKEN_SPACE_SIZE,
EMBEDDED_SIZE,
FEED_FORWARD_MULTIPLIER,
ATTENTION_HEADS,
NUMBER_OF_BLOCKS,
)
collector = LogitsCollector(PAD_TOKEN, END_TOKEN, TOKENANO) # collects logits and decodes
NANOSOCRATES.train()
cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
optimizer = torch.optim.AdamW(NANOSOCRATES.parameters())
scheduler = Custom_lr(EMBEDDED_SIZE, 4000) # step each optimizer step
current_epoch = 0
BATCH_SIZE = min(32, len(TOY_BATCH_INPUT_LIST)) # small batch to stabilize
while current_epoch < MAX_EPOCHS:
# simple fixed mini-batch from the top; later you can shuffle/slice
enc = torch.tensor(TOY_BATCH_INPUT_LIST[:BATCH_SIZE], dtype=torch.long) # [B,T] encoder token ids
pad = torch.tensor(TOY_BATCH_PADDING_LIST[:BATCH_SIZE], dtype=torch.bool) # [B,T] True where encoder PAD is present
tgt = torch.tensor(TOY_BATCH_TARGET_LIST[:BATCH_SIZE], dtype=torch.long) # [B,T] decoder targets (ground-truth)
# decoder prefix buffer: <SOS> at pos 0, PAD elsewhere (no shift here) # we will fill it step by step
dec = torch.tensor(TOY_BATCH_DECODER_DEFAULT[:BATCH_SIZE], dtype=torch.long) # [B,T]
total_loss = 0.0
collector.reset() # start fresh for this epoch
T = tgt.size(1) # sequence length
for t in range(T):
optimizer.zero_grad(set_to_none=True) # clear grads for this token step
prefix = dec[:, : t + 1] # [B, t+1] current decoder prefix
dec_pad_mask = prefix.eq(PAD_TOKEN) # [B, t+1] True where PAD inside prefix
# one-step logits given prefix (trainer model expects 4 args now)
logits_t: torch.Tensor = NANOSOCRATES((enc, pad, prefix, dec_pad_mask)) # [B,V] logits for step t
collector.add(logits_t) # store logits for decoding later
loss_t = cross_entropy(logits_t, tgt[:, t]) # CE expects raw logits; PAD ignored
loss_t.backward() # backprop for this step
optimizer.step() # update params
scheduler.step() # Noam/warmup: step per optimizer step
total_loss = float(loss_t.detach()) # keep last step loss for logging
# teacher forcing: reveal the correct token for next position
if t < T - 1:
dec[:, t + 1] = tgt[:, t] # write ground-truth into next slot
current_epoch += 1
print(f"EPOCH {current_epoch}\n\tLoss: {total_loss:.6f}") # simple log
collector.print_decoded() # print decoded predictions for the batch