Modified decoder and decoder for sequential architecture
This commit is contained in:
parent
456ce724fe
commit
e1549d4458
@ -37,7 +37,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, k_x, v_x, padding_mask = None) -> torch.Tensor: # k_x = v_x . While x_q = x
|
def forward(self, x, k_x, v_x, padding_mask = None): #-> list[torch.Tensor]: # k_x = v_x . While x_q = x
|
||||||
|
|
||||||
# build of attention mask
|
# build of attention mask
|
||||||
attention_mask = get_causal_attention_mask(x.size(1))
|
attention_mask = get_causal_attention_mask(x.size(1))
|
||||||
@ -88,7 +88,7 @@ class Decoder(nn.Module):
|
|||||||
# 12) Layer Normalization
|
# 12) Layer Normalization
|
||||||
x = self.__layer_norm_3(x)
|
x = self.__layer_norm_3(x)
|
||||||
|
|
||||||
return x
|
return x, k_x, v_x, padding_mask
|
||||||
|
|
||||||
|
|
||||||
# use eval to disable dropout ecc
|
# use eval to disable dropout ecc
|
||||||
|
|||||||
@ -62,7 +62,7 @@ class Encoder(
|
|||||||
# 8) Layer Normalization
|
# 8) Layer Normalization
|
||||||
x = self.__layer_norm_2(x)
|
x = self.__layer_norm_2(x)
|
||||||
|
|
||||||
return x
|
return x,padding_mask
|
||||||
|
|
||||||
|
|
||||||
# use eval to disable dropout ecc
|
# use eval to disable dropout ecc
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user