NanoSocrates/Project_Model/Libs/Transformer/Classes/TorchMultiHeadAttention.py

48 lines
1.3 KiB
Python
Raw Normal View History

2025-10-05 15:40:29 +02:00
import torch
import torch.nn as nn
2025-10-05 17:49:01 +02:00
from typing import Optional
2025-10-05 15:40:29 +02:00
class TorchMultiHeadAttention(nn.Module):
def __init__(
self,
embedding_dimension: int,
number_of_attention_heads: int,
2025-10-07 16:37:20 +02:00
dropout: float = 0.0
2025-10-05 15:40:29 +02:00
):
super().__init__()
2025-10-07 16:37:20 +02:00
self.attention = torch.nn.MultiheadAttention(
2025-10-05 15:40:29 +02:00
embedding_dimension,
2025-10-07 16:37:20 +02:00
num_heads=number_of_attention_heads,
2025-10-05 15:40:29 +02:00
dropout=dropout,
batch_first=True,
)
def forward(
self,
x_q: torch.Tensor,
x_k: torch.Tensor,
x_v: torch.Tensor,
key_padding_mask=None,
2025-10-06 13:03:03 +02:00
attention_mask: Optional[torch.Tensor] = None
2025-10-05 15:40:29 +02:00
) -> torch.Tensor:
# x * Wq -> Q
# x * Wk -> K
# x * Wv -> V
2025-10-06 13:03:03 +02:00
# REMEMBER: tochAttention uses Batch internally to build the 3 dimension attention mask given the 2 dimension
y, _ = self.attention(
x_q, x_k, x_v, attn_mask=attention_mask, key_padding_mask=key_padding_mask,
need_weights=False
2025-10-05 15:40:29 +02:00
)
return y
# batch_first=False (default storico)
# Formato: (L, N, E)
# L = lunghezza della sequenza (time/posizioni)
# N = batch size
# E = dimensione d_model (embed)
# batch_first=True
# Formato: (N, L, E) (più naturale per molti modelli)