Added util to create padding mask
This commit is contained in:
parent
ffdb312d58
commit
44307cd917
@ -1,6 +1,6 @@
|
|||||||
from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_batched
|
from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_batched
|
||||||
from .task_type import TaskType
|
from .task_type import TaskType
|
||||||
from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence
|
from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence, create_padding_mask
|
||||||
from .inference_masking import inference_masking
|
from .inference_masking import inference_masking
|
||||||
from .truncate_rdf_list import truncate_rdf_list
|
from .truncate_rdf_list import truncate_rdf_list
|
||||||
|
|
||||||
@ -10,6 +10,7 @@ __all__ = [
|
|||||||
"get_causal_attention_mask_batched",
|
"get_causal_attention_mask_batched",
|
||||||
"truncate_sequence",
|
"truncate_sequence",
|
||||||
"pad_sequence",
|
"pad_sequence",
|
||||||
|
"create_padding_mask",
|
||||||
"normalize_sequence",
|
"normalize_sequence",
|
||||||
"inference_masking",
|
"inference_masking",
|
||||||
"truncate_rdf_list"
|
"truncate_rdf_list"
|
||||||
|
|||||||
@ -25,15 +25,29 @@ def pad_sequence(sequence: list[int], pad_until: int, pad_token: int) -> list[in
|
|||||||
|
|
||||||
return PADDED_SEQUENCE
|
return PADDED_SEQUENCE
|
||||||
|
|
||||||
|
def create_padding_mask(sequence: list[int], pad_token: int) -> list[bool]:
|
||||||
|
|
||||||
|
PADDING_MASK = [False] * len(sequence)
|
||||||
|
|
||||||
|
for i in range(0, len(sequence)):
|
||||||
|
|
||||||
|
if sequence[i] != pad_token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
PADDING_MASK[i] = True
|
||||||
|
|
||||||
|
return PADDING_MASK
|
||||||
|
|
||||||
|
|
||||||
def normalize_sequence(
|
def normalize_sequence(
|
||||||
sequence: list[int],
|
sequence: list[int],
|
||||||
max_length: int,
|
max_length: int,
|
||||||
pad_token: int,
|
pad_token: int,
|
||||||
end_token: int,
|
end_token: int,
|
||||||
) -> list[int]:
|
) -> tuple[list[int], list[bool]]:
|
||||||
|
|
||||||
new_sequence = pad_sequence(sequence, max_length, pad_token)
|
new_sequence = pad_sequence(sequence, max_length, pad_token)
|
||||||
new_sequence = truncate_sequence(sequence, max_length, end_token)
|
new_sequence = truncate_sequence(sequence, max_length, end_token)
|
||||||
|
PADDING_MASK = create_padding_mask(sequence, pad_token)
|
||||||
|
|
||||||
return new_sequence
|
return (new_sequence, PADDING_MASK)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user