Fixes for evaluation
This commit is contained in:
@@ -83,7 +83,14 @@ class NanoSocratesCore(torch.nn.Module):
|
||||
x, padding = args
|
||||
|
||||
encoder_tensor = self.__encoder_embedder(x)
|
||||
BATCH, SEQ_LEN, _ = x.shape
|
||||
|
||||
BATCH: int
|
||||
|
||||
if len(x.shape) > 2:
|
||||
BATCH, SEQ_LEN, _ = x.shape
|
||||
else:
|
||||
_, SEQ_LEN = x.shape
|
||||
BATCH = 1
|
||||
|
||||
encoder_output, _ = self.__encoder((encoder_tensor, padding))
|
||||
|
||||
@@ -95,25 +102,32 @@ class NanoSocratesCore(torch.nn.Module):
|
||||
|
||||
while continue_generating:
|
||||
|
||||
decoder_in = self.__decoder_embedder(decoder_in)
|
||||
decoder_in_x = self.__decoder_embedder(decoder_in)
|
||||
|
||||
decoder_output, _, _, _, _, _ = self.__decoder(
|
||||
(decoder_in, encoder_output, encoder_output, padding, decoder_in_pad_mask, False)
|
||||
(decoder_in_x, encoder_output, encoder_output, padding, decoder_in_pad_mask, False)
|
||||
)
|
||||
|
||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||
|
||||
logits = torch.softmax(logits, 2)
|
||||
|
||||
tokens = torch.argmax(logits)
|
||||
tokens = torch.argmax(logits, 2)
|
||||
|
||||
if token_idx < self.__sentence_len - 1:
|
||||
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
|
||||
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
||||
|
||||
if token_idx == self.__sentence_len - 1:
|
||||
continue_generating = False
|
||||
continue
|
||||
|
||||
if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos:
|
||||
continue_generating = False
|
||||
continue
|
||||
|
||||
if token_idx < self.__sentence_len - 1:
|
||||
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
|
||||
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
||||
|
||||
token_idx += 1
|
||||
|
||||
return decoder_in
|
||||
|
||||
@@ -130,7 +144,7 @@ class NanoSocratesCore(torch.nn.Module):
|
||||
|
||||
logits = torch.softmax(logits, 2)
|
||||
|
||||
tokens = torch.argmax(logits)
|
||||
tokens = torch.argmax(logits, 2)
|
||||
|
||||
return tokens
|
||||
|
||||
@@ -146,31 +160,56 @@ class NanoSocratesCore(torch.nn.Module):
|
||||
|
||||
while continue_generating:
|
||||
|
||||
decoder_in = self.__decoder_embedder(decoder_in)
|
||||
decoder_x = self.__decoder_embedder(decoder_in)
|
||||
|
||||
decoder_output, _, _, _, _, _ = self.__decoder(
|
||||
(decoder_in, decoder_in, decoder_in, decoder_in_prefix_mask, decoder_in_pad_mask, False)
|
||||
(decoder_x, decoder_in, decoder_in, decoder_in_prefix_mask, decoder_in_pad_mask, True)
|
||||
)
|
||||
|
||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||
|
||||
logits = torch.softmax(logits, 2)
|
||||
|
||||
tokens = torch.argmax(logits)
|
||||
tokens = torch.argmax(logits, 2)
|
||||
|
||||
if token_idx < self.__sentence_len - 1:
|
||||
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
|
||||
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
||||
|
||||
if token_idx == self.__sentence_len - 1:
|
||||
continue_generating = False
|
||||
continue
|
||||
|
||||
if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos:
|
||||
continue_generating = False
|
||||
continue
|
||||
|
||||
if token_idx < self.__sentence_len - 1:
|
||||
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
|
||||
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
||||
token_idx += 1
|
||||
|
||||
|
||||
|
||||
|
||||
return decoder_in
|
||||
|
||||
def take_pieces(self):
|
||||
|
||||
return (
|
||||
(self.__encoder_embedder, self.__encoder),
|
||||
(self.__encoder_embedder, self.__encoder, self.__encoder_detokener),
|
||||
(self.__decoder_embedder, self.__decoder, self.__detokener)
|
||||
)
|
||||
)
|
||||
|
||||
def load_pieces(
|
||||
self,
|
||||
encoder_embedder: Embedder.NanoSocratesEmbedder,
|
||||
decoder_embedder: Embedder.NanoSocratesEmbedder,
|
||||
encoder: torch.nn.Sequential,
|
||||
decoder: torch.nn.Sequential,
|
||||
encoder_detokener: DeToken,
|
||||
decoder_detokener: DeToken
|
||||
):
|
||||
self.__encoder_embedder = encoder_embedder
|
||||
self.__decoder_embedder = decoder_embedder
|
||||
self.__encoder = encoder
|
||||
self.__decoder = decoder
|
||||
self.__encoder_detokener = encoder_detokener
|
||||
self.__detokener = decoder_detokener
|
||||
@@ -1,9 +1,11 @@
|
||||
from .TrainingModel import TrainingModel
|
||||
from .NanoSocratEncoder import NanoSocratEncoder
|
||||
from .NanoSocraDecoder import NanoSocraDecoder
|
||||
from .NanoSocrates import NanoSocratesCore
|
||||
|
||||
__all__ = [
|
||||
"TrainingModel",
|
||||
"NanoSocratEncoder",
|
||||
"NanoSocraDecoder"
|
||||
"NanoSocraDecoder",
|
||||
"NanoSocratesCore"
|
||||
]
|
||||
@@ -1,6 +1,7 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
class TaskType(Enum):
|
||||
TEXT2RDF = auto()
|
||||
RDF2TEXT = auto()
|
||||
MASK = auto()
|
||||
COMPLETATION = auto()
|
||||
Reference in New Issue
Block a user