2025-10-05 15:40:29 +02:00
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
2025-10-05 17:49:01 +02:00
|
|
|
from typing import Optional
|
2025-10-05 15:40:29 +02:00
|
|
|
|
|
|
|
|
class TorchMultiHeadAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
embedding_dimension: int,
|
|
|
|
|
number_of_attention_heads: int,
|
2025-10-07 16:37:20 +02:00
|
|
|
dropout: float = 0.0
|
2025-10-05 15:40:29 +02:00
|
|
|
):
|
|
|
|
|
super().__init__()
|
2025-10-07 16:37:20 +02:00
|
|
|
self.attention = torch.nn.MultiheadAttention(
|
2025-10-05 15:40:29 +02:00
|
|
|
embedding_dimension,
|
2025-10-07 16:37:20 +02:00
|
|
|
num_heads=number_of_attention_heads,
|
2025-10-05 15:40:29 +02:00
|
|
|
dropout=dropout,
|
|
|
|
|
batch_first=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x_q: torch.Tensor,
|
|
|
|
|
x_k: torch.Tensor,
|
|
|
|
|
x_v: torch.Tensor,
|
|
|
|
|
key_padding_mask=None,
|
2025-10-06 13:03:03 +02:00
|
|
|
attention_mask: Optional[torch.Tensor] = None
|
2025-10-05 15:40:29 +02:00
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
# x * Wq -> Q
|
|
|
|
|
# x * Wk -> K
|
|
|
|
|
# x * Wv -> V
|
2025-10-06 13:03:03 +02:00
|
|
|
# REMEMBER: tochAttention uses Batch internally to build the 3 dimension attention mask given the 2 dimension
|
|
|
|
|
y, _ = self.attention(
|
|
|
|
|
x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask,
|
|
|
|
|
need_weights=False
|
2025-10-05 15:40:29 +02:00
|
|
|
)
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# batch_first=False (default storico)
|
|
|
|
|
# Formato: (L, N, E)
|
|
|
|
|
# L = lunghezza della sequenza (time/posizioni)
|
|
|
|
|
# N = batch size
|
|
|
|
|
# E = dimensione d_model (embed)
|
|
|
|
|
# batch_first=True
|
|
|
|
|
# Formato: (N, L, E) (più naturale per molti modelli)
|