From 82462078f8412285518ad19d626f41e31a7520bb Mon Sep 17 00:00:00 2001 From: GassiGiuseppe Date: Sat, 11 Oct 2025 11:28:15 +0200 Subject: [PATCH] WIP for the new prefix mask --- .../Libs/Transformer/Utils/attention_mask.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/Project_Model/Libs/Transformer/Utils/attention_mask.py b/Project_Model/Libs/Transformer/Utils/attention_mask.py index b1e97f3..76455ca 100644 --- a/Project_Model/Libs/Transformer/Utils/attention_mask.py +++ b/Project_Model/Libs/Transformer/Utils/attention_mask.py @@ -8,4 +8,19 @@ def get_causal_attention_mask(seq_len: int) -> torch.Tensor: def get_causal_attention_mask_batched(seq_len: int, batch_size: int ) -> torch.Tensor: base_mask = get_causal_attention_mask(seq_len) return base_mask.unsqueeze(0).expand(batch_size, -1, -1) # add another dimension at the beginning, big as batch_size - # the result is that z,x,y where x,y are repeated along z \ No newline at end of file + # the result is that z,x,y where x,y are repeated along z + +def get_causal_attention_mask_with_prefix(seq_len, prefix): + mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) + mask[:,:prefix] = False + return mask + +def get_prefix_causal_mask_from_padding_mask(seq_len, prefix_mask): + + +""" +print(get_causal_attention_mask_with_prefix(10,3)) +seq_len = 10 +prefix = 3 +mask = torch.arange(seq_len) >= prefix +""" \ No newline at end of file