diff --git a/Project_Model/Libs/Transformer/Utils/decoder_input.py b/Project_Model/Libs/Transformer/Utils/decoder_input.py new file mode 100644 index 0000000..86c3fae --- /dev/null +++ b/Project_Model/Libs/Transformer/Utils/decoder_input.py @@ -0,0 +1,15 @@ +import torch +from Project_Model.Libs.Transformer import normalize_sequence +from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder + + +def get_decoder_input(batch_size, sos_token,pad_token, seq_len, embedder: Embedder): + + single_decoder_input, _ = normalize_sequence([sos_token],seq_len,pad_token, end_token=0, add_ending=False) + tensor_decoder_input = torch.tensor(single_decoder_input[:]) + embedded_decoder_intput = embedder(tensor_decoder_input) + + batch_decoder_input = embedded_decoder_intput.unsqueeze(0).repeat(batch_size, 1) + return batch_decoder_input + +