29 lines
1.4 KiB
Python
Raw Normal View History

2025-10-05 17:49:01 +02:00
import torch
2025-10-06 13:03:03 +02:00
def get_causal_attention_mask(seq_len: int) -> torch.Tensor:
return torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
2025-10-06 13:03:03 +02:00
# there is no need for this since MultiHeadAttention of torch does this internally
def get_causal_attention_mask_batched(seq_len: int, batch_size: int ) -> torch.Tensor:
base_mask = get_causal_attention_mask(seq_len)
return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size
2025-10-11 11:28:15 +02:00
# the result is that z,x,y where x,y are repeated along z
def get_causal_attention_mask_with_prefix(seq_len, prefix):
mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
mask[:,:prefix] = False
return mask
2025-10-11 15:19:09 +02:00
def get_prefix_causal_mask_from_padding_mask(seq_len:int, prefix_mask, att_heads:int=1):
expanded_padding_mask = prefix_mask.unsqueeze(-1).repeat(1, 1, seq_len) # B,T,T
expanded_padding_mask = expanded_padding_mask.permute(0,2,1) # B,T,T
mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) # T,T
tri_batched = mask.unsqueeze(0) # 1,T,T will broadcast over B
prefix_causal_mask = expanded_padding_mask & tri_batched
prefix_causal_mask = prefix_causal_mask.repeat_interleave(att_heads, dim=0) # B*H,T,T
return prefix_causal_mask
2025-10-11 11:28:15 +02:00