Added a way to load checkpoints

This commit is contained in:
Christian Risi 2025-10-12 12:28:24 +02:00
parent 4ca1d0a189
commit 37a2501a79

View File

@ -78,6 +78,7 @@ TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKE
# Model
NANOSOCRATES = Transformer.TrainingModel(
TOKEN_SPACE_SIZE,
EMBEDDED_SIZE,
@ -85,6 +86,11 @@ NANOSOCRATES = Transformer.TrainingModel(
ATTENTION_HEADS,
NUMBER_OF_BLOCKS,
)
if CHECKPOINT_PATH.is_file():
nanosocrates_dict = torch.load(CHECKPOINT_PATH, weights_only=True)
NANOSOCRATES.load_state_dict(nanosocrates_dict)
_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE
)