NanoSocrates/Project_Model/Libs/Embedder/Classes/NanoSocratesEmbedder.py

26 lines
849 B
Python
Raw Normal View History

import torch
from ..Utils import fixed_positional_encoding
2025-10-06 21:41:45 +02:00
2025-10-07 12:15:03 +02:00
2025-10-06 21:41:45 +02:00
# WIP FOR BATCHING
class NanoSocratesEmbedder(torch.nn.Module):
2025-10-07 12:15:03 +02:00
def __init__(self, vocabulary_size: int, embedding_size: int) -> None:
super().__init__()
2025-10-07 12:15:03 +02:00
self.__embedder = torch.nn.Embedding(vocabulary_size, embedding_size)
2025-10-06 21:41:45 +02:00
def forward(self, tokenized_sentence: list[list[int]]) -> torch.Tensor:
TOKENIZED_TENSOR = torch.tensor(tokenized_sentence)
computed_embeddings: torch.Tensor = self.__embedder(TOKENIZED_TENSOR)
2025-10-07 12:15:03 +02:00
_, SENTENCE_LENGHT, EMBEDDING_SIZE = computed_embeddings.shape # for batching
POSITIONAL_ENCODINGS = fixed_positional_encoding(
2025-10-07 12:15:03 +02:00
SENTENCE_LENGHT, EMBEDDING_SIZE
)
2025-10-07 12:15:03 +02:00
computed_embeddings = computed_embeddings + POSITIONAL_ENCODINGS # for batching
return computed_embeddings