Last fixes
This commit is contained in:
@@ -14,6 +14,7 @@ class NanoSocratesCore(torch.nn.Module):
|
||||
sos: int,
|
||||
pad: int,
|
||||
eos: int,
|
||||
continuerdf: int,
|
||||
latent_space: int = 256,
|
||||
feed_forward_multiplier: int = 4,
|
||||
attention_heads: int = 4,
|
||||
@@ -24,6 +25,7 @@ class NanoSocratesCore(torch.nn.Module):
|
||||
self.__sos = sos
|
||||
self.__pad = pad
|
||||
self.__eos = eos
|
||||
self.__continuerdf = continuerdf
|
||||
self.__sentence_len = sentence_max_length
|
||||
|
||||
feed_forward_latent_space = latent_space * feed_forward_multiplier
|
||||
@@ -156,7 +158,9 @@ class NanoSocratesCore(torch.nn.Module):
|
||||
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
||||
|
||||
continue_generating = True
|
||||
token_idx = 0
|
||||
token_idx: int= int((decoder_in[0] == self.__continuerdf).nonzero()[0].item()) + 1
|
||||
|
||||
|
||||
|
||||
while continue_generating:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user