56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
import torch
|
|
import Project_Model.Libs.Embedder as Embedder
|
|
from ..Classes import Encoder, Decoder, DeToken
|
|
|
|
|
|
class TrainingModel(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
vocabulary_size: int,
|
|
latent_space: int = 256,
|
|
feed_forward_multiplier: int = 4,
|
|
attention_heads: int = 4,
|
|
layer_number: int = 2,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
feed_forward_latent_space = latent_space * feed_forward_multiplier
|
|
|
|
self.__encoder_embedder = Embedder.NanoSocratesEmbedder(
|
|
vocabulary_size, latent_space
|
|
)
|
|
self.__decoder_embedder = Embedder.NanoSocratesEmbedder(
|
|
vocabulary_size, latent_space
|
|
)
|
|
|
|
TMP_ENCODERS = [
|
|
Encoder(latent_space, feed_forward_latent_space, attention_heads)
|
|
] * layer_number
|
|
|
|
TMP_DECODERS = [
|
|
Decoder(latent_space, feed_forward_latent_space, attention_heads)
|
|
] * layer_number
|
|
|
|
self.__encoder = torch.nn.Sequential(*TMP_ENCODERS)
|
|
self.__decoder = torch.nn.Sequential(*TMP_DECODERS)
|
|
|
|
self.__detokener = DeToken(latent_space, vocabulary_size)
|
|
|
|
def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
|
|
|
|
encoder_embedder_input, src_padding, decoder_embedder_input, tgt_padding = args
|
|
|
|
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
|
|
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
|
|
|
encoder_output, _ = self.__encoder((encoder_tensor, src_padding))
|
|
|
|
decoder_output, _, _, _, _ = self.__decoder(
|
|
(decoder_tensor, encoder_output, encoder_output, src_padding, tgt_padding)
|
|
)
|
|
|
|
logits: torch.Tensor = self.__detokener(decoder_output)
|
|
|
|
return logits
|