Refactoring
This commit is contained in:
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TorchMultiHeadAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dimension: int,
|
||||
number_of_attention_heads: int,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = nn.MultiheadAttention(
|
||||
embedding_dimension,
|
||||
number_of_attention_heads,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_q: torch.Tensor,
|
||||
x_k: torch.Tensor,
|
||||
x_v: torch.Tensor,
|
||||
attention_mask=None,
|
||||
key_padding_mask=None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# x * Wq -> Q
|
||||
# x * Wk -> K
|
||||
# x * Wv -> V
|
||||
|
||||
y, _ = self.attention.forward(
|
||||
x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
# batch_first=False (default storico)
|
||||
# Formato: (L, N, E)
|
||||
# L = lunghezza della sequenza (time/posizioni)
|
||||
# N = batch size
|
||||
# E = dimensione d_model (embed)
|
||||
# batch_first=True
|
||||
# Formato: (N, L, E) (più naturale per molti modelli)
|
||||
Reference in New Issue
Block a user