Update of the batcher to resolve a bug in the 4th construction

This commit is contained in:
GassiGiuseppe
2025-10-12 16:35:42 +02:00
parent 37a2501a79
commit 144f8724d6
3 changed files with 46 additions and 4 deletions

View File

@@ -24,5 +24,7 @@ def get_prefix_causal_mask_from_padding_mask(seq_len:int, prefix_mask, att_heads
prefix_causal_mask = prefix_causal_mask.repeat_interleave(att_heads, dim=0) # B*H,T,T
return prefix_causal_mask
#def get_prefix_causal_mask():
# continue_rdf =