Added util to create padding mask

This commit is contained in:
Christian Risi 2025-10-06 17:29:05 +02:00
parent ffdb312d58
commit 44307cd917
2 changed files with 18 additions and 3 deletions

View File

@ -1,6 +1,6 @@
from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_batched from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_batched
from .task_type import TaskType from .task_type import TaskType
from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence, create_padding_mask
from .inference_masking import inference_masking from .inference_masking import inference_masking
from .truncate_rdf_list import truncate_rdf_list from .truncate_rdf_list import truncate_rdf_list
@ -10,6 +10,7 @@ __all__ = [
"get_causal_attention_mask_batched", "get_causal_attention_mask_batched",
"truncate_sequence", "truncate_sequence",
"pad_sequence", "pad_sequence",
"create_padding_mask",
"normalize_sequence", "normalize_sequence",
"inference_masking", "inference_masking",
"truncate_rdf_list" "truncate_rdf_list"

View File

@ -25,15 +25,29 @@ def pad_sequence(sequence: list[int], pad_until: int, pad_token: int) -> list[in
return PADDED_SEQUENCE return PADDED_SEQUENCE
def create_padding_mask(sequence: list[int], pad_token: int) -> list[bool]:
PADDING_MASK = [False] * len(sequence)
for i in range(0, len(sequence)):
if sequence[i] != pad_token:
continue
PADDING_MASK[i] = True
return PADDING_MASK
def normalize_sequence( def normalize_sequence(
sequence: list[int], sequence: list[int],
max_length: int, max_length: int,
pad_token: int, pad_token: int,
end_token: int, end_token: int,
) -> list[int]: ) -> tuple[list[int], list[bool]]:
new_sequence = pad_sequence(sequence, max_length, pad_token) new_sequence = pad_sequence(sequence, max_length, pad_token)
new_sequence = truncate_sequence(sequence, max_length, end_token) new_sequence = truncate_sequence(sequence, max_length, end_token)
PADDING_MASK = create_padding_mask(sequence, pad_token)
return new_sequence return (new_sequence, PADDING_MASK)