126 lines
5.4 KiB
Python
126 lines
5.4 KiB
Python
|
|
import random
|
||
|
|
import torch
|
||
|
|
import pandas as pd
|
||
|
|
from pathlib import Path
|
||
|
|
import Project_Model.Libs.Embedder as Embedder
|
||
|
|
import Project_Model.Libs.BPE as BPE
|
||
|
|
import Project_Model.Libs.Transformer as Transformer
|
||
|
|
import Project_Model.Libs.TorchShims as torch_shims
|
||
|
|
from Project_Model.Libs.Training.learning_rade_shedulers import Custom_lr
|
||
|
|
from Project_Model.Libs.Training.logistic_collector import LogitsCollector # import the external collector
|
||
|
|
|
||
|
|
# set a fixed seed
|
||
|
|
torch.manual_seed(0)
|
||
|
|
random.seed(0)
|
||
|
|
DEVICE = torch_shims.get_default_device()
|
||
|
|
torch.set_default_device(DEVICE)
|
||
|
|
|
||
|
|
# BPE Init
|
||
|
|
VOCABULARY_PATH = Path("Assets/Model/toy_10/toy_dictionary.json")
|
||
|
|
SPECIAL_VOC = BPE.default_special_tokens()
|
||
|
|
|
||
|
|
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
|
||
|
|
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)
|
||
|
|
|
||
|
|
# Constants
|
||
|
|
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + 1
|
||
|
|
EMBEDDED_SIZE = 256
|
||
|
|
FEED_FORWARD_MULTIPLIER = 4
|
||
|
|
ATTENTION_HEADS = 4
|
||
|
|
SENTENCE_LENGTH = 256
|
||
|
|
NUMBER_OF_BLOCKS = 2
|
||
|
|
MAX_EPOCHS = int(1e3)
|
||
|
|
|
||
|
|
PAD_TOKEN = TOKENANO.encode("<PAD>")[0]
|
||
|
|
END_TOKEN = TOKENANO.encode("<END>")[0]
|
||
|
|
|
||
|
|
# Load CSV
|
||
|
|
TOY_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv")
|
||
|
|
TOY_DATASET = pd.read_csv(TOY_DATASET_PATH)
|
||
|
|
|
||
|
|
TOY_BATCH_INPUT_LIST: list[list[int]] = []
|
||
|
|
TOY_BATCH_PADDING_LIST: list[list[bool]] = []
|
||
|
|
TOY_BATCH_TARGET_LIST: list[list[int]] = []
|
||
|
|
TOY_BATCH_DECODER_DEFAULT: list[list[int]] = []
|
||
|
|
|
||
|
|
for index, row in TOY_DATASET.iterrows():
|
||
|
|
RDFs: str = row["RDFs"]
|
||
|
|
Abstract: str = row["Abstract"]
|
||
|
|
|
||
|
|
input_tokens = TOKENANO.encode(RDFs) # encoder input ids
|
||
|
|
output_tokens = TOKENANO.encode(Abstract)[1:] # decoder target ids (shifted left)
|
||
|
|
decoder_default_tokens = TOKENANO.encode("<SOS>") # decoder input starts with <SOS>
|
||
|
|
|
||
|
|
input_tokens, padding = Transformer.normalize_sequence(
|
||
|
|
input_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN
|
||
|
|
) # pad/trim + end token
|
||
|
|
output_tokens, _ = Transformer.normalize_sequence(
|
||
|
|
output_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN
|
||
|
|
) # pad/trim + end token
|
||
|
|
decoder_default_tokens = Transformer.pad_sequence(
|
||
|
|
decoder_default_tokens, SENTENCE_LENGTH, PAD_TOKEN
|
||
|
|
) # pad with PAD up to SENTENCE_LENGTH
|
||
|
|
|
||
|
|
TOY_BATCH_INPUT_LIST.append(input_tokens)
|
||
|
|
TOY_BATCH_PADDING_LIST.append(padding)
|
||
|
|
TOY_BATCH_TARGET_LIST.append(output_tokens)
|
||
|
|
TOY_BATCH_DECODER_DEFAULT.append(decoder_default_tokens)
|
||
|
|
|
||
|
|
# Training loop
|
||
|
|
LOSS_HISTORY = []
|
||
|
|
NANOSOCRATES = Transformer.TrainingModel(
|
||
|
|
TOKEN_SPACE_SIZE,
|
||
|
|
EMBEDDED_SIZE,
|
||
|
|
FEED_FORWARD_MULTIPLIER,
|
||
|
|
ATTENTION_HEADS,
|
||
|
|
NUMBER_OF_BLOCKS,
|
||
|
|
)
|
||
|
|
|
||
|
|
collector = LogitsCollector(PAD_TOKEN, END_TOKEN, TOKENANO) # collects logits and decodes
|
||
|
|
|
||
|
|
NANOSOCRATES.train()
|
||
|
|
cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||
|
|
optimizer = torch.optim.AdamW(NANOSOCRATES.parameters())
|
||
|
|
scheduler = Custom_lr(EMBEDDED_SIZE, 4000) # step each optimizer step
|
||
|
|
|
||
|
|
current_epoch = 0
|
||
|
|
BATCH_SIZE = min(32, len(TOY_BATCH_INPUT_LIST)) # small batch to stabilize
|
||
|
|
|
||
|
|
while current_epoch < MAX_EPOCHS:
|
||
|
|
# simple fixed mini-batch from the top; later you can shuffle/slice
|
||
|
|
enc = torch.tensor(TOY_BATCH_INPUT_LIST[:BATCH_SIZE], dtype=torch.long) # [B,T] encoder token ids
|
||
|
|
pad = torch.tensor(TOY_BATCH_PADDING_LIST[:BATCH_SIZE], dtype=torch.bool) # [B,T] True where encoder PAD is present
|
||
|
|
tgt = torch.tensor(TOY_BATCH_TARGET_LIST[:BATCH_SIZE], dtype=torch.long) # [B,T] decoder targets (ground-truth)
|
||
|
|
|
||
|
|
# decoder prefix buffer: <SOS> at pos 0, PAD elsewhere (no shift here) # we will fill it step by step
|
||
|
|
dec = torch.tensor(TOY_BATCH_DECODER_DEFAULT[:BATCH_SIZE], dtype=torch.long) # [B,T]
|
||
|
|
|
||
|
|
total_loss = 0.0
|
||
|
|
collector.reset() # start fresh for this epoch
|
||
|
|
|
||
|
|
T = tgt.size(1) # sequence length
|
||
|
|
for t in range(T):
|
||
|
|
optimizer.zero_grad(set_to_none=True) # clear grads for this token step
|
||
|
|
|
||
|
|
prefix = dec[:, : t + 1] # [B, t+1] current decoder prefix
|
||
|
|
dec_pad_mask = prefix.eq(PAD_TOKEN) # [B, t+1] True where PAD inside prefix
|
||
|
|
|
||
|
|
# one-step logits given prefix (trainer model expects 4 args now)
|
||
|
|
logits_t: torch.Tensor = NANOSOCRATES((enc, pad, prefix, dec_pad_mask)) # [B,V] logits for step t
|
||
|
|
collector.add(logits_t) # store logits for decoding later
|
||
|
|
|
||
|
|
loss_t = cross_entropy(logits_t, tgt[:, t]) # CE expects raw logits; PAD ignored
|
||
|
|
loss_t.backward() # backprop for this step
|
||
|
|
optimizer.step() # update params
|
||
|
|
scheduler.step() # Noam/warmup: step per optimizer step
|
||
|
|
|
||
|
|
total_loss = float(loss_t.detach()) # keep last step loss for logging
|
||
|
|
|
||
|
|
# teacher forcing: reveal the correct token for next position
|
||
|
|
if t < T - 1:
|
||
|
|
dec[:, t + 1] = tgt[:, t] # write ground-truth into next slot
|
||
|
|
|
||
|
|
current_epoch += 1
|
||
|
|
print(f"EPOCH {current_epoch}\n\tLoss: {total_loss:.6f}") # simple log
|
||
|
|
collector.print_decoded() # print decoded predictions for the batch
|