import torch from Project_Model.Libs.Embedder import NanoSocratesEmbedder from ..Models import TrainingModel, NanoSocraDecoder, NanoSocratEncoder from ..Classes import DeToken, Encoder, Decoder from ..Enums import ModelType def decompose_nano_socrates( model: TrainingModel, vocabulary_size: int, embedding_size: int ) -> tuple[TrainingModel, NanoSocratEncoder, NanoSocraDecoder]: encoder_pieces, decoder_pieces = model.take_pieces() encoder_embedder, encoder = encoder_pieces encoder_detokener = DeToken(embedding_size, vocabulary_size) decoder_embedder, decoder, decoder_detokener = decoder_pieces return ( model, NanoSocratEncoder(encoder_embedder, encoder, encoder_detokener), NanoSocraDecoder(decoder_embedder, decoder, decoder_detokener), ) def create_standalone_model( model_type: ModelType, vocabulary_size: int, latent_space: int = 256, feed_forward_multiplier: int = 4, attention_heads: int = 4, layer_number: int = 2, ) -> NanoSocratEncoder | NanoSocraDecoder: feed_forward_latent_space = latent_space * feed_forward_multiplier embedder = NanoSocratesEmbedder(vocabulary_size, latent_space) detokener = DeToken(latent_space, vocabulary_size) if model_type == ModelType.ENCODER_ONLY: TMP_ENCODERS = [ Encoder(latent_space, feed_forward_latent_space, attention_heads) ] * layer_number encoder = torch.nn.Sequential(*TMP_ENCODERS) return NanoSocratEncoder(embedder, encoder, detokener) TMP_DECODERS = [ Decoder(latent_space, feed_forward_latent_space, attention_heads) ] * layer_number decoder = torch.nn.Sequential(*TMP_DECODERS) return NanoSocraDecoder(embedder, decoder, detokener)