Added util to create padding mask

This commit is contained in:
Christian Risi
2025-10-06 17:29:05 +02:00
parent ffdb312d58
commit 44307cd917
2 changed files with 18 additions and 3 deletions

View File

@@ -25,15 +25,29 @@ def pad_sequence(sequence: list[int], pad_until: int, pad_token: int) -> list[in
return PADDED_SEQUENCE
def create_padding_mask(sequence: list[int], pad_token: int) -> list[bool]:
PADDING_MASK = [False] * len(sequence)
for i in range(0, len(sequence)):
if sequence[i] != pad_token:
continue
PADDING_MASK[i] = True
return PADDING_MASK
def normalize_sequence(
sequence: list[int],
max_length: int,
pad_token: int,
end_token: int,
) -> list[int]:
) -> tuple[list[int], list[bool]]:
new_sequence = pad_sequence(sequence, max_length, pad_token)
new_sequence = truncate_sequence(sequence, max_length, end_token)
PADDING_MASK = create_padding_mask(sequence, pad_token)
return new_sequence
return (new_sequence, PADDING_MASK)