def truncate_sequence( sequence: list[int], truncate_at: int, end_token: int, add_ending: bool ) -> list[int]: if len(sequence) < truncate_at - 1: if add_ending: sequence.append(end_token) return sequence if len(sequence) < truncate_at: if add_ending: sequence[-1] = end_token return sequence TRUNCATED_SEQUENCE = sequence[:truncate_at] if add_ending: TRUNCATED_SEQUENCE[-1] = end_token return TRUNCATED_SEQUENCE def pad_sequence(sequence: list[int], pad_until: int, pad_token: int) -> list[int]: if not (len(sequence) < pad_until): return sequence NUM_OF_PADDINGS = pad_until - len(sequence) PADDINGS = [pad_token] * NUM_OF_PADDINGS PADDED_SEQUENCE = sequence[:] PADDED_SEQUENCE.extend(PADDINGS) return PADDED_SEQUENCE def create_padding_mask(sequence: list[int], pad_token: int) -> list[bool]: PADDING_MASK = [False] * len(sequence) for i in range(0, len(sequence)): if sequence[i] != pad_token: continue PADDING_MASK[i] = True return PADDING_MASK def normalize_sequence( sequence: list[int], max_length: int, pad_token: int, end_token: int, add_ending: bool = True ) -> tuple[list[int], list[bool]]: new_sequence = truncate_sequence(sequence, max_length, end_token, add_ending) new_sequence = pad_sequence(new_sequence, max_length, pad_token) PADDING_MASK = create_padding_mask(new_sequence, pad_token) return (new_sequence, PADDING_MASK)