From 49946727d8ee91917a62ea2f639c0547428618d5 Mon Sep 17 00:00:00 2001 From: GassiGiuseppe Date: Sat, 11 Oct 2025 16:53:36 +0200 Subject: [PATCH] updated decoder_input to work without embedder --- Project_Model/Libs/Transformer/Utils/decoder_input.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/Project_Model/Libs/Transformer/Utils/decoder_input.py b/Project_Model/Libs/Transformer/Utils/decoder_input.py index 86c3fae..fb5c9cc 100644 --- a/Project_Model/Libs/Transformer/Utils/decoder_input.py +++ b/Project_Model/Libs/Transformer/Utils/decoder_input.py @@ -1,15 +1,14 @@ import torch from Project_Model.Libs.Transformer import normalize_sequence -from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder +# from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder -def get_decoder_input(batch_size, sos_token,pad_token, seq_len, embedder: Embedder): +def get_decoder_input(batch_size, sos_token,pad_token, seq_len): single_decoder_input, _ = normalize_sequence([sos_token],seq_len,pad_token, end_token=0, add_ending=False) tensor_decoder_input = torch.tensor(single_decoder_input[:]) - embedded_decoder_intput = embedder(tensor_decoder_input) + # embedded_decoder_intput = embedder(tensor_decoder_input) - batch_decoder_input = embedded_decoder_intput.unsqueeze(0).repeat(batch_size, 1) + batch_decoder_input = tensor_decoder_input.unsqueeze(0).repeat(batch_size, 1) return batch_decoder_input -