WIP NanoSocratesEmbedder for batching

This commit is contained in:
GassiGiuseppe 2025-10-06 21:41:45 +02:00
parent 56d438f01a
commit 14b810c451

View File

@ -1,5 +1,7 @@
import torch import torch
from ..Utils import fixed_positional_encoding from ..Utils import fixed_positional_encoding
# WIP FOR BATCHING
class NanoSocratesEmbedder(torch.nn.Module): class NanoSocratesEmbedder(torch.nn.Module):
def __init__( def __init__(
@ -13,20 +15,20 @@ class NanoSocratesEmbedder(torch.nn.Module):
embedding_size embedding_size
) )
def forward(self, tokenized_sentence: list[int]) -> torch.Tensor: def forward(self, tokenized_sentence: list[list[int]]) -> torch.Tensor:
TOKENIZED_TENSOR = torch.tensor(tokenized_sentence) TOKENIZED_TENSOR = torch.tensor(tokenized_sentence)
computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR) computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR)
SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape _ ,SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching
POSITIONAL_ENCODINGS = fixed_positional_encoding( POSITIONAL_ENCODINGS = fixed_positional_encoding(
SENTENCE_LENGHT, SENTENCE_LENGHT,
EMBEDDING_SIZE EMBEDDING_SIZE
) )
computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS.unsqueeze(0) # for batching
return computed_embeddings return computed_embeddings