2025-10-11 19:35:43 +02:00
|
|
|
import random
|
|
|
|
|
import sys
|
|
|
|
|
import torch
|
|
|
|
|
import pandas as pd
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
import Project_Model.Libs.Embedder as Embedder
|
|
|
|
|
import Project_Model.Libs.BPE as BPE
|
|
|
|
|
import Project_Model.Libs.Transformer as Transformer
|
|
|
|
|
import Project_Model.Libs.TransformerUtils as TUtils
|
|
|
|
|
import Project_Model.Libs.TorchShims as torch_shims
|
|
|
|
|
import Project_Model.Libs.Batch as Batch
|
|
|
|
|
|
|
|
|
|
# set a fixed seed
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
random.seed(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# set a default device
|
|
|
|
|
DEVICE = torch_shims.get_default_device()
|
|
|
|
|
torch.set_default_device(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Get paths
|
|
|
|
|
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
|
|
|
|
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
|
|
|
|
|
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
|
|
|
|
|
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
|
|
|
|
|
CHECKPOINT_PATH = Path("Assets/Dataset/Tmp/NanoSocrates.zip")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# BPE Init
|
|
|
|
|
SPECIAL_VOC = BPE.default_special_tokens()
|
|
|
|
|
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
|
|
|
|
|
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Constants
|
|
|
|
|
MASK_EXTRA_SPACE = 100
|
|
|
|
|
REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size
|
|
|
|
|
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
|
|
|
|
|
EMBEDDED_SIZE = 256
|
|
|
|
|
FEED_FORWARD_MULTIPLIER = 4
|
|
|
|
|
ATTENTION_HEADS = 8
|
|
|
|
|
SENTENCE_LENGTH = 256
|
|
|
|
|
NUMBER_OF_BLOCKS = 4
|
|
|
|
|
MAX_EPOCHS = int(1e3)
|
|
|
|
|
PRETRAIN_EPOCHS = int(10)
|
|
|
|
|
WARMUP_EPOCHS = int(4e3)
|
2025-10-11 22:11:53 +02:00
|
|
|
MINI_BATCH_SIZE = 100
|
2025-10-11 19:35:43 +02:00
|
|
|
VALIDATION_STEPS = 5
|
|
|
|
|
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
|
|
|
|
PATIENCE = 4
|
|
|
|
|
CURRENT_EPOCH = 0
|
|
|
|
|
|
|
|
|
|
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
|
|
|
|
|
|
|
|
|
PAD_TOKEN = TOKENANO.encode("<PAD>")[0]
|
|
|
|
|
END_TOKEN = TOKENANO.encode("<END>")[0]
|
|
|
|
|
SUBJ_TOKEN = TOKENANO.encode("<SUBJ>")[0]
|
|
|
|
|
REL_TOKEN = TOKENANO.encode("<PRED>")[0]
|
|
|
|
|
OBJ_TOKEN = TOKENANO.encode("<OBJ>")[0]
|
|
|
|
|
|
|
|
|
|
SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens())))
|
|
|
|
|
ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN])
|
|
|
|
|
FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Spanned_Masker
|
|
|
|
|
MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS)
|
|
|
|
|
|
|
|
|
|
TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
|
|
|
|
VALIDATION_BATCHER = Batch.Batcher(
|
|
|
|
|
VALIDATION_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER
|
|
|
|
|
)
|
|
|
|
|
TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Model
|
|
|
|
|
NANOSOCRATES = Transformer.TrainingModel(
|
|
|
|
|
TOKEN_SPACE_SIZE,
|
|
|
|
|
EMBEDDED_SIZE,
|
|
|
|
|
FEED_FORWARD_MULTIPLIER,
|
|
|
|
|
ATTENTION_HEADS,
|
|
|
|
|
NUMBER_OF_BLOCKS,
|
|
|
|
|
)
|
|
|
|
|
_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
|
|
|
|
NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Training constants
|
|
|
|
|
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
|
|
|
|
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
|
|
|
|
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
|
|
|
|
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters())
|
|
|
|
|
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters())
|
|
|
|
|
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters())
|
|
|
|
|
|
|
|
|
|
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
|
|
|
|
encoder_only_scheduler = Transformer.WarmupLR(
|
|
|
|
|
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
|
|
|
|
)
|
|
|
|
|
decoder_only_scheduler = Transformer.WarmupLR(
|
|
|
|
|
decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
current_epoch = CURRENT_EPOCH
|
|
|
|
|
patience = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
average_loss_validation = {
|
|
|
|
|
"txt": float("inf"),
|
|
|
|
|
"encoder_only": float("inf"),
|
|
|
|
|
"decoder_only": float("inf"),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
while current_epoch < MAX_EPOCHS:
|
|
|
|
|
|
|
|
|
|
NANOSOCRATES.train()
|
|
|
|
|
ENCODER_ONLY.train()
|
|
|
|
|
DECODER_ONLY.train()
|
|
|
|
|
|
|
|
|
|
text_batch_losses = []
|
|
|
|
|
encoder_batch_losses = []
|
|
|
|
|
decoder_batch_losses = []
|
|
|
|
|
|
2025-10-11 21:49:29 +02:00
|
|
|
batch_counter = 0
|
|
|
|
|
|
|
|
|
|
print(f"EPOCH {current_epoch} STARTING")
|
|
|
|
|
|
2025-10-11 19:35:43 +02:00
|
|
|
for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):
|
|
|
|
|
|
2025-10-11 21:49:29 +02:00
|
|
|
batch_counter += 1
|
|
|
|
|
|
2025-10-11 19:35:43 +02:00
|
|
|
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
|
|
|
|
|
2025-10-11 22:09:46 +02:00
|
|
|
|
|
|
|
|
|
2025-10-11 19:35:43 +02:00
|
|
|
enc_x = torch.tensor(src_x)
|
2025-10-11 22:09:46 +02:00
|
|
|
|
|
|
|
|
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
2025-10-11 19:35:43 +02:00
|
|
|
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
|
|
|
|
dec_x = Transformer.get_decoder_input(
|
2025-10-11 22:09:46 +02:00
|
|
|
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
2025-10-11 19:35:43 +02:00
|
|
|
)
|
|
|
|
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
|
|
|
|
tgt = torch.tensor(tgt_y)
|
|
|
|
|
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
|
|
|
|
|
2025-10-11 21:49:29 +02:00
|
|
|
print(f"\tBATCH {batch_counter} Starting")
|
|
|
|
|
|
2025-10-11 19:35:43 +02:00
|
|
|
# Task 1 and Task 2
|
|
|
|
|
if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:
|
2025-10-11 21:49:29 +02:00
|
|
|
|
|
|
|
|
print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}")
|
|
|
|
|
|
2025-10-11 19:35:43 +02:00
|
|
|
BATCH_LOSS = []
|
|
|
|
|
|
2025-10-11 21:49:29 +02:00
|
|
|
|
2025-10-11 19:35:43 +02:00
|
|
|
for token_idx in range(0, SENTENCE_LENGTH):
|
|
|
|
|
|
|
|
|
|
nano_optim.zero_grad()
|
|
|
|
|
|
|
|
|
|
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
|
|
|
|
|
|
|
|
|
pred_logits = pred_logits[:, token_idx, :]
|
|
|
|
|
|
|
|
|
|
loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt[:, token_idx])
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
nano_optim.step()
|
|
|
|
|
|
|
|
|
|
BATCH_LOSS.append(loss.item())
|
|
|
|
|
|
|
|
|
|
if token_idx < SENTENCE_LENGTH - 1:
|
|
|
|
|
dec_x[:, token_idx + 1] = tgt[:, token_idx]
|
|
|
|
|
|
|
|
|
|
MIN_BATCH_LOSS = min(BATCH_LOSS)
|
|
|
|
|
MAX_BATCH_LOSS = max(BATCH_LOSS)
|
|
|
|
|
AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE
|
|
|
|
|
|
|
|
|
|
text_batch_losses.append([MIN_BATCH_LOSS, AVG_BATCH_LOSS, MAX_BATCH_LOSS])
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Pretrain first
|
|
|
|
|
if current_epoch < PRETRAIN_EPOCHS:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Task 3
|
|
|
|
|
if tasktype == Batch.TaskType.MASKING:
|
|
|
|
|
|
2025-10-11 21:49:29 +02:00
|
|
|
print(f"\tExecuting TASK 3 - BATCH {batch_counter}")
|
|
|
|
|
|
2025-10-11 19:35:43 +02:00
|
|
|
encoder_only_optim.zero_grad()
|
|
|
|
|
|
|
|
|
|
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
|
|
|
|
pred_logits = pred_logits.permute(0, 2, 1)
|
|
|
|
|
print(torch.max(tgt))
|
|
|
|
|
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
encoder_only_optim.step()
|
|
|
|
|
|
|
|
|
|
encoder_batch_losses.append(loss.item())
|
|
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Task 4
|
|
|
|
|
if tasktype == Batch.TaskType.COMPLETATION:
|
|
|
|
|
|
2025-10-11 21:49:29 +02:00
|
|
|
print(f"\tExecuting TASK 4 - BATCH {batch_counter}")
|
|
|
|
|
|
2025-10-11 19:35:43 +02:00
|
|
|
BATCH_LOSS = []
|
|
|
|
|
|
|
|
|
|
for token_idx in range(0, SENTENCE_LENGTH):
|
|
|
|
|
|
|
|
|
|
decoder_only_optim.zero_grad()
|
|
|
|
|
|
|
|
|
|
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
|
|
|
|
|
|
|
|
|
pred_logits = pred_logits[:, token_idx, :]
|
|
|
|
|
|
|
|
|
|
loss: torch.Tensor = decoder_ce(pred_logits, tgt[:, token_idx])
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
decoder_only_optim.step()
|
|
|
|
|
|
|
|
|
|
BATCH_LOSS.append(loss.item())
|
|
|
|
|
|
|
|
|
|
if token_idx < SENTENCE_LENGTH - 1:
|
|
|
|
|
dec_x[:, token_idx + 1] = tgt[:, token_idx]
|
|
|
|
|
|
|
|
|
|
MIN_BATCH_LOSS = min(BATCH_LOSS)
|
|
|
|
|
MAX_BATCH_LOSS = max(BATCH_LOSS)
|
|
|
|
|
AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE
|
|
|
|
|
|
|
|
|
|
decoder_batch_losses.append(
|
|
|
|
|
[MIN_BATCH_LOSS, AVG_BATCH_LOSS, MAX_BATCH_LOSS]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
nano_scheduler.step()
|
|
|
|
|
encoder_only_scheduler.step()
|
|
|
|
|
decoder_only_scheduler.step()
|
|
|
|
|
|
|
|
|
|
current_epoch += 1
|
|
|
|
|
|
|
|
|
|
if current_epoch % VALIDATION_STEPS == 0:
|
|
|
|
|
|
|
|
|
|
NANOSOCRATES.eval()
|
|
|
|
|
ENCODER_ONLY.eval()
|
|
|
|
|
DECODER_ONLY.eval()
|
|
|
|
|
|
|
|
|
|
txt_avg_batch_losses = []
|
|
|
|
|
enc_avg_batch_losses = []
|
|
|
|
|
dec_avg_batch_losses = []
|
|
|
|
|
|
|
|
|
|
for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):
|
|
|
|
|
|
|
|
|
|
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
|
|
|
|
|
|
|
|
|
enc_x = torch.tensor(src_x)
|
2025-10-11 22:09:46 +02:00
|
|
|
|
|
|
|
|
ACTUAL_BATCH_SIZE, _, _ = enc_x.shape
|
2025-10-11 19:35:43 +02:00
|
|
|
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
|
|
|
|
dec_x = Transformer.get_decoder_input(
|
2025-10-11 22:09:46 +02:00
|
|
|
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
2025-10-11 19:35:43 +02:00
|
|
|
)
|
2025-10-11 22:09:46 +02:00
|
|
|
|
2025-10-11 19:35:43 +02:00
|
|
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
|
|
|
|
tgt = torch.tensor(tgt_y)
|
|
|
|
|
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
|
|
|
|
|
|
|
|
|
# Task 1 and Task 2
|
|
|
|
|
if (
|
|
|
|
|
tasktype == Batch.TaskType.RDF2TXT
|
|
|
|
|
or tasktype == Batch.TaskType.TEXT2RDF
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
BATCH_LOSS = []
|
|
|
|
|
|
|
|
|
|
for token_idx in range(0, SENTENCE_LENGTH):
|
|
|
|
|
|
|
|
|
|
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
|
|
|
|
|
|
|
|
|
pred_logits = pred_logits[:, token_idx, :]
|
|
|
|
|
|
|
|
|
|
loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt[:, token_idx])
|
|
|
|
|
|
|
|
|
|
BATCH_LOSS.append(loss.item())
|
|
|
|
|
|
|
|
|
|
if token_idx < SENTENCE_LENGTH - 1:
|
|
|
|
|
dec_x[:, token_idx + 1] = tgt[:, token_idx]
|
|
|
|
|
|
|
|
|
|
AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE
|
|
|
|
|
txt_avg_batch_losses.append(AVG_BATCH_LOSS)
|
|
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Pretrain first
|
|
|
|
|
if current_epoch < PRETRAIN_EPOCHS:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Task 3
|
|
|
|
|
if tasktype == Batch.TaskType.MASKING:
|
|
|
|
|
|
|
|
|
|
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
|
|
|
|
pred_logits = pred_logits.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
|
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
|
|
|
|
|
|
|
|
|
enc_avg_batch_losses.append(loss.item())
|
|
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Task 4
|
|
|
|
|
if tasktype == Batch.TaskType.COMPLETATION:
|
|
|
|
|
|
|
|
|
|
BATCH_LOSS = []
|
|
|
|
|
|
|
|
|
|
for token_idx in range(0, SENTENCE_LENGTH):
|
|
|
|
|
|
|
|
|
|
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
|
|
|
|
|
|
|
|
|
pred_logits = pred_logits[:, token_idx, :]
|
|
|
|
|
|
|
|
|
|
loss: torch.Tensor = decoder_ce(pred_logits, tgt[:, token_idx])
|
|
|
|
|
|
|
|
|
|
BATCH_LOSS.append(loss.item())
|
|
|
|
|
|
|
|
|
|
if token_idx < SENTENCE_LENGTH - 1:
|
|
|
|
|
dec_x[:, token_idx + 1] = tgt[:, token_idx]
|
|
|
|
|
|
|
|
|
|
AVG_BATCH_LOSS = sum(BATCH_LOSS) / MINI_BATCH_SIZE
|
|
|
|
|
|
|
|
|
|
dec_avg_batch_losses.append(AVG_BATCH_LOSS)
|
|
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)
|
|
|
|
|
enc_avg_loss = float("inf")
|
|
|
|
|
dec_avg_loss = float("inf")
|
|
|
|
|
|
|
|
|
|
if current_epoch >= PRETRAIN_EPOCHS:
|
|
|
|
|
enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses)
|
|
|
|
|
dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses)
|
|
|
|
|
|
|
|
|
|
if current_epoch < PRETRAIN_EPOCHS:
|
|
|
|
|
|
|
|
|
|
if txt_avg_loss < average_loss_validation["txt"]:
|
|
|
|
|
average_loss_validation["txt"] = txt_avg_loss
|
|
|
|
|
else:
|
|
|
|
|
patience += 1
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
counter = 0
|
|
|
|
|
|
|
|
|
|
if txt_avg_loss > average_loss_validation["txt"]:
|
|
|
|
|
counter += 1
|
|
|
|
|
|
|
|
|
|
if txt_avg_loss > average_loss_validation["encoder_only"]:
|
|
|
|
|
counter += 1
|
|
|
|
|
|
|
|
|
|
if txt_avg_loss > average_loss_validation["decoder_only"]:
|
|
|
|
|
counter += 1
|
|
|
|
|
|
|
|
|
|
if counter > 1:
|
|
|
|
|
patience += 1
|
|
|
|
|
|
|
|
|
|
txt_min_train_losses = [row[0] for row in text_batch_losses]
|
|
|
|
|
txt_avg_train_losses = [row[1] for row in text_batch_losses]
|
|
|
|
|
txt_max_train_losses = [row[2] for row in text_batch_losses]
|
|
|
|
|
|
|
|
|
|
txt_min_loss = min(txt_min_train_losses)
|
|
|
|
|
txt_avg_min_loss = sum(txt_min_train_losses) / len(txt_min_train_losses)
|
|
|
|
|
txt_max_loss = max(txt_max_train_losses)
|
|
|
|
|
txt_avg_max_loss = sum(txt_max_train_losses) / len(txt_max_train_losses)
|
|
|
|
|
txt_avg_loss = sum(txt_avg_train_losses) / len(txt_avg_train_losses)
|
|
|
|
|
|
|
|
|
|
enc_avg_train_loss = float("inf")
|
|
|
|
|
|
|
|
|
|
dec_min_loss = float("inf")
|
|
|
|
|
dec_avg_min_loss = float("inf")
|
|
|
|
|
dec_max_loss = float("inf")
|
|
|
|
|
dec_avg_max_loss = float("inf")
|
|
|
|
|
dec_avg_loss = float("inf")
|
|
|
|
|
|
|
|
|
|
if current_epoch >= PRETRAIN_EPOCHS:
|
|
|
|
|
enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses)
|
|
|
|
|
|
|
|
|
|
dec_min_train_losses = [row[0] for row in decoder_batch_losses]
|
|
|
|
|
dec_avg_train_losses = [row[1] for row in decoder_batch_losses]
|
|
|
|
|
dec_max_train_losses = [row[2] for row in decoder_batch_losses]
|
|
|
|
|
|
|
|
|
|
dec_min_loss = min(dec_min_train_losses)
|
|
|
|
|
dec_avg_min_loss = sum(dec_min_train_losses) / len(dec_min_train_losses)
|
|
|
|
|
dec_max_loss = max(dec_max_train_losses)
|
|
|
|
|
dec_avg_max_loss = sum(dec_max_train_losses) / len(dec_max_train_losses)
|
|
|
|
|
dec_avg_loss = sum(dec_avg_train_losses) / len(dec_avg_train_losses)
|
|
|
|
|
|
|
|
|
|
SEPARATOR = "================================================================================================================"
|
|
|
|
|
DEBUG_TEXT = "".join(
|
|
|
|
|
[
|
|
|
|
|
f"{SEPARATOR}\n",
|
|
|
|
|
f"EPOCH {current_epoch}\n",
|
|
|
|
|
f"{SEPARATOR}\n",
|
|
|
|
|
f"Train Losses:\n",
|
|
|
|
|
f"\tMin Losses:\n",
|
|
|
|
|
f"\t\tmin_txt: {txt_min_loss} - avg_txt: {txt_avg_min_loss}\n",
|
|
|
|
|
f"\t\tmin_dec: {dec_min_loss} - avg_dec: {dec_avg_min_loss}\n",
|
|
|
|
|
f"\tMax Losses:\n",
|
|
|
|
|
f"\t\tmax_txt: {txt_max_loss} - avg_txt: {txt_avg_max_loss}\n",
|
|
|
|
|
f"\t\tmax_dec: {dec_min_loss} - avg_dec: {dec_avg_max_loss}\n",
|
|
|
|
|
f"\tAvg Losses:\n",
|
|
|
|
|
f"\t\tavg_txt: {txt_avg_loss} - avg_enc: {enc_avg_loss} - avg_dec: {dec_avg_loss}\n",
|
|
|
|
|
f"{SEPARATOR}\n",
|
|
|
|
|
f"Validation Losses:\n",
|
|
|
|
|
f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n",
|
|
|
|
|
f"{SEPARATOR}\n",
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
print(DEBUG_TEXT)
|
|
|
|
|
|
|
|
|
|
# Warn about patience
|
|
|
|
|
if patience == PATIENCE:
|
|
|
|
|
print("Model is likely overfitting, so let's stop here")
|
|
|
|
|
|
|
|
|
|
# SAVE MODEL
|
|
|
|
|
if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE:
|
|
|
|
|
print(f"Saving model at {CHECKPOINT_PATH.as_posix()}")
|
|
|
|
|
torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)
|