Added capability of returning target after truncating

This commit is contained in:
Christian Risi 2025-10-06 17:43:01 +02:00
parent 44307cd917
commit 456ce724fe

View File

@ -1,3 +1,4 @@
from collections import deque
import random
import sys
@ -8,14 +9,14 @@ def truncate_rdf_list(
continue_triple_token: int,
end_of_triple_token: int,
seed: int = random.randint(0, sys.maxsize),
) -> list[int]:
) -> tuple[list[int], list[int]]:
if truncation_probability < 0 or truncation_probability > 1:
raise ValueError("A probability must be between 0 and 1")
RNG = random.Random(seed)
END_OF_TRIPLES = []
END_OF_TRIPLES: deque[int] = deque()
for i in range(0, len(sequence)):
@ -26,14 +27,14 @@ def truncate_rdf_list(
END_OF_TRIPLES.append(i + 1)
TRIPLES_TOKENS: list[int] = []
TARGET_TRIPLES: list[int] = []
start_of_triple = 0
eot_index = 0
exit_loop = False
while not exit_loop:
EOT = END_OF_TRIPLES[eot_index]
EOT = END_OF_TRIPLES.popleft()
TRIPLE = sequence[start_of_triple:EOT]
TRIPLES_TOKENS.extend(TRIPLE)
@ -43,12 +44,22 @@ def truncate_rdf_list(
if RNG.random() < truncation_probability:
exit_loop = True
if eot_index < len(END_OF_TRIPLES) - 2:
if len(END_OF_TRIPLES) == 1:
exit_loop = True
TRIPLES_TOKENS.append(
continue_triple_token
)
return TRIPLES_TOKENS
while len(END_OF_TRIPLES) > 0:
EOT = END_OF_TRIPLES.popleft()
TRIPLE = sequence[start_of_triple:EOT]
TARGET_TRIPLES.extend(TRIPLE)
start_of_triple = EOT
return (TRIPLES_TOKENS, TARGET_TRIPLES)