Fixed imports
This commit is contained in:
6
Project_Model/Libs/TransformerUtils/ModelType.py
Normal file
6
Project_Model/Libs/TransformerUtils/ModelType.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
class ModelType(Enum):
|
||||
|
||||
ENCODER_ONLY = auto()
|
||||
DECODER_ONLY = auto()
|
||||
8
Project_Model/Libs/TransformerUtils/__init__.py
Normal file
8
Project_Model/Libs/TransformerUtils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .model_utils import decompose_nano_socrates, create_standalone_model
|
||||
from .ModelType import ModelType
|
||||
|
||||
__all__ = [
|
||||
"ModelType",
|
||||
"decompose_nano_socrates",
|
||||
"create_standalone_model"
|
||||
]
|
||||
53
Project_Model/Libs/TransformerUtils/model_utils.py
Normal file
53
Project_Model/Libs/TransformerUtils/model_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from Project_Model.Libs.Embedder import NanoSocratesEmbedder
|
||||
from Project_Model.Libs.Transformer import TrainingModel, NanoSocraDecoder, NanoSocratEncoder, DeToken, Encoder, Decoder
|
||||
from .ModelType import ModelType
|
||||
|
||||
|
||||
|
||||
def decompose_nano_socrates(
|
||||
model: TrainingModel, vocabulary_size: int, embedding_size: int
|
||||
) -> tuple[TrainingModel, NanoSocratEncoder, NanoSocraDecoder]:
|
||||
|
||||
encoder_pieces, decoder_pieces = model.take_pieces()
|
||||
encoder_embedder, encoder = encoder_pieces
|
||||
encoder_detokener = DeToken(embedding_size, vocabulary_size)
|
||||
decoder_embedder, decoder, decoder_detokener = decoder_pieces
|
||||
|
||||
return (
|
||||
model,
|
||||
NanoSocratEncoder(encoder_embedder, encoder, encoder_detokener),
|
||||
NanoSocraDecoder(decoder_embedder, decoder, decoder_detokener),
|
||||
)
|
||||
|
||||
|
||||
def create_standalone_model(
|
||||
model_type: ModelType,
|
||||
vocabulary_size: int,
|
||||
latent_space: int = 256,
|
||||
feed_forward_multiplier: int = 4,
|
||||
attention_heads: int = 4,
|
||||
layer_number: int = 2,
|
||||
) -> NanoSocratEncoder | NanoSocraDecoder:
|
||||
|
||||
feed_forward_latent_space = latent_space * feed_forward_multiplier
|
||||
|
||||
embedder = NanoSocratesEmbedder(vocabulary_size, latent_space)
|
||||
detokener = DeToken(latent_space, vocabulary_size)
|
||||
|
||||
if model_type == ModelType.ENCODER_ONLY:
|
||||
TMP_ENCODERS = [
|
||||
Encoder(latent_space, feed_forward_latent_space, attention_heads)
|
||||
] * layer_number
|
||||
|
||||
encoder = torch.nn.Sequential(*TMP_ENCODERS)
|
||||
|
||||
return NanoSocratEncoder(embedder, encoder, detokener)
|
||||
|
||||
TMP_DECODERS = [
|
||||
Decoder(latent_space, feed_forward_latent_space, attention_heads)
|
||||
] * layer_number
|
||||
|
||||
decoder = torch.nn.Sequential(*TMP_DECODERS)
|
||||
|
||||
return NanoSocraDecoder(embedder, decoder, detokener)
|
||||
Reference in New Issue
Block a user