Added capability of returning target after truncating
This commit is contained in:
parent
44307cd917
commit
456ce724fe
@ -1,3 +1,4 @@
|
|||||||
|
from collections import deque
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@ -8,14 +9,14 @@ def truncate_rdf_list(
|
|||||||
continue_triple_token: int,
|
continue_triple_token: int,
|
||||||
end_of_triple_token: int,
|
end_of_triple_token: int,
|
||||||
seed: int = random.randint(0, sys.maxsize),
|
seed: int = random.randint(0, sys.maxsize),
|
||||||
) -> list[int]:
|
) -> tuple[list[int], list[int]]:
|
||||||
|
|
||||||
if truncation_probability < 0 or truncation_probability > 1:
|
if truncation_probability < 0 or truncation_probability > 1:
|
||||||
raise ValueError("A probability must be between 0 and 1")
|
raise ValueError("A probability must be between 0 and 1")
|
||||||
|
|
||||||
RNG = random.Random(seed)
|
RNG = random.Random(seed)
|
||||||
|
|
||||||
END_OF_TRIPLES = []
|
END_OF_TRIPLES: deque[int] = deque()
|
||||||
|
|
||||||
for i in range(0, len(sequence)):
|
for i in range(0, len(sequence)):
|
||||||
|
|
||||||
@ -26,14 +27,14 @@ def truncate_rdf_list(
|
|||||||
END_OF_TRIPLES.append(i + 1)
|
END_OF_TRIPLES.append(i + 1)
|
||||||
|
|
||||||
TRIPLES_TOKENS: list[int] = []
|
TRIPLES_TOKENS: list[int] = []
|
||||||
|
TARGET_TRIPLES: list[int] = []
|
||||||
|
|
||||||
start_of_triple = 0
|
start_of_triple = 0
|
||||||
eot_index = 0
|
|
||||||
exit_loop = False
|
exit_loop = False
|
||||||
|
|
||||||
while not exit_loop:
|
while not exit_loop:
|
||||||
|
|
||||||
EOT = END_OF_TRIPLES[eot_index]
|
EOT = END_OF_TRIPLES.popleft()
|
||||||
|
|
||||||
TRIPLE = sequence[start_of_triple:EOT]
|
TRIPLE = sequence[start_of_triple:EOT]
|
||||||
TRIPLES_TOKENS.extend(TRIPLE)
|
TRIPLES_TOKENS.extend(TRIPLE)
|
||||||
@ -43,12 +44,22 @@ def truncate_rdf_list(
|
|||||||
if RNG.random() < truncation_probability:
|
if RNG.random() < truncation_probability:
|
||||||
exit_loop = True
|
exit_loop = True
|
||||||
|
|
||||||
if eot_index < len(END_OF_TRIPLES) - 2:
|
if len(END_OF_TRIPLES) == 1:
|
||||||
exit_loop = True
|
exit_loop = True
|
||||||
|
|
||||||
TRIPLES_TOKENS.append(
|
TRIPLES_TOKENS.append(
|
||||||
continue_triple_token
|
continue_triple_token
|
||||||
)
|
)
|
||||||
|
|
||||||
return TRIPLES_TOKENS
|
while len(END_OF_TRIPLES) > 0:
|
||||||
|
|
||||||
|
EOT = END_OF_TRIPLES.popleft()
|
||||||
|
|
||||||
|
TRIPLE = sequence[start_of_triple:EOT]
|
||||||
|
TARGET_TRIPLES.extend(TRIPLE)
|
||||||
|
|
||||||
|
start_of_triple = EOT
|
||||||
|
|
||||||
|
|
||||||
|
return (TRIPLES_TOKENS, TARGET_TRIPLES)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user