Made model Batch ready
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from Project_Model.Libs.Transformer.Classes.FeedForwardNetwork import FeedForwardNetwork
|
||||
from Project_Model.Libs.Transformer.Classes.TorchMultiHeadAttention import (
|
||||
@@ -29,14 +30,17 @@ class Encoder(
|
||||
embedding_dimension
|
||||
) # norm of second "Add and Normalize"
|
||||
self.__dropout = nn.Dropout(0.1) # ...
|
||||
pass
|
||||
|
||||
def forward(self, x, padding_mask = None):
|
||||
|
||||
def forward(self, args: tuple[torch.Tensor, torch.Tensor]):
|
||||
# WARNING: args is needed to have sequential
|
||||
x, padding_mask = args
|
||||
|
||||
# -> ATTENTION -> dropout -> add and normalize -> FF -> dropout -> add and normalize ->
|
||||
# Attention with Residual Connection [ input + self-attention]
|
||||
|
||||
# 1) Multi Head Attention
|
||||
ATTENTION = self.__attention(x, x, x,key_padding_mask= padding_mask)
|
||||
ATTENTION = self.__attention(x, x, x, key_padding_mask=padding_mask)
|
||||
|
||||
# 2) Dropout
|
||||
DROPPED_ATTENTION = self.__dropout(ATTENTION)
|
||||
@@ -62,7 +66,7 @@ class Encoder(
|
||||
# 8) Layer Normalization
|
||||
x = self.__layer_norm_2(x)
|
||||
|
||||
return x,padding_mask
|
||||
return (x, padding_mask)
|
||||
|
||||
|
||||
# use eval to disable dropout ecc
|
||||
|
||||
Reference in New Issue
Block a user