Added a util to create truncated RDF lists
This commit is contained in:
parent
0007c38212
commit
ffdb312d58
@ -2,6 +2,7 @@ from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_
|
||||
from .task_type import TaskType
|
||||
from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence
|
||||
from .inference_masking import inference_masking
|
||||
from .truncate_rdf_list import truncate_rdf_list
|
||||
|
||||
__all__ = [
|
||||
"TaskType",
|
||||
@ -10,5 +11,6 @@ __all__ = [
|
||||
"truncate_sequence",
|
||||
"pad_sequence",
|
||||
"normalize_sequence",
|
||||
"inference_masking"
|
||||
"inference_masking",
|
||||
"truncate_rdf_list"
|
||||
]
|
||||
54
Project_Model/Libs/Transformer/Utils/truncate_rdf_list.py
Normal file
54
Project_Model/Libs/Transformer/Utils/truncate_rdf_list.py
Normal file
@ -0,0 +1,54 @@
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
||||
def truncate_rdf_list(
|
||||
sequence: list[int],
|
||||
truncation_probability: float,
|
||||
continue_triple_token: int,
|
||||
end_of_triple_token: int,
|
||||
seed: int = random.randint(0, sys.maxsize),
|
||||
) -> list[int]:
|
||||
|
||||
if truncation_probability < 0 or truncation_probability > 1:
|
||||
raise ValueError("A probability must be between 0 and 1")
|
||||
|
||||
RNG = random.Random(seed)
|
||||
|
||||
END_OF_TRIPLES = []
|
||||
|
||||
for i in range(0, len(sequence)):
|
||||
|
||||
TOKEN = sequence[i]
|
||||
if TOKEN != end_of_triple_token:
|
||||
continue
|
||||
|
||||
END_OF_TRIPLES.append(i + 1)
|
||||
|
||||
TRIPLES_TOKENS: list[int] = []
|
||||
|
||||
start_of_triple = 0
|
||||
eot_index = 0
|
||||
exit_loop = False
|
||||
|
||||
while not exit_loop:
|
||||
|
||||
EOT = END_OF_TRIPLES[eot_index]
|
||||
|
||||
TRIPLE = sequence[start_of_triple:EOT]
|
||||
TRIPLES_TOKENS.extend(TRIPLE)
|
||||
|
||||
start_of_triple = EOT
|
||||
|
||||
if RNG.random() < truncation_probability:
|
||||
exit_loop = True
|
||||
|
||||
if eot_index < len(END_OF_TRIPLES) - 2:
|
||||
exit_loop = True
|
||||
|
||||
TRIPLES_TOKENS.append(
|
||||
continue_triple_token
|
||||
)
|
||||
|
||||
return TRIPLES_TOKENS
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user