2025-10-12 16:30:30 +02:00
|
|
|
import torch
|
|
|
|
|
import Project_Model.Libs.Embedder as Embedder
|
|
|
|
|
from ..Classes import Encoder, Decoder, DeToken
|
|
|
|
|
from ..Utils import get_decoder_input
|
|
|
|
|
from Project_Model.Libs.Batch import TaskType
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NanoSocratesCore(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
vocabulary_size: int,
|
|
|
|
|
sentence_max_length: int,
|
|
|
|
|
sos: int,
|
|
|
|
|
pad: int,
|
|
|
|
|
eos: int,
|
2025-10-17 22:17:24 +02:00
|
|
|
continuerdf: int,
|
2025-10-12 16:30:30 +02:00
|
|
|
latent_space: int = 256,
|
|
|
|
|
feed_forward_multiplier: int = 4,
|
|
|
|
|
attention_heads: int = 4,
|
|
|
|
|
layer_number: int = 2,
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.__sos = sos
|
|
|
|
|
self.__pad = pad
|
|
|
|
|
self.__eos = eos
|
2025-10-17 22:17:24 +02:00
|
|
|
self.__continuerdf = continuerdf
|
2025-10-12 16:30:30 +02:00
|
|
|
self.__sentence_len = sentence_max_length
|
|
|
|
|
|
|
|
|
|
feed_forward_latent_space = latent_space * feed_forward_multiplier
|
|
|
|
|
|
|
|
|
|
self.__encoder_embedder = Embedder.NanoSocratesEmbedder(
|
|
|
|
|
vocabulary_size, latent_space
|
|
|
|
|
)
|
|
|
|
|
self.__decoder_embedder = Embedder.NanoSocratesEmbedder(
|
|
|
|
|
vocabulary_size, latent_space
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
TMP_ENCODERS = [
|
|
|
|
|
Encoder(latent_space, feed_forward_latent_space, attention_heads)
|
|
|
|
|
] * layer_number
|
|
|
|
|
|
|
|
|
|
TMP_DECODERS = [
|
|
|
|
|
Decoder(latent_space, feed_forward_latent_space, attention_heads)
|
|
|
|
|
] * layer_number
|
|
|
|
|
|
|
|
|
|
self.__encoder = torch.nn.Sequential(*TMP_ENCODERS)
|
|
|
|
|
self.__decoder = torch.nn.Sequential(*TMP_DECODERS)
|
|
|
|
|
|
|
|
|
|
self.__detokener = DeToken(latent_space, vocabulary_size)
|
|
|
|
|
self.__encoder_detokener = DeToken(latent_space, vocabulary_size)
|
|
|
|
|
|
|
|
|
|
def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
|
|
|
|
|
|
|
|
|
|
encoder_embedder_input, src_padding, decoder_embedder_input, tgt_padding = args
|
|
|
|
|
|
|
|
|
|
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
|
|
|
|
|
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
|
|
|
|
|
|
|
|
|
encoder_output, _ = self.__encoder((encoder_tensor, src_padding))
|
|
|
|
|
|
|
|
|
|
decoder_output, _, _, _, _, _ = self.__decoder(
|
|
|
|
|
(decoder_tensor, encoder_output, encoder_output, src_padding, tgt_padding, False)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logits: torch.Tensor = self.__detokener(decoder_output)
|
|
|
|
|
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
def inference(self, input: tuple[torch.Tensor, torch.Tensor], task_type: TaskType) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
if task_type == TaskType.MASKING:
|
|
|
|
|
return self.__masking(input)
|
|
|
|
|
|
|
|
|
|
if task_type == TaskType.COMPLETATION:
|
|
|
|
|
return self.__continue_rdf(input)
|
|
|
|
|
|
|
|
|
|
return self.__text_generation(input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __text_generation(self, args: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
x, padding = args
|
|
|
|
|
|
|
|
|
|
encoder_tensor = self.__encoder_embedder(x)
|
2025-10-16 19:20:23 +02:00
|
|
|
|
|
|
|
|
BATCH: int
|
|
|
|
|
|
|
|
|
|
if len(x.shape) > 2:
|
|
|
|
|
BATCH, SEQ_LEN, _ = x.shape
|
|
|
|
|
else:
|
|
|
|
|
_, SEQ_LEN = x.shape
|
|
|
|
|
BATCH = 1
|
2025-10-12 16:30:30 +02:00
|
|
|
|
|
|
|
|
encoder_output, _ = self.__encoder((encoder_tensor, padding))
|
|
|
|
|
|
|
|
|
|
decoder_in = get_decoder_input(BATCH, self.__sos, self.__pad, SEQ_LEN)
|
|
|
|
|
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
|
|
|
|
|
|
|
|
|
continue_generating = True
|
|
|
|
|
token_idx = 0
|
|
|
|
|
|
|
|
|
|
while continue_generating:
|
|
|
|
|
|
2025-10-16 19:20:23 +02:00
|
|
|
decoder_in_x = self.__decoder_embedder(decoder_in)
|
2025-10-12 16:30:30 +02:00
|
|
|
|
|
|
|
|
decoder_output, _, _, _, _, _ = self.__decoder(
|
2025-10-16 19:20:23 +02:00
|
|
|
(decoder_in_x, encoder_output, encoder_output, padding, decoder_in_pad_mask, False)
|
2025-10-12 16:30:30 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logits: torch.Tensor = self.__detokener(decoder_output)
|
|
|
|
|
|
|
|
|
|
logits = torch.softmax(logits, 2)
|
|
|
|
|
|
2025-10-16 19:20:23 +02:00
|
|
|
tokens = torch.argmax(logits, 2)
|
|
|
|
|
|
|
|
|
|
if token_idx < self.__sentence_len - 1:
|
|
|
|
|
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
|
|
|
|
|
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
|
|
|
|
|
|
|
|
|
if token_idx == self.__sentence_len - 1:
|
|
|
|
|
continue_generating = False
|
|
|
|
|
continue
|
2025-10-12 16:30:30 +02:00
|
|
|
|
|
|
|
|
if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos:
|
|
|
|
|
continue_generating = False
|
|
|
|
|
continue
|
|
|
|
|
|
2025-10-16 19:20:23 +02:00
|
|
|
|
|
|
|
|
token_idx += 1
|
2025-10-12 16:30:30 +02:00
|
|
|
|
|
|
|
|
return decoder_in
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __masking(self, args: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
x, padding = args
|
|
|
|
|
|
|
|
|
|
encoder_tensor = self.__encoder_embedder(x)
|
|
|
|
|
x, _ = self.__encoder((encoder_tensor, padding))
|
|
|
|
|
|
|
|
|
|
logits: torch.Tensor = self.__encoder_detokener(x)
|
|
|
|
|
del x
|
|
|
|
|
|
|
|
|
|
logits = torch.softmax(logits, 2)
|
|
|
|
|
|
2025-10-16 19:20:23 +02:00
|
|
|
tokens = torch.argmax(logits, 2)
|
2025-10-12 16:30:30 +02:00
|
|
|
|
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __continue_rdf(self, args: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
decoder_in, _ = args
|
|
|
|
|
decoder_in_prefix_mask = decoder_in.eq(self.__pad)
|
|
|
|
|
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
|
|
|
|
|
|
|
|
|
continue_generating = True
|
2025-10-17 22:17:24 +02:00
|
|
|
token_idx: int= int((decoder_in[0] == self.__continuerdf).nonzero()[0].item()) + 1
|
|
|
|
|
|
|
|
|
|
|
2025-10-12 16:30:30 +02:00
|
|
|
|
|
|
|
|
while continue_generating:
|
|
|
|
|
|
2025-10-16 19:20:23 +02:00
|
|
|
decoder_x = self.__decoder_embedder(decoder_in)
|
2025-10-12 16:30:30 +02:00
|
|
|
|
|
|
|
|
decoder_output, _, _, _, _, _ = self.__decoder(
|
2025-10-16 19:20:23 +02:00
|
|
|
(decoder_x, decoder_in, decoder_in, decoder_in_prefix_mask, decoder_in_pad_mask, True)
|
2025-10-12 16:30:30 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logits: torch.Tensor = self.__detokener(decoder_output)
|
|
|
|
|
|
|
|
|
|
logits = torch.softmax(logits, 2)
|
|
|
|
|
|
2025-10-16 19:20:23 +02:00
|
|
|
tokens = torch.argmax(logits, 2)
|
|
|
|
|
|
|
|
|
|
if token_idx < self.__sentence_len - 1:
|
|
|
|
|
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
|
|
|
|
|
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
|
|
|
|
|
|
|
|
|
if token_idx == self.__sentence_len - 1:
|
|
|
|
|
continue_generating = False
|
|
|
|
|
continue
|
2025-10-12 16:30:30 +02:00
|
|
|
|
|
|
|
|
if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos:
|
|
|
|
|
continue_generating = False
|
|
|
|
|
continue
|
|
|
|
|
|
2025-10-16 19:20:23 +02:00
|
|
|
token_idx += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-10-12 16:30:30 +02:00
|
|
|
|
|
|
|
|
return decoder_in
|
|
|
|
|
|
|
|
|
|
def take_pieces(self):
|
|
|
|
|
|
|
|
|
|
return (
|
2025-10-16 19:20:23 +02:00
|
|
|
(self.__encoder_embedder, self.__encoder, self.__encoder_detokener),
|
2025-10-12 16:30:30 +02:00
|
|
|
(self.__decoder_embedder, self.__decoder, self.__detokener)
|
2025-10-16 19:20:23 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def load_pieces(
|
|
|
|
|
self,
|
|
|
|
|
encoder_embedder: Embedder.NanoSocratesEmbedder,
|
|
|
|
|
decoder_embedder: Embedder.NanoSocratesEmbedder,
|
|
|
|
|
encoder: torch.nn.Sequential,
|
|
|
|
|
decoder: torch.nn.Sequential,
|
|
|
|
|
encoder_detokener: DeToken,
|
|
|
|
|
decoder_detokener: DeToken
|
|
|
|
|
):
|
|
|
|
|
self.__encoder_embedder = encoder_embedder
|
|
|
|
|
self.__decoder_embedder = decoder_embedder
|
|
|
|
|
self.__encoder = encoder
|
|
|
|
|
self.__decoder = decoder
|
|
|
|
|
self.__encoder_detokener = encoder_detokener
|
|
|
|
|
self.__detokener = decoder_detokener
|