diff --git a/Playgrounds/nanosocrates-train-experiment-2.py b/Playgrounds/nanosocrates-train-experiment-2.py index 48a804c..d845fe0 100644 --- a/Playgrounds/nanosocrates-train-experiment-2.py +++ b/Playgrounds/nanosocrates-train-experiment-2.py @@ -78,6 +78,7 @@ TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKE # Model + NANOSOCRATES = Transformer.TrainingModel( TOKEN_SPACE_SIZE, EMBEDDED_SIZE, @@ -85,6 +86,11 @@ NANOSOCRATES = Transformer.TrainingModel( ATTENTION_HEADS, NUMBER_OF_BLOCKS, ) + +if CHECKPOINT_PATH.is_file(): + nanosocrates_dict = torch.load(CHECKPOINT_PATH, weights_only=True) + NANOSOCRATES.load_state_dict(nanosocrates_dict) + _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates( NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE )