This commit is contained in:
Christian Risi 2025-10-09 21:53:45 +02:00
commit db0090981c
3 changed files with 157 additions and 0 deletions

156
Playgrounds/prova.py Normal file
View File

@ -0,0 +1,156 @@
import random
import torch
import pandas as pd
from pathlib import Path
import Project_Model.Libs.Embedder as Embedder
import Project_Model.Libs.BPE as BPE
import Project_Model.Libs.Transformer as Transformer
import Project_Model.Libs.TorchShims as torch_shims
# set a fixed seed
torch.manual_seed(0)
random.seed(0)
DEVICE = torch_shims.get_default_device()
torch.set_default_device(DEVICE)
# set a default device
# BPE Init
VOCABULARY_PATH = Path("Assets/Model/toy_10/toy_dictionary.json")
SPECIAL_VOC = BPE.default_special_tokens()
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)
# Constants
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + 1
EMBEDDED_SIZE = 256
FEED_FORWARD_MULTIPLIER = 4
ATTENTION_HEADS = 4
SENTENCE_LENGTH = 256
NUMBER_OF_BLOCKS = 2
MAX_EPOCHS = int(1e3)
PAD_TOKEN = TOKENANO.encode("<PAD>")[0]
END_TOKEN = TOKENANO.encode("<END>")[0]
# Load CSV
TOY_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv")
TOY_DATASET = pd.read_csv(TOY_DATASET_PATH)
TOY_BATCH_INPUT_LIST: list[list[int]] = []
TOY_BATCH_PADDING_LIST: list[list[bool]] = []
TOY_BATCH_TARGET_LIST: list[list[int]] = []
TOY_BATCH_DECODER_DEFAULT: list[list[int]]= []
for index, row in TOY_DATASET.iterrows():
RDFs: str = row["RDFs"]
Abstract: str = row["Abstract"]
input_tokens = TOKENANO.encode(RDFs)
output_tokens = TOKENANO.encode(Abstract)[1:]
decoder_default_tokens = TOKENANO.encode("<SOS>")
input_tokens, padding = Transformer.normalize_sequence(
input_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN
)
output_tokens, _ = Transformer.normalize_sequence(
output_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN
)
decoder_default_tokens, _ = Transformer.normalize_sequence(
decoder_default_tokens, SENTENCE_LENGTH, PAD_TOKEN, END_TOKEN, False
)
TOY_BATCH_INPUT_LIST.append(input_tokens)
TOY_BATCH_PADDING_LIST.append(padding)
TOY_BATCH_TARGET_LIST.append(output_tokens)
TOY_BATCH_DECODER_DEFAULT.append(decoder_default_tokens)
# Training loop
LOSS_HISTORY = []
NANOSOCRATES = Transformer.TrainingModel(
TOKEN_SPACE_SIZE,
EMBEDDED_SIZE,
FEED_FORWARD_MULTIPLIER,
ATTENTION_HEADS,
NUMBER_OF_BLOCKS
)
cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
optimizer = torch.optim.AdamW(NANOSOCRATES.parameters())
scheduler = Transformer.WarmupLR(optimizer, 4000, EMBEDDED_SIZE)
last_loss = 0
current_epoch = 0
while current_epoch < MAX_EPOCHS:
optimizer.zero_grad()
encoder_list = torch.tensor([TOY_BATCH_INPUT_LIST[0]])
decoder_list = torch.tensor([TOY_BATCH_DECODER_DEFAULT[0]])
src_padding = torch.tensor([TOY_BATCH_PADDING_LIST[0]], dtype=torch.bool)
# Transform target into logits
target_logits = torch.tensor([TOY_BATCH_TARGET_LIST[0]])
last_loss = 0
loss_list = []
last_prediction: torch.Tensor
for i in range(0, SENTENCE_LENGTH):
optimizer.zero_grad()
tgt_padding = decoder_list.eq(PAD_TOKEN)
logits: torch.Tensor = NANOSOCRATES((encoder_list, src_padding, decoder_list, tgt_padding))
prob = torch.softmax(logits, 2)
most_probable_tokens = torch.argmax(prob, 2)
last_prediction = most_probable_tokens
logits = logits[:,:i,:]
logits = logits.permute(0, 2, 1)
loss : torch.Tensor = cross_entropy(logits, target_logits[:, :i])
# loss : torch.Tensor = cross_entropy(logits, target_logits)
last_loss = loss
loss_list.append(loss.item())
loss.backward()
optimizer.step()
scheduler.step()
if i < SENTENCE_LENGTH - 1:
decoder_list[:,i+1] = target_logits[:,i]
current_epoch += 1
if current_epoch % 1 == 0:
loss_list = loss_list[1:]
print(f"EPOCH {current_epoch}\n\tLoss: {last_loss}")
print(f"ALL LOSS HISTORY:{loss_list}")
print(f"Max loss:{max(loss_list)}, Min loss: {min(loss_list)}")
for encoded_sentence, expected_sentence in zip(
Transformer.tensor2token(last_prediction[:,:], END_TOKEN), # type: ignore
Transformer.tensor2token(target_logits[:,:], END_TOKEN)
):
decoded_sentence = TOKENANO.decode(encoded_sentence)
decoded_target = TOKENANO.decode(expected_sentence)
print(f"ACTUAL:\n\t{decoded_sentence}\nEXPECTED:\n\t{decoded_target}")

Binary file not shown.

View File

@ -16,3 +16,4 @@ urllib3==2.5.0
wheel==0.45.1
Wikipedia-API==0.8.1
SQLAlchemy
torch