updated decoder_input to work without embedder
This commit is contained in:
parent
1649cd7768
commit
49946727d8
@ -1,15 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
from Project_Model.Libs.Transformer import normalize_sequence
|
from Project_Model.Libs.Transformer import normalize_sequence
|
||||||
from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder
|
# from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_input(batch_size, sos_token,pad_token, seq_len, embedder: Embedder):
|
def get_decoder_input(batch_size, sos_token,pad_token, seq_len):
|
||||||
|
|
||||||
single_decoder_input, _ = normalize_sequence([sos_token],seq_len,pad_token, end_token=0, add_ending=False)
|
single_decoder_input, _ = normalize_sequence([sos_token],seq_len,pad_token, end_token=0, add_ending=False)
|
||||||
tensor_decoder_input = torch.tensor(single_decoder_input[:])
|
tensor_decoder_input = torch.tensor(single_decoder_input[:])
|
||||||
embedded_decoder_intput = embedder(tensor_decoder_input)
|
# embedded_decoder_intput = embedder(tensor_decoder_input)
|
||||||
|
|
||||||
batch_decoder_input = embedded_decoder_intput.unsqueeze(0).repeat(batch_size, 1)
|
batch_decoder_input = tensor_decoder_input.unsqueeze(0).repeat(batch_size, 1)
|
||||||
return batch_decoder_input
|
return batch_decoder_input
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user