Update of the batcher to resolve a bug in the 4th construction
This commit is contained in:
@@ -20,7 +20,6 @@ class Decoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
|
||||
|
||||
self.__masked_attention = MultiHeadAttention(
|
||||
embedding_dimension, number_of_attention_heads, dropout=0.1
|
||||
)
|
||||
@@ -58,7 +57,7 @@ class Decoder(nn.Module):
|
||||
# build of attention mask
|
||||
# TODO: create a prefix causal mask if needed
|
||||
if decoder_only:
|
||||
attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads)
|
||||
attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),src_padding_mask,self.__attention_heads) # the correct is tgt however ...
|
||||
else:
|
||||
attention_mask = get_causal_attention_mask(x.size(1))
|
||||
|
||||
|
||||
@@ -24,5 +24,7 @@ def get_prefix_causal_mask_from_padding_mask(seq_len:int, prefix_mask, att_heads
|
||||
prefix_causal_mask = prefix_causal_mask.repeat_interleave(att_heads, dim=0) # B*H,T,T
|
||||
return prefix_causal_mask
|
||||
|
||||
#def get_prefix_causal_mask():
|
||||
# continue_rdf =
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user