WIP for the new prefix mask
This commit is contained in:
parent
92ae40013d
commit
82462078f8
@ -8,4 +8,19 @@ def get_causal_attention_mask(seq_len: int) -> torch.Tensor:
|
|||||||
def get_causal_attention_mask_batched(seq_len: int, batch_size: int ) -> torch.Tensor:
|
def get_causal_attention_mask_batched(seq_len: int, batch_size: int ) -> torch.Tensor:
|
||||||
base_mask = get_causal_attention_mask(seq_len)
|
base_mask = get_causal_attention_mask(seq_len)
|
||||||
return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size
|
return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size
|
||||||
# the result is that z,x,y where x,y are repeated along z
|
# the result is that z,x,y where x,y are repeated along z
|
||||||
|
|
||||||
|
def get_causal_attention_mask_with_prefix(seq_len, prefix):
|
||||||
|
mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
|
||||||
|
mask[:,:prefix] = False
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def get_prefix_causal_mask_from_padding_mask(seq_len, prefix_mask):
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
print(get_causal_attention_mask_with_prefix(10,3))
|
||||||
|
seq_len = 10
|
||||||
|
prefix = 3
|
||||||
|
mask = torch.arange(seq_len) >= prefix
|
||||||
|
"""
|
||||||
Loading…
x
Reference in New Issue
Block a user