Added a util to create truncated RDF lists
This commit is contained in:
parent
0007c38212
commit
ffdb312d58
@ -2,6 +2,7 @@ from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_
|
|||||||
from .task_type import TaskType
|
from .task_type import TaskType
|
||||||
from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence
|
from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence
|
||||||
from .inference_masking import inference_masking
|
from .inference_masking import inference_masking
|
||||||
|
from .truncate_rdf_list import truncate_rdf_list
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TaskType",
|
"TaskType",
|
||||||
@ -10,5 +11,6 @@ __all__ = [
|
|||||||
"truncate_sequence",
|
"truncate_sequence",
|
||||||
"pad_sequence",
|
"pad_sequence",
|
||||||
"normalize_sequence",
|
"normalize_sequence",
|
||||||
"inference_masking"
|
"inference_masking",
|
||||||
|
"truncate_rdf_list"
|
||||||
]
|
]
|
||||||
54
Project_Model/Libs/Transformer/Utils/truncate_rdf_list.py
Normal file
54
Project_Model/Libs/Transformer/Utils/truncate_rdf_list.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_rdf_list(
|
||||||
|
sequence: list[int],
|
||||||
|
truncation_probability: float,
|
||||||
|
continue_triple_token: int,
|
||||||
|
end_of_triple_token: int,
|
||||||
|
seed: int = random.randint(0, sys.maxsize),
|
||||||
|
) -> list[int]:
|
||||||
|
|
||||||
|
if truncation_probability < 0 or truncation_probability > 1:
|
||||||
|
raise ValueError("A probability must be between 0 and 1")
|
||||||
|
|
||||||
|
RNG = random.Random(seed)
|
||||||
|
|
||||||
|
END_OF_TRIPLES = []
|
||||||
|
|
||||||
|
for i in range(0, len(sequence)):
|
||||||
|
|
||||||
|
TOKEN = sequence[i]
|
||||||
|
if TOKEN != end_of_triple_token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
END_OF_TRIPLES.append(i + 1)
|
||||||
|
|
||||||
|
TRIPLES_TOKENS: list[int] = []
|
||||||
|
|
||||||
|
start_of_triple = 0
|
||||||
|
eot_index = 0
|
||||||
|
exit_loop = False
|
||||||
|
|
||||||
|
while not exit_loop:
|
||||||
|
|
||||||
|
EOT = END_OF_TRIPLES[eot_index]
|
||||||
|
|
||||||
|
TRIPLE = sequence[start_of_triple:EOT]
|
||||||
|
TRIPLES_TOKENS.extend(TRIPLE)
|
||||||
|
|
||||||
|
start_of_triple = EOT
|
||||||
|
|
||||||
|
if RNG.random() < truncation_probability:
|
||||||
|
exit_loop = True
|
||||||
|
|
||||||
|
if eot_index < len(END_OF_TRIPLES) - 2:
|
||||||
|
exit_loop = True
|
||||||
|
|
||||||
|
TRIPLES_TOKENS.append(
|
||||||
|
continue_triple_token
|
||||||
|
)
|
||||||
|
|
||||||
|
return TRIPLES_TOKENS
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user