import torch import torch.nn as nn class TorchMultiHeadAttention(nn.Module): def __init__( self, embedding_dimension: int, number_of_attention_heads: int, dropout: float = 0.0, ): super().__init__() self.attention = nn.MultiheadAttention( embedding_dimension, number_of_attention_heads, dropout=dropout, batch_first=True, ) def forward( self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor, attention_mask=None, key_padding_mask=None, ) -> torch.Tensor: # x * Wq -> Q # x * Wk -> K # x * Wv -> V y, _ = self.attention.forward( x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask ) return y # batch_first=False (default storico) # Formato: (L, N, E) # L = lunghezza della sequenza (time/posizioni) # N = batch size # E = dimensione d_model (embed) # batch_first=True # Formato: (N, L, E) (più naturale per molti modelli)