Added a training model for NanoSocrates
This commit is contained in:
parent
159266a603
commit
24ea4d3ba4
56
Project_Model/Libs/Transformer/Models/TrainingModel.py
Normal file
56
Project_Model/Libs/Transformer/Models/TrainingModel.py
Normal file
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
import Project_Model.Libs.Embedder as Embedder
|
||||
from ..Classes import Encoder, Decoder, DeToken
|
||||
|
||||
|
||||
class TrainingModel(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocabulary_size: int,
|
||||
latent_space: int = 256,
|
||||
feed_forward_multiplier: int = 4,
|
||||
attention_heads: int = 4,
|
||||
layer_number: int = 2,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
feed_forward_latent_space = latent_space * feed_forward_multiplier
|
||||
|
||||
self.__encoder_embedder = Embedder.NanoSocratesEmbedder(
|
||||
vocabulary_size, latent_space
|
||||
)
|
||||
self.__decoder_embedder = Embedder.NanoSocratesEmbedder(
|
||||
vocabulary_size, latent_space
|
||||
)
|
||||
|
||||
TMP_ENCODERS = [
|
||||
Encoder(latent_space, feed_forward_latent_space, attention_heads)
|
||||
] * layer_number
|
||||
|
||||
TMP_DECODERS = [
|
||||
Decoder(latent_space, feed_forward_latent_space, attention_heads)
|
||||
] * layer_number
|
||||
|
||||
self.__encoder = torch.nn.Sequential(*TMP_ENCODERS)
|
||||
self.__decoder = torch.nn.Sequential(*TMP_DECODERS)
|
||||
|
||||
self.__detokener = DeToken(latent_space, vocabulary_size)
|
||||
|
||||
def forward(self, args: tuple[list[list[int]], list[list[bool]], list[list[int]]]):
|
||||
|
||||
encoder_embedder_input, padding_input, decoder_embedder_input = args
|
||||
|
||||
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
|
||||
padding_tensor = torch.tensor(padding_input, dtype=torch.bool)
|
||||
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
||||
|
||||
encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor))
|
||||
|
||||
decoder_output, _, _, _ = self.__decoder(
|
||||
(decoder_tensor, encoder_tensor, encoder_tensor, None)
|
||||
)
|
||||
|
||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||
|
||||
return logits
|
||||
Loading…
x
Reference in New Issue
Block a user