Added post tokenization utils
This commit is contained in:
parent
ee8e56798c
commit
9c1043e0ba
39
Project_Model/Libs/Transformer/Utils/post_tokenization.py
Normal file
39
Project_Model/Libs/Transformer/Utils/post_tokenization.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
def truncate_sequence(
|
||||||
|
sequence: list[int], truncate_at: int, end_token: int
|
||||||
|
) -> list[int]:
|
||||||
|
|
||||||
|
if len(sequence) < truncate_at:
|
||||||
|
sequence[-1] = end_token
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
TRUNCATED_SEQUENCE = sequence[:truncate_at]
|
||||||
|
TRUNCATED_SEQUENCE[-1] = end_token
|
||||||
|
|
||||||
|
return TRUNCATED_SEQUENCE
|
||||||
|
|
||||||
|
|
||||||
|
def pad_sequence(sequence: list[int], pad_until: int, pad_token: int) -> list[int]:
|
||||||
|
|
||||||
|
if not (len(sequence) < pad_until):
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
NUM_OF_PADDINGS = pad_until - len(sequence)
|
||||||
|
PADDINGS = [pad_token] * NUM_OF_PADDINGS
|
||||||
|
|
||||||
|
PADDED_SEQUENCE = sequence[:]
|
||||||
|
PADDED_SEQUENCE.extend(PADDINGS)
|
||||||
|
|
||||||
|
return PADDED_SEQUENCE
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_sequence(
|
||||||
|
sequence: list[int],
|
||||||
|
max_length: int,
|
||||||
|
pad_token: int,
|
||||||
|
end_token: int,
|
||||||
|
) -> list[int]:
|
||||||
|
|
||||||
|
new_sequence = pad_sequence(sequence, max_length, pad_token)
|
||||||
|
new_sequence = truncate_sequence(sequence, max_length, end_token)
|
||||||
|
|
||||||
|
return new_sequence
|
||||||
Loading…
x
Reference in New Issue
Block a user