From ffdb312d589aa7613fdebdd30b98f7119e1899ac Mon Sep 17 00:00:00 2001 From: Christian Risi <75698846+CnF-Gris@users.noreply.github.com> Date: Mon, 6 Oct 2025 17:22:13 +0200 Subject: [PATCH] Added a util to create truncated RDF lists --- .../Libs/Transformer/Utils/__init__.py | 4 +- .../Transformer/Utils/truncate_rdf_list.py | 54 +++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 Project_Model/Libs/Transformer/Utils/truncate_rdf_list.py diff --git a/Project_Model/Libs/Transformer/Utils/__init__.py b/Project_Model/Libs/Transformer/Utils/__init__.py index fab00b0..b12ee89 100644 --- a/Project_Model/Libs/Transformer/Utils/__init__.py +++ b/Project_Model/Libs/Transformer/Utils/__init__.py @@ -2,6 +2,7 @@ from .attention_mask import get_causal_attention_mask,get_causal_attention_mask_ from .task_type import TaskType from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence from .inference_masking import inference_masking +from .truncate_rdf_list import truncate_rdf_list __all__ = [ "TaskType", @@ -10,5 +11,6 @@ __all__ = [ "truncate_sequence", "pad_sequence", "normalize_sequence", - "inference_masking" + "inference_masking", + "truncate_rdf_list" ] \ No newline at end of file diff --git a/Project_Model/Libs/Transformer/Utils/truncate_rdf_list.py b/Project_Model/Libs/Transformer/Utils/truncate_rdf_list.py new file mode 100644 index 0000000..cf302a8 --- /dev/null +++ b/Project_Model/Libs/Transformer/Utils/truncate_rdf_list.py @@ -0,0 +1,54 @@ +import random +import sys + + +def truncate_rdf_list( + sequence: list[int], + truncation_probability: float, + continue_triple_token: int, + end_of_triple_token: int, + seed: int = random.randint(0, sys.maxsize), +) -> list[int]: + + if truncation_probability < 0 or truncation_probability > 1: + raise ValueError("A probability must be between 0 and 1") + + RNG = random.Random(seed) + + END_OF_TRIPLES = [] + + for i in range(0, len(sequence)): + + TOKEN = sequence[i] + if TOKEN != end_of_triple_token: + continue + + END_OF_TRIPLES.append(i + 1) + + TRIPLES_TOKENS: list[int] = [] + + start_of_triple = 0 + eot_index = 0 + exit_loop = False + + while not exit_loop: + + EOT = END_OF_TRIPLES[eot_index] + + TRIPLE = sequence[start_of_triple:EOT] + TRIPLES_TOKENS.extend(TRIPLE) + + start_of_triple = EOT + + if RNG.random() < truncation_probability: + exit_loop = True + + if eot_index < len(END_OF_TRIPLES) - 2: + exit_loop = True + + TRIPLES_TOKENS.append( + continue_triple_token + ) + + return TRIPLES_TOKENS +