Quick fix to architecture

This commit is contained in:
Christian Risi 2025-10-08 12:34:09 +02:00
parent 14c3914571
commit c2e13bc9c6
8 changed files with 89 additions and 126 deletions

File diff suppressed because one or more lines are too long

View File

@ -9,11 +9,9 @@ class NanoSocratesEmbedder(torch.nn.Module):
super().__init__() super().__init__()
self.__embedder = torch.nn.Embedding(vocabulary_size, embedding_size) self.__embedder = torch.nn.Embedding(vocabulary_size, embedding_size)
def forward(self, tokenized_sentence: list[list[int]]) -> torch.Tensor: def forward(self, tokenized_sentence: torch.Tensor) -> torch.Tensor:
TOKENIZED_TENSOR = torch.tensor(tokenized_sentence) computed_embeddings: torch.Tensor = self.__embedder(tokenized_sentence)
computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR)
_, SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching _, SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching

View File

@ -56,12 +56,12 @@ class Decoder(nn.Module):
) )
# 2) Dropout # 2) Dropout
DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION) # DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION)
del MASKED_ATTENTION # del MASKED_ATTENTION
# 3) Residual Connection # 3) Residual Connection
x = x + DROPPED_MASKED_ATTENTION x = x + MASKED_ATTENTION
del DROPPED_MASKED_ATTENTION del MASKED_ATTENTION
# 4) Layer Normalization # 4) Layer Normalization
x = self.__layer_norm_1(x) x = self.__layer_norm_1(x)
@ -72,12 +72,12 @@ class Decoder(nn.Module):
) )
# 6) Dropout # 6) Dropout
DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION) # DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)
del CROSS_ATTENTION # del CROSS_ATTENTION
# 7) Residual Connection # 7) Residual Connection
x = x + DROPPED_CROSS_ATTENTION x = x + CROSS_ATTENTION
del DROPPED_CROSS_ATTENTION del CROSS_ATTENTION
# 8) Layer Normalization # 8) Layer Normalization
x = self.__layer_norm_2(x) x = self.__layer_norm_2(x)
@ -86,12 +86,12 @@ class Decoder(nn.Module):
FEED_FORWARD = self.__feed_forward_network(x) FEED_FORWARD = self.__feed_forward_network(x)
# 10) Dropout # 10) Dropout
DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD) # DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
del FEED_FORWARD # del FEED_FORWARD
# 11) Residual Connection # 11) Residual Connection
x = x + DROPPED_FEED_FORWARD x = x + FEED_FORWARD
del DROPPED_FEED_FORWARD del FEED_FORWARD
# 12) Layer Normalization # 12) Layer Normalization
x = self.__layer_norm_3(x) x = self.__layer_norm_3(x)

View File

@ -43,11 +43,12 @@ class Encoder(
ATTENTION = self.__attention(x, x, x, key_padding_mask=padding_mask) ATTENTION = self.__attention(x, x, x, key_padding_mask=padding_mask)
# 2) Dropout # 2) Dropout
DROPPED_ATTENTION = self.__dropout(ATTENTION) # DROPPED_ATTENTION = self.__dropout(ATTENTION)
del ATTENTION # del ATTENTION
# 3) Residual Connection # 3) Residual Connection
x = x + DROPPED_ATTENTION x = x + ATTENTION
del ATTENTION
# 4) Layer Normalization # 4) Layer Normalization
x = self.__layer_norm_1(x) x = self.__layer_norm_1(x)
@ -56,12 +57,12 @@ class Encoder(
FEED_FORWARD = self.__feed_forward(x) FEED_FORWARD = self.__feed_forward(x)
# 6) Dropout # 6) Dropout
DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD) # DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
del FEED_FORWARD # del FEED_FORWARD
# 7) Residual Connection # 7) Residual Connection
x = x + DROPPED_FEED_FORWARD x = x + FEED_FORWARD
del DROPPED_FEED_FORWARD del FEED_FORWARD
# 8) Layer Normalization # 8) Layer Normalization
x = self.__layer_norm_2(x) x = self.__layer_norm_2(x)

View File

@ -37,12 +37,11 @@ class TrainingModel(torch.nn.Module):
self.__detokener = DeToken(latent_space, vocabulary_size) self.__detokener = DeToken(latent_space, vocabulary_size)
def forward(self, args: tuple[list[list[int]], list[list[bool]], list[list[int]]]): def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
encoder_embedder_input, padding_input, decoder_embedder_input = args encoder_embedder_input, padding_tensor, decoder_embedder_input = args
encoder_tensor = self.__encoder_embedder(encoder_embedder_input) encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
padding_tensor = torch.tensor(padding_input, dtype=torch.bool)
decoder_tensor = self.__decoder_embedder(decoder_embedder_input) decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor)) encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor))

View File

@ -0,0 +1,5 @@
from .TrainingModel import TrainingModel
__all__ = [
"TrainingModel"
]

View File

@ -2,6 +2,10 @@ def truncate_sequence(
sequence: list[int], truncate_at: int, end_token: int sequence: list[int], truncate_at: int, end_token: int
) -> list[int]: ) -> list[int]:
if len(sequence) < truncate_at - 1:
sequence.append(end_token)
return sequence
if len(sequence) < truncate_at: if len(sequence) < truncate_at:
sequence[-1] = end_token sequence[-1] = end_token
return sequence return sequence

View File

@ -1,5 +1,7 @@
from .Classes import * from .Classes import *
from .Utils import * from .Utils import *
from .Models import *
from . import Classes from . import Classes
from . import Utils from . import Utils
from . import Models