doctor and model test

This commit is contained in:
GassiGiuseppe
2025-10-08 22:51:36 +02:00
parent b805dc538e
commit 1de2cc59db
13 changed files with 902 additions and 63 deletions

View File

@@ -24,32 +24,49 @@ class TrainingModel(torch.nn.Module):
vocabulary_size, latent_space
)
TMP_ENCODERS = [
# do NOT share layer weights
enc_layers = [
Encoder(latent_space, feed_forward_latent_space, attention_heads)
] * layer_number
TMP_DECODERS = [
for _ in range(layer_number)
]
dec_layers = [
Decoder(latent_space, feed_forward_latent_space, attention_heads)
] * layer_number
for _ in range(layer_number)
]
self.__encoder = torch.nn.Sequential(*TMP_ENCODERS)
self.__decoder = torch.nn.Sequential(*TMP_DECODERS)
self.__encoder = torch.nn.Sequential(*enc_layers)
self.__decoder = torch.nn.Sequential(*dec_layers)
self.__detokener = DeToken(latent_space, vocabulary_size)
def forward(self, args: tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
encoder_embedder_input, padding_tensor, decoder_embedder_input = args
def forward(
self,
args: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
):
# returns logits for the LAST decoder position only -> [B, V]
(
encoder_embedder_input, # [B,S] encoder tokens
encoder_padding_mask, # [B,S] True where encoder is PAD
decoder_embedder_prefix, # [B,Tp] decoder prefix (e.g., <SOS> + tokens so far)
decoder_padding_mask, # [B,Tp] True where decoder prefix has PAD
) = args
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
# 1) embeddings
encoder_tensor = self.__encoder_embedder(encoder_embedder_input) # [B,S,E]
decoder_tensor = self.__decoder_embedder(decoder_embedder_prefix) # [B,Tp,E]
encoder_output, _ = self.__encoder((encoder_tensor, padding_tensor))
# 2) encode
encoder_output, _ = self.__encoder((encoder_tensor, encoder_padding_mask)) # [B,S,E], [B,S]
decoder_output, _, _, _ = self.__decoder(
(decoder_tensor, encoder_tensor, encoder_tensor, None)
)
# 3) decode (causal mask is built inside the decoder)
decoder_output, _, _, _, _ = self.__decoder(
(decoder_tensor, encoder_output, encoder_output,
decoder_padding_mask, encoder_padding_mask)
) # [B,Tp,E], ...
logits: torch.Tensor = self.__detokener(decoder_output)
# 4) project only the last time step
last_hidden = decoder_output[:, -1:, :] # [B,1,E]
step_logits = self.__detokener(last_hidden) # [B,1,V]
step_logits = step_logits[:, -1, :] # [B,V]
return logits
return step_logits # logits for one token