2025-10-05 17:49:01 +02:00
|
|
|
import torch
|
|
|
|
|
|
2025-10-06 13:03:03 +02:00
|
|
|
def get_causal_attention_mask(seq_len: int) -> torch.Tensor:
|
|
|
|
|
return torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
|
2025-10-06 12:00:11 +02:00
|
|
|
|
2025-10-06 13:03:03 +02:00
|
|
|
|
|
|
|
|
# there is no need for this since MultiHeadAttention of torch does this internally
|
|
|
|
|
def get_causal_attention_mask_batched(seq_len: int, batch_size: int ) -> torch.Tensor:
|
|
|
|
|
base_mask = get_causal_attention_mask(seq_len)
|
2025-10-06 12:00:11 +02:00
|
|
|
return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size
|
|
|
|
|
# the result is that z,x,y where x,y are repeated along z
|