update to batch attention mask

This commit is contained in:
GassiGiuseppe
2025-10-06 13:03:03 +02:00
parent 87409fecd5
commit 948c3fd7ac
3 changed files with 19 additions and 13 deletions

View File

@@ -9,7 +9,6 @@ class TorchMultiHeadAttention(nn.Module):
embedding_dimension: int,
number_of_attention_heads: int,
dropout: float = 0.0,
attention_mask: Optional[torch.Tensor] = None
):
super().__init__()
self.attention = nn.MultiheadAttention(
@@ -19,7 +18,6 @@ class TorchMultiHeadAttention(nn.Module):
batch_first=True,
)
self.__attention_mask = attention_mask
def forward(
self,
@@ -27,14 +25,16 @@ class TorchMultiHeadAttention(nn.Module):
x_k: torch.Tensor,
x_v: torch.Tensor,
key_padding_mask=None,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
# x * Wq -> Q
# x * Wk -> K
# x * Wv -> V
y, _ = self.attention.forward(
x_q, x_k, x_v, attn_mask=self.__attention_mask, key_padding_mask=key_padding_mask
# REMEMBER: tochAttention uses Batch internally to build the 3 dimension attention mask given the 2 dimension
y, _ = self.attention(
x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask,
need_weights=False
)
return y