Added a util to create truncated RDF lists

This commit is contained in:
Christian Risi 2025-10-06 17:22:13 +02:00
parent 0007c38212
commit ffdb312d58
2 changed files with 57 additions and 1 deletions

View File

@ -2,6 +2,7 @@ from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_
from .task_type import TaskType from .task_type import TaskType
from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence
from .inference_masking import inference_masking from .inference_masking import inference_masking
from .truncate_rdf_list import truncate_rdf_list
__all__ = [ __all__ = [
"TaskType", "TaskType",
@ -10,5 +11,6 @@ __all__ = [
"truncate_sequence", "truncate_sequence",
"pad_sequence", "pad_sequence",
"normalize_sequence", "normalize_sequence",
"inference_masking" "inference_masking",
"truncate_rdf_list"
] ]

View File

@ -0,0 +1,54 @@
import random
import sys
def truncate_rdf_list(
sequence: list[int],
truncation_probability: float,
continue_triple_token: int,
end_of_triple_token: int,
seed: int = random.randint(0, sys.maxsize),
) -> list[int]:
if truncation_probability < 0 or truncation_probability > 1:
raise ValueError("A probability must be between 0 and 1")
RNG = random.Random(seed)
END_OF_TRIPLES = []
for i in range(0, len(sequence)):
TOKEN = sequence[i]
if TOKEN != end_of_triple_token:
continue
END_OF_TRIPLES.append(i + 1)
TRIPLES_TOKENS: list[int] = []
start_of_triple = 0
eot_index = 0
exit_loop = False
while not exit_loop:
EOT = END_OF_TRIPLES[eot_index]
TRIPLE = sequence[start_of_triple:EOT]
TRIPLES_TOKENS.extend(TRIPLE)
start_of_triple = EOT
if RNG.random() < truncation_probability:
exit_loop = True
if eot_index < len(END_OF_TRIPLES) - 2:
exit_loop = True
TRIPLES_TOKENS.append(
continue_triple_token
)
return TRIPLES_TOKENS