Update of the batcher to resolve a bug in the 4th construction
This commit is contained in:
@@ -24,5 +24,7 @@ def get_prefix_causal_mask_from_padding_mask(seq_len:int, prefix_mask, att_heads
|
||||
prefix_causal_mask = prefix_causal_mask.repeat_interleave(att_heads, dim=0) # B*H,T,T
|
||||
return prefix_causal_mask
|
||||
|
||||
#def get_prefix_causal_mask():
|
||||
# continue_rdf =
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user