Quick fix to architecture

This commit is contained in:
Christian Risi
2025-10-08 12:34:09 +02:00
parent 14c3914571
commit c2e13bc9c6
8 changed files with 89 additions and 126 deletions

View File

@@ -56,12 +56,12 @@ class Decoder(nn.Module):
)
# 2) Dropout
DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION)
del MASKED_ATTENTION
# DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION)
# del MASKED_ATTENTION
# 3) Residual Connection
x = x + DROPPED_MASKED_ATTENTION
del DROPPED_MASKED_ATTENTION
x = x + MASKED_ATTENTION
del MASKED_ATTENTION
# 4) Layer Normalization
x = self.__layer_norm_1(x)
@@ -72,12 +72,12 @@ class Decoder(nn.Module):
)
# 6) Dropout
DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)
del CROSS_ATTENTION
# DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)
# del CROSS_ATTENTION
# 7) Residual Connection
x = x + DROPPED_CROSS_ATTENTION
del DROPPED_CROSS_ATTENTION
x = x + CROSS_ATTENTION
del CROSS_ATTENTION
# 8) Layer Normalization
x = self.__layer_norm_2(x)
@@ -86,12 +86,12 @@ class Decoder(nn.Module):
FEED_FORWARD = self.__feed_forward_network(x)
# 10) Dropout
DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
del FEED_FORWARD
# DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
# del FEED_FORWARD
# 11) Residual Connection
x = x + DROPPED_FEED_FORWARD
del DROPPED_FEED_FORWARD
x = x + FEED_FORWARD
del FEED_FORWARD
# 12) Layer Normalization
x = self.__layer_norm_3(x)