diff --git a/Project_Model/Libs/Transformer/Classes/Decoder.py b/Project_Model/Libs/Transformer/Classes/Decoder.py index 11e9aa7..0a818ee 100644 --- a/Project_Model/Libs/Transformer/Classes/Decoder.py +++ b/Project_Model/Libs/Transformer/Classes/Decoder.py @@ -37,7 +37,7 @@ class Decoder(nn.Module): - def forward(self, x, k_x, v_x, padding_mask = None) -> torch.Tensor: # k_x = v_x . While x_q = x + def forward(self, x, k_x, v_x, padding_mask = None): #-> list[torch.Tensor]: # k_x = v_x . While x_q = x # build of attention mask attention_mask = get_causal_attention_mask(x.size(1)) @@ -88,7 +88,7 @@ class Decoder(nn.Module): # 12) Layer Normalization x = self.__layer_norm_3(x) - return x + return x, k_x, v_x, padding_mask # use eval to disable dropout ecc diff --git a/Project_Model/Libs/Transformer/Classes/Encoder.py b/Project_Model/Libs/Transformer/Classes/Encoder.py index cdec92a..0c46fe0 100644 --- a/Project_Model/Libs/Transformer/Classes/Encoder.py +++ b/Project_Model/Libs/Transformer/Classes/Encoder.py @@ -62,7 +62,7 @@ class Encoder( # 8) Layer Normalization x = self.__layer_norm_2(x) - return x + return x,padding_mask # use eval to disable dropout ecc