Fixes for evaluation

This commit is contained in:
Christian Risi
2025-10-16 19:20:23 +02:00
parent 9ff117f437
commit 892f91aad7
10 changed files with 492 additions and 31 deletions

View File

@@ -83,7 +83,14 @@ class NanoSocratesCore(torch.nn.Module):
x, padding = args
encoder_tensor = self.__encoder_embedder(x)
BATCH, SEQ_LEN, _ = x.shape
BATCH: int
if len(x.shape) > 2:
BATCH, SEQ_LEN, _ = x.shape
else:
_, SEQ_LEN = x.shape
BATCH = 1
encoder_output, _ = self.__encoder((encoder_tensor, padding))
@@ -95,25 +102,32 @@ class NanoSocratesCore(torch.nn.Module):
while continue_generating:
decoder_in = self.__decoder_embedder(decoder_in)
decoder_in_x = self.__decoder_embedder(decoder_in)
decoder_output, _, _, _, _, _ = self.__decoder(
(decoder_in, encoder_output, encoder_output, padding, decoder_in_pad_mask, False)
(decoder_in_x, encoder_output, encoder_output, padding, decoder_in_pad_mask, False)
)
logits: torch.Tensor = self.__detokener(decoder_output)
logits = torch.softmax(logits, 2)
tokens = torch.argmax(logits)
tokens = torch.argmax(logits, 2)
if token_idx < self.__sentence_len - 1:
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
decoder_in_pad_mask = decoder_in.eq(self.__pad)
if token_idx == self.__sentence_len - 1:
continue_generating = False
continue
if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos:
continue_generating = False
continue
if token_idx < self.__sentence_len - 1:
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
decoder_in_pad_mask = decoder_in.eq(self.__pad)
token_idx += 1
return decoder_in
@@ -130,7 +144,7 @@ class NanoSocratesCore(torch.nn.Module):
logits = torch.softmax(logits, 2)
tokens = torch.argmax(logits)
tokens = torch.argmax(logits, 2)
return tokens
@@ -146,31 +160,56 @@ class NanoSocratesCore(torch.nn.Module):
while continue_generating:
decoder_in = self.__decoder_embedder(decoder_in)
decoder_x = self.__decoder_embedder(decoder_in)
decoder_output, _, _, _, _, _ = self.__decoder(
(decoder_in, decoder_in, decoder_in, decoder_in_prefix_mask, decoder_in_pad_mask, False)
(decoder_x, decoder_in, decoder_in, decoder_in_prefix_mask, decoder_in_pad_mask, True)
)
logits: torch.Tensor = self.__detokener(decoder_output)
logits = torch.softmax(logits, 2)
tokens = torch.argmax(logits)
tokens = torch.argmax(logits, 2)
if token_idx < self.__sentence_len - 1:
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
decoder_in_pad_mask = decoder_in.eq(self.__pad)
if token_idx == self.__sentence_len - 1:
continue_generating = False
continue
if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos:
continue_generating = False
continue
if token_idx < self.__sentence_len - 1:
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
decoder_in_pad_mask = decoder_in.eq(self.__pad)
token_idx += 1
return decoder_in
def take_pieces(self):
return (
(self.__encoder_embedder, self.__encoder),
(self.__encoder_embedder, self.__encoder, self.__encoder_detokener),
(self.__decoder_embedder, self.__decoder, self.__detokener)
)
)
def load_pieces(
self,
encoder_embedder: Embedder.NanoSocratesEmbedder,
decoder_embedder: Embedder.NanoSocratesEmbedder,
encoder: torch.nn.Sequential,
decoder: torch.nn.Sequential,
encoder_detokener: DeToken,
decoder_detokener: DeToken
):
self.__encoder_embedder = encoder_embedder
self.__decoder_embedder = decoder_embedder
self.__encoder = encoder
self.__decoder = decoder
self.__encoder_detokener = encoder_detokener
self.__detokener = decoder_detokener