import random import torch import pandas as pd from pathlib import Path import Project_Model.Libs.Embedder as Embedder import Project_Model.Libs.BPE as BPE import Project_Model.Libs.Transformer as Transformer import Project_Model.Libs.TorchShims as torch_shims from Project_Model.Libs.Training.learning_rade_shedulers import Custom_lr from Project_Model.Libs.Training.logistic_collector import LogitsCollector # import the external collector # set a fixed seed torch.manual_seed(0) random.seed(0) DEVICE = torch_shims.get_default_device() torch.set_default_device(DEVICE) # BPE Init VOCABULARY_PATH = Path("Assets/Model/toy_10/toy_dictionary.json") SPECIAL_VOC = BPE.default_special_tokens() VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH) TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC) # Constants TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + 1 EMBEDDED_SIZE = 256 FEED_FORWARD_MULTIPLIER = 4 ATTENTION_HEADS = 4 SENTENCE_LENGTH = 256 NUMBER_OF_BLOCKS = 2 MAX_EPOCHS = int(1e3) PAD_TOKEN = TOKENANO.encode("")[0] END_TOKEN = TOKENANO.encode("")[0] # Load CSV TOY_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv") TOY_DATASET = pd.read_csv(TOY_DATASET_PATH) TOY_BATCH_INPUT_LIST: list[list[int]] = [] TOY_BATCH_PADDING_LIST: list[list[bool]] = [] TOY_BATCH_TARGET_LIST: list[list[int]] = [] TOY_BATCH_DECODER_DEFAULT: list[list[int]] = [] for index, row in TOY_DATASET.iterrows(): RDFs: str = row["RDFs"] Abstract: str = row["Abstract"] input_tokens = TOKENANO.encode(RDFs) # encoder input ids output_tokens = TOKENANO.encode(Abstract)[1:] # decoder target ids (shifted left) decoder_default_tokens = TOKENANO.encode("") # decoder input starts with input_tokens, padding = Transformer.normalize_sequence( input_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN ) # pad/trim + end token output_tokens, _ = Transformer.normalize_sequence( output_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN ) # pad/trim + end token decoder_default_tokens = Transformer.pad_sequence( decoder_default_tokens, SENTENCE_LENGTH, PAD_TOKEN ) # pad with PAD up to SENTENCE_LENGTH TOY_BATCH_INPUT_LIST.append(input_tokens) TOY_BATCH_PADDING_LIST.append(padding) TOY_BATCH_TARGET_LIST.append(output_tokens) TOY_BATCH_DECODER_DEFAULT.append(decoder_default_tokens) # Training loop LOSS_HISTORY = [] NANOSOCRATES = Transformer.TrainingModel( TOKEN_SPACE_SIZE, EMBEDDED_SIZE, FEED_FORWARD_MULTIPLIER, ATTENTION_HEADS, NUMBER_OF_BLOCKS, ) collector = LogitsCollector(PAD_TOKEN, END_TOKEN, TOKENANO) # collects logits and decodes NANOSOCRATES.train() cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) optimizer = torch.optim.AdamW(NANOSOCRATES.parameters()) scheduler = Custom_lr(EMBEDDED_SIZE, 4000) # step each optimizer step current_epoch = 0 BATCH_SIZE = min(32, len(TOY_BATCH_INPUT_LIST)) # small batch to stabilize while current_epoch < MAX_EPOCHS: # simple fixed mini-batch from the top; later you can shuffle/slice enc = torch.tensor(TOY_BATCH_INPUT_LIST[:BATCH_SIZE], dtype=torch.long) # [B,T] encoder token ids pad = torch.tensor(TOY_BATCH_PADDING_LIST[:BATCH_SIZE], dtype=torch.bool) # [B,T] True where encoder PAD is present tgt = torch.tensor(TOY_BATCH_TARGET_LIST[:BATCH_SIZE], dtype=torch.long) # [B,T] decoder targets (ground-truth) # decoder prefix buffer: at pos 0, PAD elsewhere (no shift here) # we will fill it step by step dec = torch.tensor(TOY_BATCH_DECODER_DEFAULT[:BATCH_SIZE], dtype=torch.long) # [B,T] total_loss = 0.0 collector.reset() # start fresh for this epoch T = tgt.size(1) # sequence length for t in range(T): optimizer.zero_grad(set_to_none=True) # clear grads for this token step prefix = dec[:, : t + 1] # [B, t+1] current decoder prefix dec_pad_mask = prefix.eq(PAD_TOKEN) # [B, t+1] True where PAD inside prefix # one-step logits given prefix (trainer model expects 4 args now) logits_t: torch.Tensor = NANOSOCRATES((enc, pad, prefix, dec_pad_mask)) # [B,V] logits for step t collector.add(logits_t) # store logits for decoding later loss_t = cross_entropy(logits_t, tgt[:, t]) # CE expects raw logits; PAD ignored loss_t.backward() # backprop for this step optimizer.step() # update params scheduler.step() # Noam/warmup: step per optimizer step total_loss = float(loss_t.detach()) # keep last step loss for logging # teacher forcing: reveal the correct token for next position if t < T - 1: dec[:, t + 1] = tgt[:, t] # write ground-truth into next slot current_epoch += 1 print(f"EPOCH {current_epoch}\n\tLoss: {total_loss:.6f}") # simple log collector.print_decoded() # print decoded predictions for the batch