Deleted MultiHeadAttention

This commit is contained in:
Christian Risi 2025-10-07 16:36:11 +02:00
parent 9b5bb6d5f8
commit f9545aca1d

View File

@ -1,24 +0,0 @@
# multi-head attention -> (then to) ff
# attention: qkv -> score = qk -> divide -> softamx
# multihead -> QKV diferent in each head ( built by : X*[WQ/QK/WV])
# z = soft(Q*K'/sqr(d))*V
# recombine Z: 1) concatenate. 2) [z01234] * W = Z
# we expect later to have padding token
########################
# WIP
########################
import torch.nn as nn
embed_dim = 256
num_heads = 8
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
class MultiheadAttention:
def __init__(
self,
num_heads=8,
) -> None:
pass