# multi-head attention -> (then to) ff # attention: qkv -> score = qk -> divide -> softamx # multihead -> QKV diferent in each head ( built by : X*[WQ/QK/WV]) # z = soft(Q*K'/sqr(d))*V # recombine Z: 1) concatenate. 2) [z01234] * W = Z # we expect later to have padding token ######################## # WIP ######################## import torch.nn as nn embed_dim = 256 num_heads = 8 multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) class MultiheadAttention: def __init__( self, num_heads=8, ) -> None: pass