Added attention_mask

This commit is contained in:
GassiGiuseppe
2025-10-05 17:49:01 +02:00
parent b303affd18
commit 6f219f634f
5 changed files with 20 additions and 5 deletions

View File

@@ -0,0 +1,3 @@
from .attention_mask import get_attention_mask
__all__ = ["get_attention_mask"]

View File

@@ -0,0 +1,4 @@
import torch
def get_attention_mask(embedding_dimension: int) -> torch.Tensor:
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1)