From 56d438f01ad74cefc1c5bcaba9b644f63a75b64e Mon Sep 17 00:00:00 2001 From: GassiGiuseppe Date: Mon, 6 Oct 2025 18:21:27 +0200 Subject: [PATCH] WIP NanoSocratesCore --- .../Transformer/Classes/NanoSocratesCore.py | 58 ++++++++++++++++++- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/Project_Model/Libs/Transformer/Classes/NanoSocratesCore.py b/Project_Model/Libs/Transformer/Classes/NanoSocratesCore.py index 7b2a9b0..bb2d971 100644 --- a/Project_Model/Libs/Transformer/Classes/NanoSocratesCore.py +++ b/Project_Model/Libs/Transformer/Classes/NanoSocratesCore.py @@ -1,6 +1,58 @@ from ..Utils.task_type import TaskType +from .Decoder import Decoder +from .Encoder import Encoder +from ....Libs.Embedder import NanoSocratesEmbedder +import torch -class NanoSocratesCore(): +class NanoSocratesCore(torch.nn.Module): + + def __init__(self, + embedded_size: int, + feed_forward_dim: int, + encoder_layers: int, + decoder_layers:int, + attention_heads: int, + vocab_size: int) -> None: + + + self.__encoder_sequence = torch.nn.Sequential( + *[Encoder(embedded_size, feed_forward_dim, attention_heads) for _ in range(encoder_layers)] + ) + + #* unpack the list so that each encoder has its own weights + + self.__decoder_sequence = torch.nn.Sequential( + *[Decoder(embedded_size, feed_forward_dim, attention_heads) for _ in range(decoder_layers)] + ) + + self.__linear = torch.nn.Linear(embedded_size, vocab_size, bias=False) + + self.__input_embeder = NanoSocratesEmbedder(vocab_size,embedded_size) + self.__output_embedder = NanoSocratesEmbedder(vocab_size,embedded_size) + + + def forward(self, token_list, padding_mask = None): + x = self.__input_embeder(token_list) + x = self.__encoder_sequence(x, padding_mask)[0] + + + # do while + x = self.__decoder_sequence(x,x,x, padding_mask)[0] + logits = self.__linear(x) + log_prob = torch.softmax(logits, dim=-1) + output = torch.argmax(log_prob) + while self.keep_going(log_prob): + # from log_prob again into x + + + x = self.__decoder_sequence(x,x,x, padding_mask)[0] + logits = self.__linear(x) + log_prob = torch.softmax(logits, dim=-1) + # argmax + + return log_prob + + + + def keep_going(self, x: ) -> bool: - def __init__(self) -> None: - pass \ No newline at end of file