4 lines
185 B
Python
4 lines
185 B
Python
import torch
|
|
|
|
def get_attention_mask(embedding_dimension: int) -> torch.Tensor:
|
|
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1) |