From 1649cd7768b081e3be2c3eacc0749de7ce2897cb Mon Sep 17 00:00:00 2001 From: GassiGiuseppe Date: Sat, 11 Oct 2025 16:18:43 +0200 Subject: [PATCH] added decoder_input method to build the batch tensor to give in input to the deocder --- .../Libs/Transformer/Utils/decoder_input.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 Project_Model/Libs/Transformer/Utils/decoder_input.py diff --git a/Project_Model/Libs/Transformer/Utils/decoder_input.py b/Project_Model/Libs/Transformer/Utils/decoder_input.py new file mode 100644 index 0000000..86c3fae --- /dev/null +++ b/Project_Model/Libs/Transformer/Utils/decoder_input.py @@ -0,0 +1,15 @@ +import torch +from Project_Model.Libs.Transformer import normalize_sequence +from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder + + +def get_decoder_input(batch_size, sos_token,pad_token, seq_len, embedder: Embedder): + + single_decoder_input, _ = normalize_sequence([sos_token],seq_len,pad_token, end_token=0, add_ending=False) + tensor_decoder_input = torch.tensor(single_decoder_input[:]) + embedded_decoder_intput = embedder(tensor_decoder_input) + + batch_decoder_input = embedded_decoder_intput.unsqueeze(0).repeat(batch_size, 1) + return batch_decoder_input + +