Added support for batches
This commit is contained in:
parent
14b810c451
commit
9b5bb6d5f8
File diff suppressed because one or more lines are too long
@ -1,19 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
from ..Utils import fixed_positional_encoding
|
from ..Utils import fixed_positional_encoding
|
||||||
|
|
||||||
|
|
||||||
# WIP FOR BATCHING
|
# WIP FOR BATCHING
|
||||||
class NanoSocratesEmbedder(torch.nn.Module):
|
class NanoSocratesEmbedder(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, vocabulary_size: int, embedding_size: int) -> None:
|
||||||
self,
|
|
||||||
vocabulary_size: int,
|
|
||||||
embedding_size: int
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__embedder = torch.nn.Embedding(
|
self.__embedder = torch.nn.Embedding(vocabulary_size, embedding_size)
|
||||||
vocabulary_size,
|
|
||||||
embedding_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, tokenized_sentence: list[list[int]]) -> torch.Tensor:
|
def forward(self, tokenized_sentence: list[list[int]]) -> torch.Tensor:
|
||||||
|
|
||||||
@ -21,14 +15,11 @@ class NanoSocratesEmbedder(torch.nn.Module):
|
|||||||
|
|
||||||
computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR)
|
computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR)
|
||||||
|
|
||||||
_ ,SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching
|
_, SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching
|
||||||
|
|
||||||
POSITIONAL_ENCODINGS = fixed_positional_encoding(
|
POSITIONAL_ENCODINGS = fixed_positional_encoding(
|
||||||
SENTENCE_LENGHT,
|
SENTENCE_LENGHT, EMBEDDING_SIZE
|
||||||
EMBEDDING_SIZE
|
|
||||||
)
|
)
|
||||||
|
|
||||||
computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS.unsqueeze(0) # for batching
|
computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS # for batching
|
||||||
return computed_embeddings
|
return computed_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user