Added a way to detach models and create them standalone
This commit is contained in:
parent
15f203cad5
commit
92ae40013d
@ -1,3 +1,4 @@
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .FeedForwardNetwork import FeedForwardNetwork
|
||||
@ -42,13 +43,17 @@ class Decoder(nn.Module):
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor
|
||||
torch.Tensor,
|
||||
Optional[bool]
|
||||
]
|
||||
): # -> list[torch.Tensor]: # k_x = v_x . While x_q = x
|
||||
# WARNING: args is needed to have sequential
|
||||
x, k_x, v_x, src_padding_mask, tgt_padding_mask = args
|
||||
if len(args) < 6:
|
||||
args = args + (False)
|
||||
x, k_x, v_x, src_padding_mask, tgt_padding_mask, decoder_only = args
|
||||
|
||||
# build of attention mask
|
||||
# TODO: create a prefix causal mask if needed
|
||||
attention_mask = get_causal_attention_mask(x.size(1))
|
||||
|
||||
# 1) Masked Attention
|
||||
@ -67,21 +72,23 @@ class Decoder(nn.Module):
|
||||
# 4) Layer Normalization
|
||||
x = self.__layer_norm_1(x)
|
||||
|
||||
# 5) Encoder–decoder (cross) attention
|
||||
CROSS_ATTENTION = self.__cross_attention(
|
||||
x, k_x, v_x, key_padding_mask=src_padding_mask
|
||||
)
|
||||
|
||||
# 6) Dropout
|
||||
# DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)
|
||||
# del CROSS_ATTENTION
|
||||
if not decoder_only:
|
||||
# 5) Encoder–decoder (cross) attention
|
||||
CROSS_ATTENTION = self.__cross_attention(
|
||||
x, k_x, v_x, key_padding_mask=src_padding_mask
|
||||
)
|
||||
|
||||
# 7) Residual Connection
|
||||
x = x + CROSS_ATTENTION
|
||||
del CROSS_ATTENTION
|
||||
# 6) Dropout
|
||||
# DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)
|
||||
# del CROSS_ATTENTION
|
||||
|
||||
# 8) Layer Normalization
|
||||
x = self.__layer_norm_2(x)
|
||||
# 7) Residual Connection
|
||||
x = x + CROSS_ATTENTION
|
||||
del CROSS_ATTENTION
|
||||
|
||||
# 8) Layer Normalization
|
||||
x = self.__layer_norm_2(x)
|
||||
|
||||
# 9) Position-wise feed-forward
|
||||
FEED_FORWARD = self.__feed_forward_network(x)
|
||||
@ -97,7 +104,7 @@ class Decoder(nn.Module):
|
||||
# 12) Layer Normalization
|
||||
x = self.__layer_norm_3(x)
|
||||
|
||||
return (x, k_x, v_x, src_padding_mask, tgt_padding_mask)
|
||||
return (x, k_x, v_x, src_padding_mask, tgt_padding_mask, decoder_only)
|
||||
|
||||
|
||||
# use eval to disable dropout ecc
|
||||
|
||||
6
Project_Model/Libs/Transformer/Enums/ModelType.py
Normal file
6
Project_Model/Libs/Transformer/Enums/ModelType.py
Normal file
@ -0,0 +1,6 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
class ModelType(Enum):
|
||||
|
||||
ENCODER_ONLY = auto()
|
||||
DECODER_ONLY = auto()
|
||||
@ -0,0 +1 @@
|
||||
from .ModelType import ModelType
|
||||
32
Project_Model/Libs/Transformer/Models/NanoSocraDecoder.py
Normal file
32
Project_Model/Libs/Transformer/Models/NanoSocraDecoder.py
Normal file
@ -0,0 +1,32 @@
|
||||
import torch
|
||||
import Project_Model.Libs.Embedder as Embedder
|
||||
from ..Classes import DeToken
|
||||
|
||||
class NanoSocraDecoder(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder_embedder: Embedder.NanoSocratesEmbedder,
|
||||
decoder_layers: torch.nn.Sequential,
|
||||
detokener: DeToken
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.__decoder_embedder = decoder_embedder
|
||||
self.__decoder = decoder_layers
|
||||
self.__detokener = detokener
|
||||
|
||||
def forward(self, args: tuple[torch.Tensor, torch.Tensor]):
|
||||
|
||||
decoder_embedder_input, tgt_padding = args
|
||||
|
||||
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
||||
|
||||
decoder_output, _, _, _, _, _ = self.__decoder(
|
||||
(decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, False)
|
||||
)
|
||||
|
||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||
|
||||
return logits
|
||||
29
Project_Model/Libs/Transformer/Models/NanoSocratEncoder.py
Normal file
29
Project_Model/Libs/Transformer/Models/NanoSocratEncoder.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import Project_Model.Libs.Embedder as Embedder
|
||||
from ..Classes import DeToken
|
||||
|
||||
class NanoSocratEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embedder: Embedder.NanoSocratesEmbedder,
|
||||
encoder_layers: torch.nn.Sequential,
|
||||
detokener: DeToken
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.__encoder_embedder = encoder_embedder
|
||||
self.__encoder = encoder_layers
|
||||
self.__detokener = detokener
|
||||
|
||||
def forward(self, args: tuple[torch.Tensor, torch.Tensor]):
|
||||
|
||||
encoder_embedder_input, src_padding = args
|
||||
|
||||
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
|
||||
|
||||
encoder_output, _ = self.__encoder((encoder_tensor, src_padding))
|
||||
|
||||
logits: torch.Tensor = self.__detokener(encoder_output)
|
||||
|
||||
return logits
|
||||
@ -46,10 +46,17 @@ class TrainingModel(torch.nn.Module):
|
||||
|
||||
encoder_output, _ = self.__encoder((encoder_tensor, src_padding))
|
||||
|
||||
decoder_output, _, _, _, _ = self.__decoder(
|
||||
(decoder_tensor, encoder_output, encoder_output, src_padding, tgt_padding)
|
||||
decoder_output, _, _, _, _, _ = self.__decoder(
|
||||
(decoder_tensor, encoder_output, encoder_output, src_padding, tgt_padding, False)
|
||||
)
|
||||
|
||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||
|
||||
return logits
|
||||
|
||||
def take_pieces(self):
|
||||
|
||||
return (
|
||||
(self.__encoder_embedder, self.__encoder),
|
||||
(self.__decoder_embedder, self.__decoder, self.__detokener)
|
||||
)
|
||||
@ -1,5 +1,9 @@
|
||||
from .TrainingModel import TrainingModel
|
||||
from .NanoSocratEncoder import NanoSocratEncoder
|
||||
from .NanoSocraDecoder import NanoSocraDecoder
|
||||
|
||||
__all__ = [
|
||||
"TrainingModel"
|
||||
"TrainingModel",
|
||||
"NanoSocratEncoder",
|
||||
"NanoSocraDecoder"
|
||||
]
|
||||
@ -4,6 +4,7 @@ from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequen
|
||||
from .inference_masking import inference_masking
|
||||
from .truncate_rdf_list import truncate_rdf_list
|
||||
from .decode_out import tensor2token
|
||||
from .model_utils import decompose_nano_socrates, create_standalone_model
|
||||
|
||||
__all__ = [
|
||||
"TaskType",
|
||||
@ -15,5 +16,7 @@ __all__ = [
|
||||
"normalize_sequence",
|
||||
"inference_masking",
|
||||
"truncate_rdf_list",
|
||||
"tensor2token"
|
||||
"tensor2token",
|
||||
"decompose_nano_socrates",
|
||||
"create_standalone_model"
|
||||
]
|
||||
53
Project_Model/Libs/Transformer/Utils/model_utils.py
Normal file
53
Project_Model/Libs/Transformer/Utils/model_utils.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from Project_Model.Libs.Embedder import NanoSocratesEmbedder
|
||||
from ..Models import TrainingModel, NanoSocraDecoder, NanoSocratEncoder
|
||||
from ..Classes import DeToken, Encoder, Decoder
|
||||
from ..Enums import ModelType
|
||||
|
||||
|
||||
def decompose_nano_socrates(
|
||||
model: TrainingModel, vocabulary_size: int, embedding_size: int
|
||||
) -> tuple[TrainingModel, NanoSocratEncoder, NanoSocraDecoder]:
|
||||
|
||||
encoder_pieces, decoder_pieces = model.take_pieces()
|
||||
encoder_embedder, encoder = encoder_pieces
|
||||
encoder_detokener = DeToken(embedding_size, vocabulary_size)
|
||||
decoder_embedder, decoder, decoder_detokener = decoder_pieces
|
||||
|
||||
return (
|
||||
model,
|
||||
NanoSocratEncoder(encoder_embedder, encoder, encoder_detokener),
|
||||
NanoSocraDecoder(decoder_embedder, decoder, decoder_detokener),
|
||||
)
|
||||
|
||||
|
||||
def create_standalone_model(
|
||||
model_type: ModelType,
|
||||
vocabulary_size: int,
|
||||
latent_space: int = 256,
|
||||
feed_forward_multiplier: int = 4,
|
||||
attention_heads: int = 4,
|
||||
layer_number: int = 2,
|
||||
) -> NanoSocratEncoder | NanoSocraDecoder:
|
||||
|
||||
feed_forward_latent_space = latent_space * feed_forward_multiplier
|
||||
|
||||
embedder = NanoSocratesEmbedder(vocabulary_size, latent_space)
|
||||
detokener = DeToken(latent_space, vocabulary_size)
|
||||
|
||||
if model_type == ModelType.ENCODER_ONLY:
|
||||
TMP_ENCODERS = [
|
||||
Encoder(latent_space, feed_forward_latent_space, attention_heads)
|
||||
] * layer_number
|
||||
|
||||
encoder = torch.nn.Sequential(*TMP_ENCODERS)
|
||||
|
||||
return NanoSocratEncoder(embedder, encoder, detokener)
|
||||
|
||||
TMP_DECODERS = [
|
||||
Decoder(latent_space, feed_forward_latent_space, attention_heads)
|
||||
] * layer_number
|
||||
|
||||
decoder = torch.nn.Sequential(*TMP_DECODERS)
|
||||
|
||||
return NanoSocraDecoder(embedder, decoder, detokener)
|
||||
@ -1,7 +1,9 @@
|
||||
from .Classes import *
|
||||
from .Enums import *
|
||||
from .Utils import *
|
||||
from .Models import *
|
||||
|
||||
from . import Classes
|
||||
from . import Enums
|
||||
from . import Utils
|
||||
from . import Models
|
||||
from . import Models
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user