import torch from ..Utils import normalize_sequence # from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder def get_decoder_input(batch_size, sos_token,pad_token, seq_len): single_decoder_input, _ = normalize_sequence([sos_token],seq_len,pad_token, end_token=0, add_ending=False) tensor_decoder_input = torch.tensor(single_decoder_input[:]) # embedded_decoder_intput = embedder(tensor_decoder_input) batch_decoder_input = tensor_decoder_input.unsqueeze(0).repeat(batch_size, 1) return batch_decoder_input