Added support for batches

This commit is contained in:
Christian Risi 2025-10-07 12:15:03 +02:00
parent 14b810c451
commit 9b5bb6d5f8
2 changed files with 56 additions and 38 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,19 +1,13 @@
import torch import torch
from ..Utils import fixed_positional_encoding from ..Utils import fixed_positional_encoding
# WIP FOR BATCHING # WIP FOR BATCHING
class NanoSocratesEmbedder(torch.nn.Module): class NanoSocratesEmbedder(torch.nn.Module):
def __init__( def __init__(self, vocabulary_size: int, embedding_size: int) -> None:
self,
vocabulary_size: int,
embedding_size: int
) -> None:
super().__init__() super().__init__()
self.__embedder = torch.nn.Embedding( self.__embedder = torch.nn.Embedding(vocabulary_size, embedding_size)
vocabulary_size,
embedding_size
)
def forward(self, tokenized_sentence: list[list[int]]) -> torch.Tensor: def forward(self, tokenized_sentence: list[list[int]]) -> torch.Tensor:
@ -21,14 +15,11 @@ class NanoSocratesEmbedder(torch.nn.Module):
computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR) computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR)
_ ,SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching _, SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching
POSITIONAL_ENCODINGS = fixed_positional_encoding( POSITIONAL_ENCODINGS = fixed_positional_encoding(
SENTENCE_LENGHT, SENTENCE_LENGHT, EMBEDDING_SIZE
EMBEDDING_SIZE
) )
computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS.unsqueeze(0) # for batching computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS # for batching
return computed_embeddings return computed_embeddings