NanoSocrates/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py
Christian Risi c60da8ba82 Refactoring
2025-10-05 15:40:29 +02:00

47 lines
1.1 KiB
Python

import torch
import torch.nn as nn
class TorchMultiHeadAttention(nn.Module):
def __init__(
self,
embedding_dimension: int,
number_of_attention_heads: int,
dropout: float = 0.0,
):
super().__init__()
self.attention = nn.MultiheadAttention(
embedding_dimension,
number_of_attention_heads,
dropout=dropout,
batch_first=True,
)
def forward(
self,
x_q: torch.Tensor,
x_k: torch.Tensor,
x_v: torch.Tensor,
attention_mask=None,
key_padding_mask=None,
) -> torch.Tensor:
# x * Wq -> Q
# x * Wk -> K
# x * Wv -> V
y, _ = self.attention.forward(
x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask
)
return y
# batch_first=False (default storico)
# Formato: (L, N, E)
# L = lunghezza della sequenza (time/posizioni)
# N = batch size
# E = dimensione d_model (embed)
# batch_first=True
# Formato: (N, L, E) (più naturale per molti modelli)