Pipeline fix and added a util to decode
This commit is contained in:
@@ -1,17 +1,20 @@
|
||||
def truncate_sequence(
|
||||
sequence: list[int], truncate_at: int, end_token: int
|
||||
sequence: list[int], truncate_at: int, end_token: int, add_ending: bool
|
||||
) -> list[int]:
|
||||
|
||||
if len(sequence) < truncate_at - 1:
|
||||
sequence.append(end_token)
|
||||
if add_ending:
|
||||
sequence.append(end_token)
|
||||
return sequence
|
||||
|
||||
if len(sequence) < truncate_at:
|
||||
sequence[-1] = end_token
|
||||
if add_ending:
|
||||
sequence[-1] = end_token
|
||||
return sequence
|
||||
|
||||
TRUNCATED_SEQUENCE = sequence[:truncate_at]
|
||||
TRUNCATED_SEQUENCE[-1] = end_token
|
||||
if add_ending:
|
||||
TRUNCATED_SEQUENCE[-1] = end_token
|
||||
|
||||
return TRUNCATED_SEQUENCE
|
||||
|
||||
@@ -48,8 +51,9 @@ def normalize_sequence(
|
||||
max_length: int,
|
||||
pad_token: int,
|
||||
end_token: int,
|
||||
add_ending: bool = True
|
||||
) -> tuple[list[int], list[bool]]:
|
||||
new_sequence = truncate_sequence(sequence, max_length, end_token)
|
||||
new_sequence = truncate_sequence(sequence, max_length, end_token, add_ending)
|
||||
new_sequence = pad_sequence(new_sequence, max_length, pad_token)
|
||||
PADDING_MASK = create_padding_mask(new_sequence, pad_token)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user