Quick fix to architecture
This commit is contained in:
@@ -56,12 +56,12 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
# 2) Dropout
|
||||
DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION)
|
||||
del MASKED_ATTENTION
|
||||
# DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION)
|
||||
# del MASKED_ATTENTION
|
||||
|
||||
# 3) Residual Connection
|
||||
x = x + DROPPED_MASKED_ATTENTION
|
||||
del DROPPED_MASKED_ATTENTION
|
||||
x = x + MASKED_ATTENTION
|
||||
del MASKED_ATTENTION
|
||||
|
||||
# 4) Layer Normalization
|
||||
x = self.__layer_norm_1(x)
|
||||
@@ -72,12 +72,12 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
# 6) Dropout
|
||||
DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)
|
||||
del CROSS_ATTENTION
|
||||
# DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)
|
||||
# del CROSS_ATTENTION
|
||||
|
||||
# 7) Residual Connection
|
||||
x = x + DROPPED_CROSS_ATTENTION
|
||||
del DROPPED_CROSS_ATTENTION
|
||||
x = x + CROSS_ATTENTION
|
||||
del CROSS_ATTENTION
|
||||
|
||||
# 8) Layer Normalization
|
||||
x = self.__layer_norm_2(x)
|
||||
@@ -86,12 +86,12 @@ class Decoder(nn.Module):
|
||||
FEED_FORWARD = self.__feed_forward_network(x)
|
||||
|
||||
# 10) Dropout
|
||||
DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
|
||||
del FEED_FORWARD
|
||||
# DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
|
||||
# del FEED_FORWARD
|
||||
|
||||
# 11) Residual Connection
|
||||
x = x + DROPPED_FEED_FORWARD
|
||||
del DROPPED_FEED_FORWARD
|
||||
x = x + FEED_FORWARD
|
||||
del FEED_FORWARD
|
||||
|
||||
# 12) Layer Normalization
|
||||
x = self.__layer_norm_3(x)
|
||||
|
||||
@@ -43,11 +43,12 @@ class Encoder(
|
||||
ATTENTION = self.__attention(x, x, x, key_padding_mask=padding_mask)
|
||||
|
||||
# 2) Dropout
|
||||
DROPPED_ATTENTION = self.__dropout(ATTENTION)
|
||||
del ATTENTION
|
||||
# DROPPED_ATTENTION = self.__dropout(ATTENTION)
|
||||
# del ATTENTION
|
||||
|
||||
# 3) Residual Connection
|
||||
x = x + DROPPED_ATTENTION
|
||||
x = x + ATTENTION
|
||||
del ATTENTION
|
||||
|
||||
# 4) Layer Normalization
|
||||
x = self.__layer_norm_1(x)
|
||||
@@ -56,12 +57,12 @@ class Encoder(
|
||||
FEED_FORWARD = self.__feed_forward(x)
|
||||
|
||||
# 6) Dropout
|
||||
DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
|
||||
del FEED_FORWARD
|
||||
# DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
|
||||
# del FEED_FORWARD
|
||||
|
||||
# 7) Residual Connection
|
||||
x = x + DROPPED_FEED_FORWARD
|
||||
del DROPPED_FEED_FORWARD
|
||||
x = x + FEED_FORWARD
|
||||
del FEED_FORWARD
|
||||
|
||||
# 8) Layer Normalization
|
||||
x = self.__layer_norm_2(x)
|
||||
|
||||
Reference in New Issue
Block a user