WIP NanoSocratesCore
This commit is contained in:
parent
745424a978
commit
56d438f01a
@ -1,6 +1,58 @@
|
|||||||
from ..Utils.task_type import TaskType
|
from ..Utils.task_type import TaskType
|
||||||
|
from .Decoder import Decoder
|
||||||
|
from .Encoder import Encoder
|
||||||
|
from ....Libs.Embedder import NanoSocratesEmbedder
|
||||||
|
import torch
|
||||||
|
|
||||||
class NanoSocratesCore():
|
class NanoSocratesCore(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embedded_size: int,
|
||||||
|
feed_forward_dim: int,
|
||||||
|
encoder_layers: int,
|
||||||
|
decoder_layers:int,
|
||||||
|
attention_heads: int,
|
||||||
|
vocab_size: int) -> None:
|
||||||
|
|
||||||
|
|
||||||
|
self.__encoder_sequence = torch.nn.Sequential(
|
||||||
|
*[Encoder(embedded_size, feed_forward_dim, attention_heads) for _ in range(encoder_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
|
#* unpack the list so that each encoder has its own weights
|
||||||
|
|
||||||
|
self.__decoder_sequence = torch.nn.Sequential(
|
||||||
|
*[Decoder(embedded_size, feed_forward_dim, attention_heads) for _ in range(decoder_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__linear = torch.nn.Linear(embedded_size, vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.__input_embeder = NanoSocratesEmbedder(vocab_size,embedded_size)
|
||||||
|
self.__output_embedder = NanoSocratesEmbedder(vocab_size,embedded_size)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, token_list, padding_mask = None):
|
||||||
|
x = self.__input_embeder(token_list)
|
||||||
|
x = self.__encoder_sequence(x, padding_mask)[0]
|
||||||
|
|
||||||
|
|
||||||
|
# do while
|
||||||
|
x = self.__decoder_sequence(x,x,x, padding_mask)[0]
|
||||||
|
logits = self.__linear(x)
|
||||||
|
log_prob = torch.softmax(logits, dim=-1)
|
||||||
|
output = torch.argmax(log_prob)
|
||||||
|
while self.keep_going(log_prob):
|
||||||
|
# from log_prob again into x
|
||||||
|
|
||||||
|
|
||||||
|
x = self.__decoder_sequence(x,x,x, padding_mask)[0]
|
||||||
|
logits = self.__linear(x)
|
||||||
|
log_prob = torch.softmax(logits, dim=-1)
|
||||||
|
# argmax
|
||||||
|
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def keep_going(self, x: ) -> bool:
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
||||||
Loading…
x
Reference in New Issue
Block a user