53 lines
1.7 KiB
Python
Raw Normal View History

import torch
from Project_Model.Libs.Embedder import NanoSocratesEmbedder
2025-10-11 11:18:44 +02:00
from Project_Model.Libs.Transformer import TrainingModel, NanoSocraDecoder, NanoSocratEncoder, DeToken, Encoder, Decoder
from .ModelType import ModelType
def decompose_nano_socrates(
model: TrainingModel, vocabulary_size: int, embedding_size: int
) -> tuple[TrainingModel, NanoSocratEncoder, NanoSocraDecoder]:
encoder_pieces, decoder_pieces = model.take_pieces()
2025-10-12 16:30:30 +02:00
encoder_embedder, encoder, encoder_detokener = encoder_pieces
decoder_embedder, decoder, decoder_detokener = decoder_pieces
return (
model,
NanoSocratEncoder(encoder_embedder, encoder, encoder_detokener),
NanoSocraDecoder(decoder_embedder, decoder, decoder_detokener),
)
def create_standalone_model(
model_type: ModelType,
vocabulary_size: int,
latent_space: int = 256,
feed_forward_multiplier: int = 4,
attention_heads: int = 4,
layer_number: int = 2,
) -> NanoSocratEncoder | NanoSocraDecoder:
feed_forward_latent_space = latent_space * feed_forward_multiplier
embedder = NanoSocratesEmbedder(vocabulary_size, latent_space)
detokener = DeToken(latent_space, vocabulary_size)
if model_type == ModelType.ENCODER_ONLY:
TMP_ENCODERS = [
Encoder(latent_space, feed_forward_latent_space, attention_heads)
] * layer_number
encoder = torch.nn.Sequential(*TMP_ENCODERS)
return NanoSocratEncoder(embedder, encoder, detokener)
TMP_DECODERS = [
Decoder(latent_space, feed_forward_latent_space, attention_heads)
] * layer_number
decoder = torch.nn.Sequential(*TMP_DECODERS)
return NanoSocraDecoder(embedder, decoder, detokener)