Added a way to load checkpoints
This commit is contained in:
parent
4ca1d0a189
commit
37a2501a79
@ -78,6 +78,7 @@ TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKE
|
|||||||
|
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
|
|
||||||
NANOSOCRATES = Transformer.TrainingModel(
|
NANOSOCRATES = Transformer.TrainingModel(
|
||||||
TOKEN_SPACE_SIZE,
|
TOKEN_SPACE_SIZE,
|
||||||
EMBEDDED_SIZE,
|
EMBEDDED_SIZE,
|
||||||
@ -85,6 +86,11 @@ NANOSOCRATES = Transformer.TrainingModel(
|
|||||||
ATTENTION_HEADS,
|
ATTENTION_HEADS,
|
||||||
NUMBER_OF_BLOCKS,
|
NUMBER_OF_BLOCKS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if CHECKPOINT_PATH.is_file():
|
||||||
|
nanosocrates_dict = torch.load(CHECKPOINT_PATH, weights_only=True)
|
||||||
|
NANOSOCRATES.load_state_dict(nanosocrates_dict)
|
||||||
|
|
||||||
_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
||||||
NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE
|
NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user