2025-10-08 11:18:05 +02:00
|
|
|
import torch
|
|
|
|
|
import Project_Model.Libs.Embedder as Embedder
|
|
|
|
|
from ..Classes import Encoder, Decoder, DeToken
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainingModel(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
vocabulary_size: int,
|
|
|
|
|
latent_space: int = 256,
|
|
|
|
|
feed_forward_multiplier: int = 4,
|
|
|
|
|
attention_heads: int = 4,
|
|
|
|
|
layer_number: int = 2,
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
feed_forward_latent_space = latent_space * feed_forward_multiplier
|
|
|
|
|
|
|
|
|
|
self.__encoder_embedder = Embedder.NanoSocratesEmbedder(
|
|
|
|
|
vocabulary_size, latent_space
|
|
|
|
|
)
|
|
|
|
|
self.__decoder_embedder = Embedder.NanoSocratesEmbedder(
|
|
|
|
|
vocabulary_size, latent_space
|
|
|
|
|
)
|
|
|
|
|
|
2025-10-08 22:51:36 +02:00
|
|
|
# do NOT share layer weights
|
|
|
|
|
enc_layers = [
|
2025-10-08 11:18:05 +02:00
|
|
|
Encoder(latent_space, feed_forward_latent_space, attention_heads)
|
2025-10-08 22:51:36 +02:00
|
|
|
for _ in range(layer_number)
|
|
|
|
|
]
|
|
|
|
|
dec_layers = [
|
2025-10-08 11:18:05 +02:00
|
|
|
Decoder(latent_space, feed_forward_latent_space, attention_heads)
|
2025-10-08 22:51:36 +02:00
|
|
|
for _ in range(layer_number)
|
|
|
|
|
]
|
2025-10-08 11:18:05 +02:00
|
|
|
|
2025-10-08 22:51:36 +02:00
|
|
|
self.__encoder = torch.nn.Sequential(*enc_layers)
|
|
|
|
|
self.__decoder = torch.nn.Sequential(*dec_layers)
|
2025-10-08 11:18:05 +02:00
|
|
|
|
|
|
|
|
self.__detokener = DeToken(latent_space, vocabulary_size)
|
|
|
|
|
|
2025-10-08 22:51:36 +02:00
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
args: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
|
|
|
|
):
|
|
|
|
|
# returns logits for the LAST decoder position only -> [B, V]
|
|
|
|
|
(
|
|
|
|
|
encoder_embedder_input, # [B,S] encoder tokens
|
|
|
|
|
encoder_padding_mask, # [B,S] True where encoder is PAD
|
|
|
|
|
decoder_embedder_prefix, # [B,Tp] decoder prefix (e.g., <SOS> + tokens so far)
|
|
|
|
|
decoder_padding_mask, # [B,Tp] True where decoder prefix has PAD
|
|
|
|
|
) = args
|
2025-10-08 11:18:05 +02:00
|
|
|
|
2025-10-08 22:51:36 +02:00
|
|
|
# 1) embeddings
|
|
|
|
|
encoder_tensor = self.__encoder_embedder(encoder_embedder_input) # [B,S,E]
|
|
|
|
|
decoder_tensor = self.__decoder_embedder(decoder_embedder_prefix) # [B,Tp,E]
|
2025-10-08 11:18:05 +02:00
|
|
|
|
2025-10-08 22:51:36 +02:00
|
|
|
# 2) encode
|
|
|
|
|
encoder_output, _ = self.__encoder((encoder_tensor, encoder_padding_mask)) # [B,S,E], [B,S]
|
2025-10-08 11:18:05 +02:00
|
|
|
|
2025-10-08 22:51:36 +02:00
|
|
|
# 3) decode (causal mask is built inside the decoder)
|
|
|
|
|
decoder_output, _, _, _, _ = self.__decoder(
|
|
|
|
|
(decoder_tensor, encoder_output, encoder_output,
|
|
|
|
|
decoder_padding_mask, encoder_padding_mask)
|
|
|
|
|
) # [B,Tp,E], ...
|
2025-10-08 11:18:05 +02:00
|
|
|
|
2025-10-08 22:51:36 +02:00
|
|
|
# 4) project only the last time step
|
|
|
|
|
last_hidden = decoder_output[:, -1:, :] # [B,1,E]
|
|
|
|
|
step_logits = self.__detokener(last_hidden) # [B,1,V]
|
|
|
|
|
step_logits = step_logits[:, -1, :] # [B,V]
|
2025-10-08 11:18:05 +02:00
|
|
|
|
2025-10-08 22:51:36 +02:00
|
|
|
return step_logits # logits for one token
|