diff --git a/Project_Model/Libs/Embedder/Classes/NanoSocratesEmbedder.py b/Project_Model/Libs/Embedder/Classes/NanoSocratesEmbedder.py new file mode 100644 index 0000000..0377331 --- /dev/null +++ b/Project_Model/Libs/Embedder/Classes/NanoSocratesEmbedder.py @@ -0,0 +1,32 @@ +import torch +from ..Utils import fixed_positional_encoding +class NanoSocratesEmbedder(torch.nn.Module): + + def __init__( + self, + vocabulary_size: int, + embedding_size: int + ) -> None: + super().__init__() + self.__embedder = torch.nn.Embedding( + vocabulary_size, + embedding_size + ) + + def forward(self, tokenized_sentence: list[int]) -> torch.Tensor: + + TOKENIZED_TENSOR = torch.tensor(tokenized_sentence) + + computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR) + + SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape + + POSITIONAL_ENCODINGS = fixed_positional_encoding( + SENTENCE_LENGHT, + EMBEDDING_SIZE + ) + + computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS + return computed_embeddings + + diff --git a/Project_Model/Libs/Embedder/Utils/fixed_positional_encoding.py b/Project_Model/Libs/Embedder/Utils/fixed_positional_encoding.py new file mode 100644 index 0000000..bcdc0ee --- /dev/null +++ b/Project_Model/Libs/Embedder/Utils/fixed_positional_encoding.py @@ -0,0 +1,28 @@ +import torch + + +def fixed_positional_encoding( + sentence_dimension: int, + embedding_dimension: int, +) -> torch.Tensor: + + BIG_CONST = int(1e4) + INITIAL_ENCODING = torch.tensor([i for i in range(0, sentence_dimension)]) + + ENCODINGS: list[torch.Tensor] = [] + + for i in range(0, embedding_dimension): + EMBEDDING_POSITION = i + + # Note: The original paper did not specify + # to compute: pos mod 2!! + DIVISOR = BIG_CONST ** ((2 * (EMBEDDING_POSITION // 2)) / embedding_dimension) + INTERMEDIATE_ENCODING = INITIAL_ENCODING / DIVISOR + + if EMBEDDING_POSITION % 2 == 0: + ENCODINGS.append(torch.sin(INTERMEDIATE_ENCODING)) + continue + + ENCODINGS.append(torch.cos(INTERMEDIATE_ENCODING)) + + return torch.stack(ENCODINGS).transpose(0, 1)