2025-10-04 21:07:58 +02:00
|
|
|
# multi-head attention -> (then to) ff
|
|
|
|
|
# attention: qkv -> score = qk -> divide -> softamx
|
|
|
|
|
# multihead -> QKV diferent in each head ( built by : X*[WQ/QK/WV])
|
|
|
|
|
# z = soft(Q*K'/sqr(d))*V
|
|
|
|
|
# recombine Z: 1) concatenate. 2) [z01234] * W = Z
|
|
|
|
|
# we expect later to have padding token
|
|
|
|
|
########################
|
|
|
|
|
# WIP
|
|
|
|
|
########################
|
2025-10-05 15:40:29 +02:00
|
|
|
|
2025-10-04 21:07:58 +02:00
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
embed_dim = 256
|
|
|
|
|
num_heads = 8
|
|
|
|
|
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
|
|
|
|
|
|
|
|
|
2025-10-05 15:40:29 +02:00
|
|
|
class MultiheadAttention:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
num_heads=8,
|
|
|
|
|
) -> None:
|
|
|
|
|
pass
|