diff --git a/Project_Model/Libs/Transformer/Utils/__init__.py b/Project_Model/Libs/Transformer/Utils/__init__.py index 2831ec4..d4dfba3 100644 --- a/Project_Model/Libs/Transformer/Utils/__init__.py +++ b/Project_Model/Libs/Transformer/Utils/__init__.py @@ -1,4 +1,4 @@ -from .attention_mask import get_attention_mask +from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_batched from .task_type import TaskType -__all__ = ["get_attention_mask", "TaskType"] \ No newline at end of file +__all__ = ["get_causal_attention_mask", "TaskType", "get_causal_attention_mask_batched"] \ No newline at end of file diff --git a/Project_Model/Libs/Transformer/Utils/attention_mask.py b/Project_Model/Libs/Transformer/Utils/attention_mask.py index a6c595e..cb0ddcf 100644 --- a/Project_Model/Libs/Transformer/Utils/attention_mask.py +++ b/Project_Model/Libs/Transformer/Utils/attention_mask.py @@ -1,4 +1,9 @@ import torch -def get_attention_mask(embedding_dimension: int) -> torch.Tensor: - return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1) \ No newline at end of file +def get_causal_attention_mask(embedding_dimension: int) -> torch.Tensor: + return torch.triu(torch.ones(embedding_dimension, embedding_dimension, dtype=torch.bool), diagonal=1) + +def get_causal_attention_mask_batched(embedding_dimension: int, batch_size: int ) -> torch.Tensor: + base_mask = get_causal_attention_mask(embedding_dimension) + return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size + # the result is that z,x,y where x,y are repeated along z \ No newline at end of file