2025-10-12 00:05:30 +02:00
|
|
|
import random
|
|
|
|
|
import sys
|
|
|
|
|
import torch
|
|
|
|
|
import pandas as pd
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
import Project_Model.Libs.Embedder as Embedder
|
|
|
|
|
import Project_Model.Libs.BPE as BPE
|
|
|
|
|
import Project_Model.Libs.Transformer as Transformer
|
|
|
|
|
import Project_Model.Libs.TransformerUtils as TUtils
|
|
|
|
|
import Project_Model.Libs.TorchShims as torch_shims
|
|
|
|
|
import Project_Model.Libs.Batch as Batch
|
|
|
|
|
|
|
|
|
|
# set a fixed seed
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
random.seed(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# set a default device
|
|
|
|
|
DEVICE = torch_shims.get_default_device()
|
|
|
|
|
torch.set_default_device(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Get paths
|
|
|
|
|
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
|
|
|
|
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
|
|
|
|
|
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
|
|
|
|
|
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
|
|
|
|
|
CHECKPOINT_PATH = Path("Assets/Dataset/Tmp/NanoSocrates.zip")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# BPE Init
|
|
|
|
|
SPECIAL_VOC = BPE.default_special_tokens()
|
|
|
|
|
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
|
|
|
|
|
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Constants
|
|
|
|
|
MASK_EXTRA_SPACE = 100
|
|
|
|
|
REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size
|
|
|
|
|
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
|
|
|
|
|
EMBEDDED_SIZE = 256
|
|
|
|
|
FEED_FORWARD_MULTIPLIER = 4
|
|
|
|
|
ATTENTION_HEADS = 8
|
|
|
|
|
SENTENCE_LENGTH = 256
|
|
|
|
|
NUMBER_OF_BLOCKS = 4
|
2025-10-12 00:47:20 +02:00
|
|
|
MAX_EPOCHS = int(3e3)
|
2025-10-12 00:05:30 +02:00
|
|
|
PRETRAIN_EPOCHS = int(300)
|
2025-10-12 00:47:20 +02:00
|
|
|
WARMUP_EPOCHS = int(1e3)
|
2025-10-12 01:16:09 +02:00
|
|
|
MINI_BATCH_SIZE = 80
|
2025-10-12 00:57:24 +02:00
|
|
|
VALIDATION_STEPS = 5
|
2025-10-12 00:05:30 +02:00
|
|
|
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
|
|
|
|
PATIENCE = 4
|
|
|
|
|
CURRENT_EPOCH = 0
|
2025-10-12 00:57:24 +02:00
|
|
|
VERBOSE = True
|
2025-10-12 00:47:20 +02:00
|
|
|
LEARNING_RATE = 1.5
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
|
|
|
|
|
|
|
|
|
PAD_TOKEN = TOKENANO.encode("<PAD>")[0]
|
|
|
|
|
END_TOKEN = TOKENANO.encode("<END>")[0]
|
|
|
|
|
SUBJ_TOKEN = TOKENANO.encode("<SUBJ>")[0]
|
|
|
|
|
REL_TOKEN = TOKENANO.encode("<PRED>")[0]
|
|
|
|
|
OBJ_TOKEN = TOKENANO.encode("<OBJ>")[0]
|
|
|
|
|
|
|
|
|
|
SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens())))
|
|
|
|
|
ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN])
|
|
|
|
|
FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Spanned_Masker
|
|
|
|
|
MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS)
|
|
|
|
|
|
|
|
|
|
TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
|
|
|
|
VALIDATION_BATCHER = Batch.Batcher(
|
|
|
|
|
VALIDATION_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER
|
|
|
|
|
)
|
|
|
|
|
TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Model
|
|
|
|
|
NANOSOCRATES = Transformer.TrainingModel(
|
|
|
|
|
TOKEN_SPACE_SIZE,
|
|
|
|
|
EMBEDDED_SIZE,
|
|
|
|
|
FEED_FORWARD_MULTIPLIER,
|
|
|
|
|
ATTENTION_HEADS,
|
|
|
|
|
NUMBER_OF_BLOCKS,
|
|
|
|
|
)
|
|
|
|
|
_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
|
|
|
|
NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Training constants
|
|
|
|
|
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
|
|
|
|
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
|
|
|
|
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
2025-10-12 00:47:20 +02:00
|
|
|
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
|
|
|
|
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
|
|
|
|
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
2025-10-12 00:05:30 +02:00
|
|
|
nano_scheduler = Transformer.WarmupLR(nano_optim, WARMUP_EPOCHS, EMBEDDED_SIZE)
|
|
|
|
|
encoder_only_scheduler = Transformer.WarmupLR(
|
|
|
|
|
encoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
|
|
|
|
)
|
|
|
|
|
decoder_only_scheduler = Transformer.WarmupLR(
|
|
|
|
|
decoder_only_optim, WARMUP_EPOCHS, EMBEDDED_SIZE
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
current_epoch = CURRENT_EPOCH
|
|
|
|
|
patience = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
average_loss_validation = {
|
|
|
|
|
"txt": float("inf"),
|
|
|
|
|
"encoder_only": float("inf"),
|
|
|
|
|
"decoder_only": float("inf"),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
while current_epoch < MAX_EPOCHS:
|
|
|
|
|
|
|
|
|
|
NANOSOCRATES.train()
|
|
|
|
|
ENCODER_ONLY.train()
|
|
|
|
|
DECODER_ONLY.train()
|
|
|
|
|
|
|
|
|
|
text_batch_losses = []
|
|
|
|
|
encoder_batch_losses = []
|
|
|
|
|
decoder_batch_losses = []
|
|
|
|
|
|
|
|
|
|
batch_counter = 0
|
|
|
|
|
|
2025-10-12 00:13:03 +02:00
|
|
|
if VERBOSE:
|
|
|
|
|
print(f"EPOCH {current_epoch} STARTING")
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
for batch in TRAIN_BATCHER.batch(MINI_BATCH_SIZE):
|
|
|
|
|
|
|
|
|
|
batch_counter += 1
|
|
|
|
|
|
|
|
|
|
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
|
|
|
|
|
|
|
|
|
enc_x = torch.tensor(src_x)
|
|
|
|
|
|
|
|
|
|
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
|
|
|
|
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
|
|
|
|
tgt = torch.tensor(tgt_y)
|
|
|
|
|
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
|
|
|
|
|
|
|
|
|
dec_x = Transformer.get_decoder_input(
|
|
|
|
|
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
|
|
|
|
)
|
|
|
|
|
dec_x[:, 1:] = tgt[:, :-1]
|
|
|
|
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
|
|
|
|
|
2025-10-12 00:13:03 +02:00
|
|
|
if VERBOSE:
|
|
|
|
|
print(f"\tBATCH {batch_counter} Starting")
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
# Task 1 and Task 2
|
|
|
|
|
if tasktype == Batch.TaskType.RDF2TXT or tasktype == Batch.TaskType.TEXT2RDF:
|
|
|
|
|
|
2025-10-12 00:15:15 +02:00
|
|
|
if VERBOSE:
|
|
|
|
|
print(f"\tExecuting TASK 1 or 2 - BATCH {batch_counter}")
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
nano_optim.zero_grad()
|
|
|
|
|
|
|
|
|
|
pred_logits: torch.Tensor = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
|
|
|
|
pred_logits = pred_logits.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
|
loss: torch.Tensor = nano_cross_entropy(pred_logits, tgt)
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
nano_optim.step()
|
|
|
|
|
|
|
|
|
|
text_batch_losses.append(loss)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Pretrain first
|
|
|
|
|
if current_epoch < PRETRAIN_EPOCHS:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Task 3
|
|
|
|
|
if tasktype == Batch.TaskType.MASKING:
|
|
|
|
|
|
2025-10-12 00:13:03 +02:00
|
|
|
if VERBOSE:
|
|
|
|
|
print(f"\tExecuting TASK 3 - BATCH {batch_counter}")
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
encoder_only_optim.zero_grad()
|
|
|
|
|
|
|
|
|
|
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
|
|
|
|
pred_logits = pred_logits.permute(0, 2, 1)
|
2025-10-12 01:16:09 +02:00
|
|
|
# print(torch.max(tgt))
|
2025-10-12 00:05:30 +02:00
|
|
|
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
encoder_only_optim.step()
|
|
|
|
|
|
|
|
|
|
encoder_batch_losses.append(loss.item())
|
|
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Task 4
|
|
|
|
|
if tasktype == Batch.TaskType.COMPLETATION:
|
|
|
|
|
|
2025-10-12 00:13:03 +02:00
|
|
|
if VERBOSE:
|
|
|
|
|
print(f"\tExecuting TASK 4 - BATCH {batch_counter}")
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
decoder_only_optim.zero_grad()
|
|
|
|
|
|
|
|
|
|
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
|
|
|
|
pred_logits = pred_logits.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
|
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
decoder_only_optim.step()
|
|
|
|
|
|
|
|
|
|
decoder_batch_losses.append(
|
|
|
|
|
loss
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
nano_scheduler.step()
|
|
|
|
|
encoder_only_scheduler.step()
|
|
|
|
|
decoder_only_scheduler.step()
|
|
|
|
|
|
|
|
|
|
current_epoch += 1
|
|
|
|
|
|
|
|
|
|
if current_epoch % VALIDATION_STEPS == 0:
|
|
|
|
|
|
|
|
|
|
NANOSOCRATES.eval()
|
|
|
|
|
ENCODER_ONLY.eval()
|
|
|
|
|
DECODER_ONLY.eval()
|
|
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
with torch.no_grad():
|
|
|
|
|
txt_avg_batch_losses = []
|
|
|
|
|
enc_avg_batch_losses = []
|
|
|
|
|
dec_avg_batch_losses = []
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
for batch in VALIDATION_BATCHER.batch(MINI_BATCH_SIZE):
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
src_x, tgt_y, pad_x, pad_y, tasktype = batch
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
enc_x = torch.tensor(src_x)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
|
|
|
|
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
|
|
|
|
tgt = torch.tensor(tgt_y)
|
|
|
|
|
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
dec_x = Transformer.get_decoder_input(
|
|
|
|
|
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
|
|
|
|
)
|
|
|
|
|
dec_x[:, 1:] = tgt[:, :-1]
|
|
|
|
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
# Task 1 and Task 2
|
|
|
|
|
if (
|
|
|
|
|
tasktype == Batch.TaskType.RDF2TXT
|
|
|
|
|
or tasktype == Batch.TaskType.TEXT2RDF
|
|
|
|
|
):
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
pred_logits = NANOSOCRATES((enc_x, enc_x_pad, dec_x, dec_x_pad))
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
pred_logits = pred_logits.permute(0, 2, 1)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
loss: torch.Tensor = nano_cross_entropy(
|
|
|
|
|
pred_logits, tgt
|
|
|
|
|
)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
txt_avg_batch_losses.append(loss)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
continue
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
# Pretrain first
|
|
|
|
|
if current_epoch <= PRETRAIN_EPOCHS:
|
|
|
|
|
continue
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
# Task 3
|
|
|
|
|
if tasktype == Batch.TaskType.MASKING:
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
pred_logits = ENCODER_ONLY((enc_x, enc_x_pad))
|
|
|
|
|
pred_logits = pred_logits.permute(0, 2, 1)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
loss: torch.Tensor = encoder_ce(pred_logits, tgt)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
enc_avg_batch_losses.append(loss.item())
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
continue
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
# Task 4
|
|
|
|
|
if tasktype == Batch.TaskType.COMPLETATION:
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
pred_logits = pred_logits.permute(0, 2, 1)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
dec_avg_batch_losses.append(loss)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
2025-10-12 00:47:20 +02:00
|
|
|
continue
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
txt_avg_loss = sum(txt_avg_batch_losses) / len(txt_avg_batch_losses)
|
|
|
|
|
enc_avg_loss = float("inf")
|
|
|
|
|
dec_avg_loss = float("inf")
|
|
|
|
|
|
|
|
|
|
if current_epoch > PRETRAIN_EPOCHS:
|
|
|
|
|
enc_avg_loss = sum(enc_avg_batch_losses) / len(enc_avg_batch_losses)
|
|
|
|
|
dec_avg_loss = sum(dec_avg_batch_losses) / len(dec_avg_batch_losses)
|
|
|
|
|
|
|
|
|
|
if current_epoch < PRETRAIN_EPOCHS:
|
|
|
|
|
|
|
|
|
|
if txt_avg_loss < average_loss_validation["txt"]:
|
|
|
|
|
average_loss_validation["txt"] = txt_avg_loss
|
|
|
|
|
else:
|
|
|
|
|
patience += 1
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
counter = 0
|
|
|
|
|
|
|
|
|
|
if txt_avg_loss > average_loss_validation["txt"]:
|
|
|
|
|
counter += 1
|
|
|
|
|
|
|
|
|
|
if txt_avg_loss > average_loss_validation["encoder_only"]:
|
|
|
|
|
counter += 1
|
|
|
|
|
|
|
|
|
|
if txt_avg_loss > average_loss_validation["decoder_only"]:
|
|
|
|
|
counter += 1
|
|
|
|
|
|
|
|
|
|
if counter > 1:
|
|
|
|
|
patience += 1
|
|
|
|
|
|
|
|
|
|
|
2025-10-12 00:57:24 +02:00
|
|
|
txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
enc_avg_train_loss = float("inf")
|
2025-10-12 00:57:24 +02:00
|
|
|
dec_avg_train_loss = float("inf")
|
2025-10-12 00:05:30 +02:00
|
|
|
|
|
|
|
|
if current_epoch > PRETRAIN_EPOCHS:
|
|
|
|
|
try:
|
|
|
|
|
enc_avg_train_loss = sum(encoder_batch_losses) / len(encoder_batch_losses)
|
2025-10-12 00:57:24 +02:00
|
|
|
dec_avg_train_loss = sum(decoder_batch_losses) / len(decoder_batch_losses)
|
2025-10-12 00:05:30 +02:00
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
SEPARATOR = "================================================================================================================"
|
|
|
|
|
DEBUG_TEXT = "".join(
|
|
|
|
|
[
|
|
|
|
|
f"{SEPARATOR}\n",
|
|
|
|
|
f"EPOCH {current_epoch}\n",
|
|
|
|
|
f"{SEPARATOR}\n",
|
|
|
|
|
f"Train Losses:\n",
|
|
|
|
|
f"\tAvg Losses:\n",
|
2025-10-12 00:57:24 +02:00
|
|
|
f"\t\tavg_txt: {txt_train_avg_loss} - avg_enc: {enc_avg_train_loss} - avg_dec: {dec_avg_train_loss}\n",
|
2025-10-12 00:05:30 +02:00
|
|
|
f"{SEPARATOR}\n",
|
|
|
|
|
f"Validation Losses:\n",
|
|
|
|
|
f"\ttxt_loss: {txt_avg_loss} - masking_loss: {enc_avg_loss} - prediction: {dec_avg_loss}\n",
|
|
|
|
|
f"{SEPARATOR}\n",
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
print(DEBUG_TEXT)
|
|
|
|
|
|
|
|
|
|
# Warn about patience
|
|
|
|
|
if patience == PATIENCE:
|
|
|
|
|
print("Model is likely overfitting, so let's stop here")
|
|
|
|
|
|
|
|
|
|
# SAVE MODEL
|
|
|
|
|
if current_epoch % CHECKPOINT_STEPS == 0 or patience == PATIENCE:
|
|
|
|
|
print(f"Saving model at {CHECKPOINT_PATH.as_posix()}")
|
|
|
|
|
torch.save(NANOSOCRATES.state_dict(), CHECKPOINT_PATH)
|