4 lines
185 B
Python
4 lines
185 B
Python
|
|
import torch
|
||
|
|
|
||
|
|
def get_attention_mask(embedding_dimension: int) -> torch.Tensor:
|
||
|
|
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1)
|