From 44307cd9177fd27879aa70b33264a673e72e9159 Mon Sep 17 00:00:00 2001 From: Christian Risi <75698846+CnF-Gris@users.noreply.github.com> Date: Mon, 6 Oct 2025 17:29:05 +0200 Subject: [PATCH] Added util to create padding mask --- .../Libs/Transformer/Utils/__init__.py | 3 ++- .../Transformer/Utils/post_tokenization.py | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/Project_Model/Libs/Transformer/Utils/__init__.py b/Project_Model/Libs/Transformer/Utils/__init__.py index b12ee89..0c44c83 100644 --- a/Project_Model/Libs/Transformer/Utils/__init__.py +++ b/Project_Model/Libs/Transformer/Utils/__init__.py @@ -1,6 +1,6 @@ from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_batched from .task_type import TaskType -from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence +from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence, create_padding_mask from .inference_masking import inference_masking from .truncate_rdf_list import truncate_rdf_list @@ -10,6 +10,7 @@ __all__ = [ "get_causal_attention_mask_batched", "truncate_sequence", "pad_sequence", + "create_padding_mask", "normalize_sequence", "inference_masking", "truncate_rdf_list" diff --git a/Project_Model/Libs/Transformer/Utils/post_tokenization.py b/Project_Model/Libs/Transformer/Utils/post_tokenization.py index 5b4ae89..fc68363 100644 --- a/Project_Model/Libs/Transformer/Utils/post_tokenization.py +++ b/Project_Model/Libs/Transformer/Utils/post_tokenization.py @@ -25,15 +25,29 @@ def pad_sequence(sequence: list[int], pad_until: int, pad_token: int) -> list[in return PADDED_SEQUENCE +def create_padding_mask(sequence: list[int], pad_token: int) -> list[bool]: + + PADDING_MASK = [False] * len(sequence) + + for i in range(0, len(sequence)): + + if sequence[i] != pad_token: + continue + + PADDING_MASK[i] = True + + return PADDING_MASK + def normalize_sequence( sequence: list[int], max_length: int, pad_token: int, end_token: int, -) -> list[int]: +) -> tuple[list[int], list[bool]]: new_sequence = pad_sequence(sequence, max_length, pad_token) new_sequence = truncate_sequence(sequence, max_length, end_token) + PADDING_MASK = create_padding_mask(sequence, pad_token) - return new_sequence + return (new_sequence, PADDING_MASK)