diff --git a/Project_Model/Libs/Transformer/Utils/attention_mask.py b/Project_Model/Libs/Transformer/Utils/attention_mask.py index 76455ca..34a6f86 100644 --- a/Project_Model/Libs/Transformer/Utils/attention_mask.py +++ b/Project_Model/Libs/Transformer/Utils/attention_mask.py @@ -15,12 +15,16 @@ def get_causal_attention_mask_with_prefix(seq_len, prefix): mask[:,:prefix] = False return mask -def get_prefix_causal_mask_from_padding_mask(seq_len, prefix_mask): +def get_prefix_causal_mask_from_padding_mask(seq_len:int, prefix_mask, att_heads:int=1): + expanded_padding_mask = prefix_mask.unsqueeze(-1).repeat(1, 1, seq_len) # B,T,T + expanded_padding_mask = expanded_padding_mask.permute(0,2,1) # B,T,T + mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) # T,T + tri_batched = mask.unsqueeze(0) # 1,T,T will broadcast over B + prefix_causal_mask = expanded_padding_mask & tri_batched + prefix_causal_mask = prefix_causal_mask.repeat_interleave(att_heads, dim=0) # B*H,T,T + return prefix_causal_mask + -""" print(get_causal_attention_mask_with_prefix(10,3)) -seq_len = 10 -prefix = 3 -mask = torch.arange(seq_len) >= prefix -""" \ No newline at end of file +