Fixed a bug about sequence normalizations

This commit is contained in:
Christian Risi 2025-10-07 16:37:43 +02:00
parent fdece42462
commit b97282179d

View File

@ -47,7 +47,7 @@ def normalize_sequence(
) -> tuple[list[int], list[bool]]: ) -> tuple[list[int], list[bool]]:
new_sequence = pad_sequence(sequence, max_length, pad_token) new_sequence = pad_sequence(sequence, max_length, pad_token)
new_sequence = truncate_sequence(sequence, max_length, end_token) new_sequence = truncate_sequence(new_sequence, max_length, end_token)
PADDING_MASK = create_padding_mask(sequence, pad_token) PADDING_MASK = create_padding_mask(new_sequence, pad_token)
return (new_sequence, PADDING_MASK) return (new_sequence, PADDING_MASK)