Activated Dropout to avoid overfitting
This commit is contained in:
parent
f463f699cf
commit
4ca1d0a189
@ -19,7 +19,7 @@ class Decoder(nn.Module):
|
|||||||
self.__attention_heads = number_of_attention_heads
|
self.__attention_heads = number_of_attention_heads
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.__masked_attention = MultiHeadAttention(
|
self.__masked_attention = MultiHeadAttention(
|
||||||
embedding_dimension, number_of_attention_heads, dropout=0.1
|
embedding_dimension, number_of_attention_heads, dropout=0.1
|
||||||
@ -68,12 +68,12 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 2) Dropout
|
# 2) Dropout
|
||||||
# DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION)
|
DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION)
|
||||||
# del MASKED_ATTENTION
|
del MASKED_ATTENTION
|
||||||
|
|
||||||
# 3) Residual Connection
|
# 3) Residual Connection
|
||||||
x = x + MASKED_ATTENTION
|
x = x + DROPPED_MASKED_ATTENTION
|
||||||
del MASKED_ATTENTION
|
del DROPPED_MASKED_ATTENTION
|
||||||
|
|
||||||
# 4) Layer Normalization
|
# 4) Layer Normalization
|
||||||
x = self.__layer_norm_1(x)
|
x = self.__layer_norm_1(x)
|
||||||
@ -86,12 +86,12 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 6) Dropout
|
# 6) Dropout
|
||||||
# DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)
|
DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)
|
||||||
# del CROSS_ATTENTION
|
del CROSS_ATTENTION
|
||||||
|
|
||||||
# 7) Residual Connection
|
# 7) Residual Connection
|
||||||
x = x + CROSS_ATTENTION
|
x = x + DROPPED_CROSS_ATTENTION
|
||||||
del CROSS_ATTENTION
|
del DROPPED_CROSS_ATTENTION
|
||||||
|
|
||||||
# 8) Layer Normalization
|
# 8) Layer Normalization
|
||||||
x = self.__layer_norm_2(x)
|
x = self.__layer_norm_2(x)
|
||||||
@ -100,12 +100,12 @@ class Decoder(nn.Module):
|
|||||||
FEED_FORWARD = self.__feed_forward_network(x)
|
FEED_FORWARD = self.__feed_forward_network(x)
|
||||||
|
|
||||||
# 10) Dropout
|
# 10) Dropout
|
||||||
# DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
|
DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
|
||||||
# del FEED_FORWARD
|
del FEED_FORWARD
|
||||||
|
|
||||||
# 11) Residual Connection
|
# 11) Residual Connection
|
||||||
x = x + FEED_FORWARD
|
x = x + DROPPED_FEED_FORWARD
|
||||||
del FEED_FORWARD
|
del DROPPED_FEED_FORWARD
|
||||||
|
|
||||||
# 12) Layer Normalization
|
# 12) Layer Normalization
|
||||||
x = self.__layer_norm_3(x)
|
x = self.__layer_norm_3(x)
|
||||||
|
|||||||
@ -43,12 +43,12 @@ class Encoder(
|
|||||||
ATTENTION = self.__attention(x, x, x, key_padding_mask=padding_mask)
|
ATTENTION = self.__attention(x, x, x, key_padding_mask=padding_mask)
|
||||||
|
|
||||||
# 2) Dropout
|
# 2) Dropout
|
||||||
# DROPPED_ATTENTION = self.__dropout(ATTENTION)
|
DROPPED_ATTENTION = self.__dropout(ATTENTION)
|
||||||
# del ATTENTION
|
del ATTENTION
|
||||||
|
|
||||||
# 3) Residual Connection
|
# 3) Residual Connection
|
||||||
x = x + ATTENTION
|
x = x + DROPPED_ATTENTION
|
||||||
del ATTENTION
|
del DROPPED_ATTENTION
|
||||||
|
|
||||||
# 4) Layer Normalization
|
# 4) Layer Normalization
|
||||||
x = self.__layer_norm_1(x)
|
x = self.__layer_norm_1(x)
|
||||||
@ -57,12 +57,12 @@ class Encoder(
|
|||||||
FEED_FORWARD = self.__feed_forward(x)
|
FEED_FORWARD = self.__feed_forward(x)
|
||||||
|
|
||||||
# 6) Dropout
|
# 6) Dropout
|
||||||
# DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
|
DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
|
||||||
# del FEED_FORWARD
|
del FEED_FORWARD
|
||||||
|
|
||||||
# 7) Residual Connection
|
# 7) Residual Connection
|
||||||
x = x + FEED_FORWARD
|
x = x + DROPPED_FEED_FORWARD
|
||||||
del FEED_FORWARD
|
del DROPPED_FEED_FORWARD
|
||||||
|
|
||||||
# 8) Layer Normalization
|
# 8) Layer Normalization
|
||||||
x = self.__layer_norm_2(x)
|
x = self.__layer_norm_2(x)
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class FeedForwardNetwork(nn.Module):
|
|||||||
x = self.__relu(x)
|
x = self.__relu(x)
|
||||||
|
|
||||||
# 3) Dropout
|
# 3) Dropout
|
||||||
# x = self.__dropout(x)
|
x = self.__dropout(x)
|
||||||
|
|
||||||
# 4) Linear Layer
|
# 4) Linear Layer
|
||||||
x = self.__fully_connected_2(x)
|
x = self.__fully_connected_2(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user