import torch def get_causal_attention_mask(seq_len: int) -> torch.Tensor: return torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) # there is no need for this since MultiHeadAttention of torch does this internally def get_causal_attention_mask_batched(seq_len: int, batch_size: int ) -> torch.Tensor: base_mask = get_causal_attention_mask(seq_len) return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size # the result is that z,x,y where x,y are repeated along z def get_causal_attention_mask_with_prefix(seq_len, prefix): mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) mask[:,:prefix] = False return mask def get_prefix_causal_mask_from_padding_mask(seq_len, prefix_mask): """ print(get_causal_attention_mask_with_prefix(10,3)) seq_len = 10 prefix = 3 mask = torch.arange(seq_len) >= prefix """