WIP NanoSocratesEmbedder for batching
This commit is contained in:
parent
56d438f01a
commit
14b810c451
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
from ..Utils import fixed_positional_encoding
|
||||
|
||||
# WIP FOR BATCHING
|
||||
class NanoSocratesEmbedder(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -13,20 +15,20 @@ class NanoSocratesEmbedder(torch.nn.Module):
|
||||
embedding_size
|
||||
)
|
||||
|
||||
def forward(self, tokenized_sentence: list[int]) -> torch.Tensor:
|
||||
def forward(self, tokenized_sentence: list[list[int]]) -> torch.Tensor:
|
||||
|
||||
TOKENIZED_TENSOR = torch.tensor(tokenized_sentence)
|
||||
|
||||
computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR)
|
||||
|
||||
SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape
|
||||
_ ,SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching
|
||||
|
||||
POSITIONAL_ENCODINGS = fixed_positional_encoding(
|
||||
SENTENCE_LENGHT,
|
||||
EMBEDDING_SIZE
|
||||
)
|
||||
|
||||
computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS
|
||||
computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS.unsqueeze(0) # for batching
|
||||
return computed_embeddings
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user