2025-10-06 12:00:11 +02:00

9 lines
542 B
Python

import torch
def get_causal_attention_mask(embedding_dimension: int) -> torch.Tensor:
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1)
def get_causal_attention_mask_batched(embedding_dimension: int, batch_size: int ) -> torch.Tensor:
base_mask = get_causal_attention_mask(embedding_dimension)
return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size
# the result is that z,x,y where x,y are repeated along z