Added attention_mask
This commit is contained in:
parent
b303affd18
commit
6f219f634f
@ -2,6 +2,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .FeedForwardNetwork import FeedForwardNetwork
|
from .FeedForwardNetwork import FeedForwardNetwork
|
||||||
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
|
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
|
||||||
|
from ..Utils.attention_mask import get_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
@ -15,7 +17,7 @@ class Decoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.__masked_attention = MultiHeadAttention(
|
self.__masked_attention = MultiHeadAttention(
|
||||||
embedding_dimension, number_of_attention_heads, dropout=0.1
|
embedding_dimension, number_of_attention_heads, dropout=0.1, attention_mask=get_attention_mask(embedding_dimension)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.__layer_norm_1 = nn.LayerNorm(embedding_dimension)
|
self.__layer_norm_1 = nn.LayerNorm(embedding_dimension)
|
||||||
@ -32,6 +34,8 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.__layer_norm_3 = nn.LayerNorm(embedding_dimension)
|
self.__layer_norm_3 = nn.LayerNorm(embedding_dimension)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x
|
def forward(self, x, k_x, v_x, attention_mask) -> torch.Tensor: # k_x = v_x . While x_q = x
|
||||||
|
|
||||||
# 1) Masked Attention
|
# 1) Masked Attention
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
class TorchMultiHeadAttention(nn.Module):
|
class TorchMultiHeadAttention(nn.Module):
|
||||||
|
|
||||||
@ -9,6 +9,7 @@ class TorchMultiHeadAttention(nn.Module):
|
|||||||
embedding_dimension: int,
|
embedding_dimension: int,
|
||||||
number_of_attention_heads: int,
|
number_of_attention_heads: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention = nn.MultiheadAttention(
|
self.attention = nn.MultiheadAttention(
|
||||||
@ -18,12 +19,13 @@ class TorchMultiHeadAttention(nn.Module):
|
|||||||
batch_first=True,
|
batch_first=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.__attention_mask = attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x_q: torch.Tensor,
|
x_q: torch.Tensor,
|
||||||
x_k: torch.Tensor,
|
x_k: torch.Tensor,
|
||||||
x_v: torch.Tensor,
|
x_v: torch.Tensor,
|
||||||
attention_mask=None,
|
|
||||||
key_padding_mask=None,
|
key_padding_mask=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
@ -32,7 +34,7 @@ class TorchMultiHeadAttention(nn.Module):
|
|||||||
# x * Wv -> V
|
# x * Wv -> V
|
||||||
|
|
||||||
y, _ = self.attention.forward(
|
y, _ = self.attention.forward(
|
||||||
x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask
|
x_q, x_k, x_v, attn_mask=self.__attention_mask, key_padding_mask=key_padding_mask
|
||||||
)
|
)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|||||||
3
Project_Model/Libs/Transformer/Utils/__init__.py
Normal file
3
Project_Model/Libs/Transformer/Utils/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .attention_mask import get_attention_mask
|
||||||
|
|
||||||
|
__all__ = ["get_attention_mask"]
|
||||||
4
Project_Model/Libs/Transformer/Utils/attention_mask.py
Normal file
4
Project_Model/Libs/Transformer/Utils/attention_mask.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def get_attention_mask(embedding_dimension: int) -> torch.Tensor:
|
||||||
|
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1)
|
||||||
@ -1,3 +1,5 @@
|
|||||||
from .Classes import *
|
from .Classes import *
|
||||||
|
from .Utils import *
|
||||||
|
|
||||||
from . import Classes
|
from . import Classes
|
||||||
|
from . import Utils
|
||||||
Loading…
x
Reference in New Issue
Block a user