Quick fix to architecture
This commit is contained in:
@@ -37,12 +37,11 @@ class TrainingModel(torch.nn.Module):
|
||||
|
||||
self.__detokener = DeToken(latent_space, vocabulary_size)
|
||||
|
||||
def forward(self, args: tuple[list[list[int]], list[list[bool]], list[list[int]]]):
|
||||
|
||||
encoder_embedder_input, padding_input, decoder_embedder_input = args
|
||||
def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
|
||||
|
||||
encoder_embedder_input, padding_tensor, decoder_embedder_input = args
|
||||
|
||||
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
|
||||
padding_tensor = torch.tensor(padding_input, dtype=torch.bool)
|
||||
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
||||
|
||||
encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor))
|
||||
|
||||
5
Project_Model/Libs/Transformer/Models/__init__.py
Normal file
5
Project_Model/Libs/Transformer/Models/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .TrainingModel import TrainingModel
|
||||
|
||||
__all__ = [
|
||||
"TrainingModel"
|
||||
]
|
||||
Reference in New Issue
Block a user