2025-10-04 19:43:25 +02:00
|
|
|
import torch
|
|
|
|
|
from ..Utils import fixed_positional_encoding
|
2025-10-06 21:41:45 +02:00
|
|
|
|
|
|
|
|
# WIP FOR BATCHING
|
2025-10-04 19:43:25 +02:00
|
|
|
class NanoSocratesEmbedder(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
vocabulary_size: int,
|
|
|
|
|
embedding_size: int
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.__embedder = torch.nn.Embedding(
|
|
|
|
|
vocabulary_size,
|
|
|
|
|
embedding_size
|
|
|
|
|
)
|
|
|
|
|
|
2025-10-06 21:41:45 +02:00
|
|
|
def forward(self, tokenized_sentence: list[list[int]]) -> torch.Tensor:
|
2025-10-04 19:43:25 +02:00
|
|
|
|
|
|
|
|
TOKENIZED_TENSOR = torch.tensor(tokenized_sentence)
|
|
|
|
|
|
|
|
|
|
computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR)
|
|
|
|
|
|
2025-10-06 21:41:45 +02:00
|
|
|
_ ,SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching
|
2025-10-04 19:43:25 +02:00
|
|
|
|
|
|
|
|
POSITIONAL_ENCODINGS = fixed_positional_encoding(
|
|
|
|
|
SENTENCE_LENGHT,
|
|
|
|
|
EMBEDDING_SIZE
|
|
|
|
|
)
|
|
|
|
|
|
2025-10-06 21:41:45 +02:00
|
|
|
computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS.unsqueeze(0) # for batching
|
2025-10-04 19:43:25 +02:00
|
|
|
return computed_embeddings
|
|
|
|
|
|
|
|
|
|
|