added decoder_input method to build the batch tensor to give in input to the deocder

This commit is contained in:
GassiGiuseppe 2025-10-11 16:18:43 +02:00
parent 443f54fffd
commit 1649cd7768

View File

@ -0,0 +1,15 @@
import torch
from Project_Model.Libs.Transformer import normalize_sequence
from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder
def get_decoder_input(batch_size, sos_token,pad_token, seq_len, embedder: Embedder):
single_decoder_input, _ = normalize_sequence([sos_token],seq_len,pad_token, end_token=0, add_ending=False)
tensor_decoder_input = torch.tensor(single_decoder_input[:])
embedded_decoder_intput = embedder(tensor_decoder_input)
batch_decoder_input = embedded_decoder_intput.unsqueeze(0).repeat(batch_size, 1)
return batch_decoder_input