update to batch attention mask

This commit is contained in:
GassiGiuseppe 2025-10-06 13:03:03 +02:00
parent 87409fecd5
commit 948c3fd7ac
3 changed files with 19 additions and 13 deletions

View File

@ -2,8 +2,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from .FeedForwardNetwork import FeedForwardNetwork from .FeedForwardNetwork import FeedForwardNetwork
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
from ..Utils.attention_mask import get_attention_mask from ..Utils.attention_mask import get_causal_attention_mask
# B, L(T), E_D
class Decoder(nn.Module): class Decoder(nn.Module):
@ -17,7 +18,7 @@ class Decoder(nn.Module):
super().__init__() super().__init__()
self.__masked_attention = MultiHeadAttention( self.__masked_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1, attention_mask=get_attention_mask(embedding_dimension) embedding_dimension, number_of_attention_heads, dropout=0.1
) )
self.__layer_norm_1 = nn.LayerNorm(embedding_dimension) self.__layer_norm_1 = nn.LayerNorm(embedding_dimension)
@ -38,9 +39,12 @@ class Decoder(nn.Module):
def forward(self, x, k_x, v_x, padding_mask = None) -> torch.Tensor: # k_x = v_x . While x_q = x def forward(self, x, k_x, v_x, padding_mask = None) -> torch.Tensor: # k_x = v_x . While x_q = x
# build of attention mask
attention_mask = get_causal_attention_mask(x.size(1))
# 1) Masked Attention # 1) Masked Attention
MASKED_ATTENTION = self.__masked_attention( MASKED_ATTENTION = self.__masked_attention(
x, x, x, key_padding_mask=padding_mask x, x, x, key_padding_mask=padding_mask, attn_mask=attention_mask
) )
# 2) Dropout # 2) Dropout
@ -57,7 +61,7 @@ class Decoder(nn.Module):
x = self.__layer_norm_1(x) x = self.__layer_norm_1(x)
# 5) Encoderdecoder (cross) attention # 5) Encoderdecoder (cross) attention
CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x key_padding_mask=padding_mask) CROSS_ATTENTION = self.__cross_attention(x, k_x, v_x, key_padding_mask=padding_mask)
# 6) Dropout # 6) Dropout
DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION) DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)

View File

@ -9,7 +9,6 @@ class TorchMultiHeadAttention(nn.Module):
embedding_dimension: int, embedding_dimension: int,
number_of_attention_heads: int, number_of_attention_heads: int,
dropout: float = 0.0, dropout: float = 0.0,
attention_mask: Optional[torch.Tensor] = None
): ):
super().__init__() super().__init__()
self.attention = nn.MultiheadAttention( self.attention = nn.MultiheadAttention(
@ -19,7 +18,6 @@ class TorchMultiHeadAttention(nn.Module):
batch_first=True, batch_first=True,
) )
self.__attention_mask = attention_mask
def forward( def forward(
self, self,
@ -27,14 +25,16 @@ class TorchMultiHeadAttention(nn.Module):
x_k: torch.Tensor, x_k: torch.Tensor,
x_v: torch.Tensor, x_v: torch.Tensor,
key_padding_mask=None, key_padding_mask=None,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
# x * Wq -> Q # x * Wq -> Q
# x * Wk -> K # x * Wk -> K
# x * Wv -> V # x * Wv -> V
# REMEMBER: tochAttention uses Batch internally to build the 3 dimension attention mask given the 2 dimension
y, _ = self.attention.forward( y, _ = self.attention(
x_q, x_k, x_v, attn_mask=self.__attention_mask, key_padding_mask=key_padding_mask x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask,
need_weights=False
) )
return y return y

View File

@ -1,9 +1,11 @@
import torch import torch
def get_causal_attention_mask(embedding_dimension: int) -> torch.Tensor: def get_causal_attention_mask(seq_len: int) -> torch.Tensor:
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1) return torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
def get_causal_attention_mask_batched(embedding_dimension: int, batch_size: int ) -> torch.Tensor:
base_mask = get_causal_attention_mask(embedding_dimension) # there is no need for this since MultiHeadAttention of torch does this internally
def get_causal_attention_mask_batched(seq_len: int, batch_size: int ) -> torch.Tensor:
base_mask = get_causal_attention_mask(seq_len)
return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size
# the result is that z,x,y where x,y are repeated along z # the result is that z,x,y where x,y are repeated along z