WIP NanoSocratesCore
This commit is contained in:
parent
745424a978
commit
56d438f01a
@ -1,6 +1,58 @@
|
||||
from ..Utils.task_type import TaskType
|
||||
from .Decoder import Decoder
|
||||
from .Encoder import Encoder
|
||||
from ....Libs.Embedder import NanoSocratesEmbedder
|
||||
import torch
|
||||
|
||||
class NanoSocratesCore():
|
||||
class NanoSocratesCore(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
embedded_size: int,
|
||||
feed_forward_dim: int,
|
||||
encoder_layers: int,
|
||||
decoder_layers:int,
|
||||
attention_heads: int,
|
||||
vocab_size: int) -> None:
|
||||
|
||||
|
||||
self.__encoder_sequence = torch.nn.Sequential(
|
||||
*[Encoder(embedded_size, feed_forward_dim, attention_heads) for _ in range(encoder_layers)]
|
||||
)
|
||||
|
||||
#* unpack the list so that each encoder has its own weights
|
||||
|
||||
self.__decoder_sequence = torch.nn.Sequential(
|
||||
*[Decoder(embedded_size, feed_forward_dim, attention_heads) for _ in range(decoder_layers)]
|
||||
)
|
||||
|
||||
self.__linear = torch.nn.Linear(embedded_size, vocab_size, bias=False)
|
||||
|
||||
self.__input_embeder = NanoSocratesEmbedder(vocab_size,embedded_size)
|
||||
self.__output_embedder = NanoSocratesEmbedder(vocab_size,embedded_size)
|
||||
|
||||
|
||||
def forward(self, token_list, padding_mask = None):
|
||||
x = self.__input_embeder(token_list)
|
||||
x = self.__encoder_sequence(x, padding_mask)[0]
|
||||
|
||||
|
||||
# do while
|
||||
x = self.__decoder_sequence(x,x,x, padding_mask)[0]
|
||||
logits = self.__linear(x)
|
||||
log_prob = torch.softmax(logits, dim=-1)
|
||||
output = torch.argmax(log_prob)
|
||||
while self.keep_going(log_prob):
|
||||
# from log_prob again into x
|
||||
|
||||
|
||||
x = self.__decoder_sequence(x,x,x, padding_mask)[0]
|
||||
logits = self.__linear(x)
|
||||
log_prob = torch.softmax(logits, dim=-1)
|
||||
# argmax
|
||||
|
||||
return log_prob
|
||||
|
||||
|
||||
|
||||
def keep_going(self, x: ) -> bool:
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
Loading…
x
Reference in New Issue
Block a user