WIP NanoSocratesCore

This commit is contained in:
GassiGiuseppe 2025-10-06 18:21:27 +02:00
parent 745424a978
commit 56d438f01a

View File

@ -1,6 +1,58 @@
from ..Utils.task_type import TaskType from ..Utils.task_type import TaskType
from .Decoder import Decoder
from .Encoder import Encoder
from ....Libs.Embedder import NanoSocratesEmbedder
import torch
class NanoSocratesCore(): class NanoSocratesCore(torch.nn.Module):
def __init__(self,
embedded_size: int,
feed_forward_dim: int,
encoder_layers: int,
decoder_layers:int,
attention_heads: int,
vocab_size: int) -> None:
self.__encoder_sequence = torch.nn.Sequential(
*[Encoder(embedded_size, feed_forward_dim, attention_heads) for _ in range(encoder_layers)]
)
#* unpack the list so that each encoder has its own weights
self.__decoder_sequence = torch.nn.Sequential(
*[Decoder(embedded_size, feed_forward_dim, attention_heads) for _ in range(decoder_layers)]
)
self.__linear = torch.nn.Linear(embedded_size, vocab_size, bias=False)
self.__input_embeder = NanoSocratesEmbedder(vocab_size,embedded_size)
self.__output_embedder = NanoSocratesEmbedder(vocab_size,embedded_size)
def forward(self, token_list, padding_mask = None):
x = self.__input_embeder(token_list)
x = self.__encoder_sequence(x, padding_mask)[0]
# do while
x = self.__decoder_sequence(x,x,x, padding_mask)[0]
logits = self.__linear(x)
log_prob = torch.softmax(logits, dim=-1)
output = torch.argmax(log_prob)
while self.keep_going(log_prob):
# from log_prob again into x
x = self.__decoder_sequence(x,x,x, padding_mask)[0]
logits = self.__linear(x)
log_prob = torch.softmax(logits, dim=-1)
# argmax
return log_prob
def keep_going(self, x: ) -> bool:
def __init__(self) -> None:
pass