added builder for prefix mask
This commit is contained in:
parent
5e3878ea17
commit
f1886e5be1
@ -15,12 +15,16 @@ def get_causal_attention_mask_with_prefix(seq_len, prefix):
|
|||||||
mask[:,:prefix] = False
|
mask[:,:prefix] = False
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def get_prefix_causal_mask_from_padding_mask(seq_len, prefix_mask):
|
def get_prefix_causal_mask_from_padding_mask(seq_len:int, prefix_mask, att_heads:int=1):
|
||||||
|
expanded_padding_mask = prefix_mask.unsqueeze(-1).repeat(1, 1, seq_len) # B,T,T
|
||||||
|
expanded_padding_mask = expanded_padding_mask.permute(0,2,1) # B,T,T
|
||||||
|
mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) # T,T
|
||||||
|
tri_batched = mask.unsqueeze(0) # 1,T,T will broadcast over B
|
||||||
|
prefix_causal_mask = expanded_padding_mask & tri_batched
|
||||||
|
prefix_causal_mask = prefix_causal_mask.repeat_interleave(att_heads, dim=0) # B*H,T,T
|
||||||
|
return prefix_causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
print(get_causal_attention_mask_with_prefix(10,3))
|
print(get_causal_attention_mask_with_prefix(10,3))
|
||||||
seq_len = 10
|
|
||||||
prefix = 3
|
|
||||||
mask = torch.arange(seq_len) >= prefix
|
|
||||||
"""
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user