4 lines
185 B
Python
Raw Normal View History

2025-10-05 17:49:01 +02:00
import torch
def get_attention_mask(embedding_dimension: int) -> torch.Tensor:
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1)