Modified decoder and decoder for sequential architecture

This commit is contained in:
GassiGiuseppe 2025-10-06 18:20:46 +02:00
parent 456ce724fe
commit e1549d4458
2 changed files with 3 additions and 3 deletions

View File

@ -37,7 +37,7 @@ class Decoder(nn.Module):
def forward(self, x, k_x, v_x, padding_mask = None) -> torch.Tensor: # k_x = v_x . While x_q = x
def forward(self, x, k_x, v_x, padding_mask = None): #-> list[torch.Tensor]: # k_x = v_x . While x_q = x
# build of attention mask
attention_mask = get_causal_attention_mask(x.size(1))
@ -88,7 +88,7 @@ class Decoder(nn.Module):
# 12) Layer Normalization
x = self.__layer_norm_3(x)
return x
return x, k_x, v_x, padding_mask
# use eval to disable dropout ecc

View File

@ -62,7 +62,7 @@ class Encoder(
# 8) Layer Normalization
x = self.__layer_norm_2(x)
return x
return x,padding_mask
# use eval to disable dropout ecc