import torch from ..Utils import fixed_positional_encoding # WIP FOR BATCHING class NanoSocratesEmbedder(torch.nn.Module): def __init__( self, vocabulary_size: int, embedding_size: int ) -> None: super().__init__() self.__embedder = torch.nn.Embedding( vocabulary_size, embedding_size ) def forward(self, tokenized_sentence: list[list[int]]) -> torch.Tensor: TOKENIZED_TENSOR = torch.tensor(tokenized_sentence) computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR) _ ,SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching POSITIONAL_ENCODINGS = fixed_positional_encoding( SENTENCE_LENGHT, EMBEDDING_SIZE ) computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS.unsqueeze(0) # for batching return computed_embeddings