In [None]:
import random
import torch
import pandas as pd
from pathlib import Path
import Project_Model.Libs.Embedder as Embedder
import Project_Model.Libs.BPE as BPE
import Project_Model.Libs.Transformer as Transformer
import Project_Model.Libs.TorchShims as torch_shims

# set a fixed seed
torch.manual_seed(0)
random.seed(0)
DEVICE = torch_shims.get_default_device()
torch.set_default_device(DEVICE)

# set a default device

# BPE Init
VOCABULARY_PATH = Path("Assets/Model/toy_10/toy_dictionary.json")
SPECIAL_VOC = BPE.default_special_tokens()

VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)


# Constants
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + 1
EMBEDDED_SIZE = 256
FEED_FORWARD_DIM = EMBEDDED_SIZE * 4
ATTENTION_HEADS = 4
SENTENCE_LENGTH = 256
NUMBER_OF_BLOCKS = 2
MAX_EPOCHS = int(1e3)


# Model Init
ENCODER_EMBEDDER = Embedder.NanoSocratesEmbedder(TOKEN_SPACE_SIZE, EMBEDDED_SIZE)
DECODER_EMBEDDER = Embedder.NanoSocratesEmbedder(TOKEN_SPACE_SIZE, EMBEDDED_SIZE)

ENCODERS = [
    Transformer.Encoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, ATTENTION_HEADS)
] * NUMBER_OF_BLOCKS
ENCODER = torch.nn.Sequential(*ENCODERS)

DECODERS = [
    Transformer.Decoder(EMBEDDED_SIZE, FEED_FORWARD_DIM, ATTENTION_HEADS)
] * NUMBER_OF_BLOCKS
DECODER = torch.nn.Sequential(*DECODERS)

DETOKENER = Transformer.DeToken(EMBEDDED_SIZE, TOKEN_SPACE_SIZE)

PAD_TOKEN = TOKENANO.encode("<PAD>")[0]
END_TOKEN = TOKENANO.encode("<END>")[0]


# Load CSV
TOY_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv")

TOY_DATASET = pd.read_csv(TOY_DATASET_PATH)

TOY_BATCH_INPUT_LIST: list[list[int]] = []
TOY_BATCH_PADDING_LIST: list[list[bool]] = []
TOY_BATCH_TARGET_LIST: list[list[int]] = []
TOY_BATCH_DECODER_DEFAULT: list[list[int]]= []


for index, row in TOY_DATASET.iterrows():

    RDFs: str = row["RDFs"]
    Abstract: str = row["Abstract"]

    input_tokens = TOKENANO.encode(RDFs)
    output_tokens = TOKENANO.encode(Abstract)
    decoder_default_tokens = TOKENANO.encode("<SOS>")

    input_tokens, padding = Transformer.normalize_sequence(
        input_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN
    )
    output_tokens, _ = Transformer.normalize_sequence(
        output_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN
    )
    decoder_default_tokens, _ = Transformer.normalize_sequence(
        decoder_default_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN
    )

    TOY_BATCH_INPUT_LIST.append(input_tokens)
    TOY_BATCH_PADDING_LIST.append(padding)
    TOY_BATCH_TARGET_LIST.append(output_tokens)
    TOY_BATCH_DECODER_DEFAULT.append(decoder_default_tokens)

# Training loop
LOSS_HISTORY = []
NANOSOCRATES = torch.nn.ModuleList([
    ENCODER_EMBEDDER,
    ENCODER,
    DECODER_EMBEDDER,
    DECODER,
    DETOKENER
])
cross_entropy = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(NANOSOCRATES.parameters())
# scheduler = torch.optim.lr_scheduler.LRScheduler(optimizer)
last_loss = 0
current_epoch = 0

while current_epoch < MAX_EPOCHS:

    optimizer.zero_grad()

    INPUT_LIST = TOY_BATCH_INPUT_LIST[:]
    TARGET_LIST = TOY_BATCH_TARGET_LIST[:]
    # Transform target into logits
    target_logits = torch.tensor(TOY_BATCH_TARGET_LIST[:])
    DECODER_DEFAULT_LIST = TOY_BATCH_DECODER_DEFAULT[:]
    PADDINGS = torch.tensor(TOY_BATCH_PADDING_LIST, dtype=torch.bool)
    ENCODER_INPUTS = ENCODER_EMBEDDER(INPUT_LIST)
    DECODER_INPUTS = DECODER_EMBEDDER(DECODER_DEFAULT_LIST)

    for _ in range(0, SENTENCE_LENGTH):

        optimizer.zero_grad()

        

        encoder_output, _ = ENCODER((ENCODER_INPUTS, PADDINGS))

        decoder_output, _, _, _ = DECODER(
            (DECODER_INPUTS, encoder_output, encoder_output, None)
        )

        logits: torch.Tensor = DETOKENER(decoder_output)
        logits = logits.permute(0, 2, 1)


        loss: torch.Tensor= cross_entropy(logits, target_logits)
        last_loss = loss
        LOSS_HISTORY.append(loss)

        loss.backward()
        optimizer.step()
        # scheduler.step()

        most_probable_tokens = torch.argmax(logits, 2)

    if current_epoch % 10 == 0:

        print(f"EPOCH {current_epoch}\n\tCurrent Loss: {last_loss}")

    current_epoch += 1








  return func(*args, **kwargs)


EPOCH 0
	Current Loss: 8.951058387756348
EPOCH 10
	Current Loss: 8.913984298706055
EPOCH 20
	Current Loss: 8.911956787109375
EPOCH 30
	Current Loss: 8.911856651306152
EPOCH 40
	Current Loss: 8.911840438842773
EPOCH 50
	Current Loss: 8.911835670471191
EPOCH 60
	Current Loss: 8.911831855773926
EPOCH 70
	Current Loss: 8.91179084777832
EPOCH 80
	Current Loss: 8.899038314819336
EPOCH 90
	Current Loss: 8.898558616638184


KeyboardInterrupt: 