added method fot batched attention_mask
This commit is contained in:
parent
7e40a36701
commit
87409fecd5
@ -1,4 +1,4 @@
|
||||
from .attention_mask import get_attention_mask
|
||||
from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_batched
|
||||
from .task_type import TaskType
|
||||
|
||||
__all__ = ["get_attention_mask", "TaskType"]
|
||||
__all__ = ["get_causal_attention_mask", "TaskType", "get_causal_attention_mask_batched"]
|
||||
@ -1,4 +1,9 @@
|
||||
import torch
|
||||
|
||||
def get_attention_mask(embedding_dimension: int) -> torch.Tensor:
|
||||
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1)
|
||||
def get_causal_attention_mask(embedding_dimension: int) -> torch.Tensor:
|
||||
return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1)
|
||||
|
||||
def get_causal_attention_mask_batched(embedding_dimension: int, batch_size: int ) -> torch.Tensor:
|
||||
base_mask = get_causal_attention_mask(embedding_dimension)
|
||||
return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size
|
||||
# the result is that z,x,y where x,y are repeated along z
|
||||
Loading…
x
Reference in New Issue
Block a user