57 lines
1.4 KiB
Python
Raw Normal View History

2025-10-06 17:01:18 +02:00
def truncate_sequence(
sequence: list[int], truncate_at: int, end_token: int
) -> list[int]:
2025-10-08 12:34:09 +02:00
if len(sequence) < truncate_at - 1:
sequence.append(end_token)
return sequence
2025-10-06 17:01:18 +02:00
if len(sequence) < truncate_at:
sequence[-1] = end_token
return sequence
TRUNCATED_SEQUENCE = sequence[:truncate_at]
TRUNCATED_SEQUENCE[-1] = end_token
return TRUNCATED_SEQUENCE
def pad_sequence(sequence: list[int], pad_until: int, pad_token: int) -> list[int]:
if not (len(sequence) < pad_until):
return sequence
NUM_OF_PADDINGS = pad_until - len(sequence)
PADDINGS = [pad_token] * NUM_OF_PADDINGS
PADDED_SEQUENCE = sequence[:]
PADDED_SEQUENCE.extend(PADDINGS)
return PADDED_SEQUENCE
2025-10-06 17:29:05 +02:00
def create_padding_mask(sequence: list[int], pad_token: int) -> list[bool]:
PADDING_MASK = [False] * len(sequence)
for i in range(0, len(sequence)):
if sequence[i] != pad_token:
continue
PADDING_MASK[i] = True
return PADDING_MASK
2025-10-06 17:01:18 +02:00
def normalize_sequence(
sequence: list[int],
max_length: int,
pad_token: int,
end_token: int,
2025-10-06 17:29:05 +02:00
) -> tuple[list[int], list[bool]]:
2025-10-07 20:44:40 +02:00
new_sequence = truncate_sequence(sequence, max_length, end_token)
new_sequence = pad_sequence(new_sequence, max_length, pad_token)
PADDING_MASK = create_padding_mask(new_sequence, pad_token)
2025-10-06 17:01:18 +02:00
2025-10-06 17:29:05 +02:00
return (new_sequence, PADDING_MASK)