Modified decoder and decoder for sequential architecture
This commit is contained in:
parent
456ce724fe
commit
e1549d4458
@ -37,7 +37,7 @@ class Decoder(nn.Module):
|
||||
|
||||
|
||||
|
||||
def forward(self, x, k_x, v_x, padding_mask = None) -> torch.Tensor: # k_x = v_x . While x_q = x
|
||||
def forward(self, x, k_x, v_x, padding_mask = None): #-> list[torch.Tensor]: # k_x = v_x . While x_q = x
|
||||
|
||||
# build of attention mask
|
||||
attention_mask = get_causal_attention_mask(x.size(1))
|
||||
@ -88,7 +88,7 @@ class Decoder(nn.Module):
|
||||
# 12) Layer Normalization
|
||||
x = self.__layer_norm_3(x)
|
||||
|
||||
return x
|
||||
return x, k_x, v_x, padding_mask
|
||||
|
||||
|
||||
# use eval to disable dropout ecc
|
||||
|
||||
@ -62,7 +62,7 @@ class Encoder(
|
||||
# 8) Layer Normalization
|
||||
x = self.__layer_norm_2(x)
|
||||
|
||||
return x
|
||||
return x,padding_mask
|
||||
|
||||
|
||||
# use eval to disable dropout ecc
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user