import torch from ..Utils import fixed_positional_encoding # WIP FOR BATCHING class NanoSocratesEmbedder(torch.nn.Module): def __init__(self, vocabulary_size: int, embedding_size: int) -> None: super().__init__() self.__embedder = torch.nn.Embedding(vocabulary_size, embedding_size) def forward(self, tokenized_sentence: torch.Tensor) -> torch.Tensor: computed_embeddings: torch.Tensor = self.__embedder(tokenized_sentence) _, SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching POSITIONAL_ENCODINGS = fixed_positional_encoding( SENTENCE_LENGHT, EMBEDDING_SIZE ) computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS # for batching return computed_embeddings