Added attention_mask
This commit is contained in:
4
Project_Model/Libs/Transformer/Utils/attention_mask.py
Normal file
4
Project_Model/Libs/Transformer/Utils/attention_mask.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import torch
|
||||
|
||||
def get_attention_mask(embedding_dimension: int) -> torch.Tensor:
|
||||
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1)
|
||||
Reference in New Issue
Block a user