Added a way to detach models and create them standalone

This commit is contained in:
Christian Risi 2025-10-10 18:43:20 +02:00
parent 15f203cad5
commit 92ae40013d
10 changed files with 164 additions and 20 deletions

View File

@ -1,3 +1,4 @@
from typing import Optional
import torch
import torch.nn as nn
from .FeedForwardNetwork import FeedForwardNetwork
@ -42,13 +43,17 @@ class Decoder(nn.Module):
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor
torch.Tensor,
Optional[bool]
]
): # -> list[torch.Tensor]: # k_x = v_x . While x_q = x
# WARNING: args is needed to have sequential
x, k_x, v_x, src_padding_mask, tgt_padding_mask = args
if len(args) < 6:
args = args + (False)
x, k_x, v_x, src_padding_mask, tgt_padding_mask, decoder_only = args
# build of attention mask
# TODO: create a prefix causal mask if needed
attention_mask = get_causal_attention_mask(x.size(1))
# 1) Masked Attention
@ -67,6 +72,8 @@ class Decoder(nn.Module):
# 4) Layer Normalization
x = self.__layer_norm_1(x)
if not decoder_only:
# 5) Encoderdecoder (cross) attention
CROSS_ATTENTION = self.__cross_attention(
x, k_x, v_x, key_padding_mask=src_padding_mask
@ -97,7 +104,7 @@ class Decoder(nn.Module):
# 12) Layer Normalization
x = self.__layer_norm_3(x)
return (x, k_x, v_x, src_padding_mask, tgt_padding_mask)
return (x, k_x, v_x, src_padding_mask, tgt_padding_mask, decoder_only)
# use eval to disable dropout ecc

View File

@ -0,0 +1,6 @@
from enum import Enum, auto
class ModelType(Enum):
ENCODER_ONLY = auto()
DECODER_ONLY = auto()

View File

@ -0,0 +1 @@
from .ModelType import ModelType

View File

@ -0,0 +1,32 @@
import torch
import Project_Model.Libs.Embedder as Embedder
from ..Classes import DeToken
class NanoSocraDecoder(torch.nn.Module):
def __init__(
self,
decoder_embedder: Embedder.NanoSocratesEmbedder,
decoder_layers: torch.nn.Sequential,
detokener: DeToken
) -> None:
super().__init__()
self.__decoder_embedder = decoder_embedder
self.__decoder = decoder_layers
self.__detokener = detokener
def forward(self, args: tuple[torch.Tensor, torch.Tensor]):
decoder_embedder_input, tgt_padding = args
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
decoder_output, _, _, _, _, _ = self.__decoder(
(decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, False)
)
logits: torch.Tensor = self.__detokener(decoder_output)
return logits

View File

@ -0,0 +1,29 @@
import torch
import Project_Model.Libs.Embedder as Embedder
from ..Classes import DeToken
class NanoSocratEncoder(torch.nn.Module):
def __init__(
self,
encoder_embedder: Embedder.NanoSocratesEmbedder,
encoder_layers: torch.nn.Sequential,
detokener: DeToken
) -> None:
super().__init__()
self.__encoder_embedder = encoder_embedder
self.__encoder = encoder_layers
self.__detokener = detokener
def forward(self, args: tuple[torch.Tensor, torch.Tensor]):
encoder_embedder_input, src_padding = args
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
encoder_output, _ = self.__encoder((encoder_tensor, src_padding))
logits: torch.Tensor = self.__detokener(encoder_output)
return logits

View File

@ -46,10 +46,17 @@ class TrainingModel(torch.nn.Module):
encoder_output, _ = self.__encoder((encoder_tensor, src_padding))
decoder_output, _, _, _, _ = self.__decoder(
(decoder_tensor, encoder_output, encoder_output, src_padding, tgt_padding)
decoder_output, _, _, _, _, _ = self.__decoder(
(decoder_tensor, encoder_output, encoder_output, src_padding, tgt_padding, False)
)
logits: torch.Tensor = self.__detokener(decoder_output)
return logits
def take_pieces(self):
return (
(self.__encoder_embedder, self.__encoder),
(self.__decoder_embedder, self.__decoder, self.__detokener)
)

View File

@ -1,5 +1,9 @@
from .TrainingModel import TrainingModel
from .NanoSocratEncoder import NanoSocratEncoder
from .NanoSocraDecoder import NanoSocraDecoder
__all__ = [
"TrainingModel"
"TrainingModel",
"NanoSocratEncoder",
"NanoSocraDecoder"
]

View File

@ -4,6 +4,7 @@ from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequen
from .inference_masking import inference_masking
from .truncate_rdf_list import truncate_rdf_list
from .decode_out import tensor2token
from .model_utils import decompose_nano_socrates, create_standalone_model
__all__ = [
"TaskType",
@ -15,5 +16,7 @@ __all__ = [
"normalize_sequence",
"inference_masking",
"truncate_rdf_list",
"tensor2token"
"tensor2token",
"decompose_nano_socrates",
"create_standalone_model"
]

View File

@ -0,0 +1,53 @@
import torch
from Project_Model.Libs.Embedder import NanoSocratesEmbedder
from ..Models import TrainingModel, NanoSocraDecoder, NanoSocratEncoder
from ..Classes import DeToken, Encoder, Decoder
from ..Enums import ModelType
def decompose_nano_socrates(
model: TrainingModel, vocabulary_size: int, embedding_size: int
) -> tuple[TrainingModel, NanoSocratEncoder, NanoSocraDecoder]:
encoder_pieces, decoder_pieces = model.take_pieces()
encoder_embedder, encoder = encoder_pieces
encoder_detokener = DeToken(embedding_size, vocabulary_size)
decoder_embedder, decoder, decoder_detokener = decoder_pieces
return (
model,
NanoSocratEncoder(encoder_embedder, encoder, encoder_detokener),
NanoSocraDecoder(decoder_embedder, decoder, decoder_detokener),
)
def create_standalone_model(
model_type: ModelType,
vocabulary_size: int,
latent_space: int = 256,
feed_forward_multiplier: int = 4,
attention_heads: int = 4,
layer_number: int = 2,
) -> NanoSocratEncoder | NanoSocraDecoder:
feed_forward_latent_space = latent_space * feed_forward_multiplier
embedder = NanoSocratesEmbedder(vocabulary_size, latent_space)
detokener = DeToken(latent_space, vocabulary_size)
if model_type == ModelType.ENCODER_ONLY:
TMP_ENCODERS = [
Encoder(latent_space, feed_forward_latent_space, attention_heads)
] * layer_number
encoder = torch.nn.Sequential(*TMP_ENCODERS)
return NanoSocratEncoder(embedder, encoder, detokener)
TMP_DECODERS = [
Decoder(latent_space, feed_forward_latent_space, attention_heads)
] * layer_number
decoder = torch.nn.Sequential(*TMP_DECODERS)
return NanoSocraDecoder(embedder, decoder, detokener)

View File

@ -1,7 +1,9 @@
from .Classes import *
from .Enums import *
from .Utils import *
from .Models import *
from . import Classes
from . import Enums
from . import Utils
from . import Models