diff --git a/Project_Model/Libs/Transformer/Utils/post_tokenization.py b/Project_Model/Libs/Transformer/Utils/post_tokenization.py new file mode 100644 index 0000000..5b4ae89 --- /dev/null +++ b/Project_Model/Libs/Transformer/Utils/post_tokenization.py @@ -0,0 +1,39 @@ +def truncate_sequence( + sequence: list[int], truncate_at: int, end_token: int +) -> list[int]: + + if len(sequence) < truncate_at: + sequence[-1] = end_token + return sequence + + TRUNCATED_SEQUENCE = sequence[:truncate_at] + TRUNCATED_SEQUENCE[-1] = end_token + + return TRUNCATED_SEQUENCE + + +def pad_sequence(sequence: list[int], pad_until: int, pad_token: int) -> list[int]: + + if not (len(sequence) < pad_until): + return sequence + + NUM_OF_PADDINGS = pad_until - len(sequence) + PADDINGS = [pad_token] * NUM_OF_PADDINGS + + PADDED_SEQUENCE = sequence[:] + PADDED_SEQUENCE.extend(PADDINGS) + + return PADDED_SEQUENCE + + +def normalize_sequence( + sequence: list[int], + max_length: int, + pad_token: int, + end_token: int, +) -> list[int]: + + new_sequence = pad_sequence(sequence, max_length, pad_token) + new_sequence = truncate_sequence(sequence, max_length, end_token) + + return new_sequence