Pipeline fix and added a util to decode

This commit is contained in:
Christian Risi
2025-10-09 13:24:48 +02:00
parent f3b83eda3d
commit aac7675b30
7 changed files with 78 additions and 29 deletions

View File

@@ -3,6 +3,7 @@ from .task_type import TaskType
from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequence, create_padding_mask
from .inference_masking import inference_masking
from .truncate_rdf_list import truncate_rdf_list
from .decode_out import tensor2token
__all__ = [
"TaskType",
@@ -13,5 +14,6 @@ __all__ = [
"create_padding_mask",
"normalize_sequence",
"inference_masking",
"truncate_rdf_list"
"truncate_rdf_list",
"tensor2token"
]

View File

@@ -0,0 +1,27 @@
from typing import Generator
import torch
def tensor2token(tensor: torch.Tensor, end_token: int) -> Generator[list[int]]:
if len(tensor.shape) < 1 or len(tensor.shape) > 2:
raise ValueError("Shape is not correct")
if len(tensor.shape) == 1:
token_list: list[int] = tensor.tolist()
token_list.append(end_token)
yield token_list
return
batch_len: int
batch_len, _ = tensor.shape
for i in range(batch_len):
smaller_tensor = tensor[i, :]
token_list: list[int] = smaller_tensor.tolist()
token_list.append(end_token)
yield token_list

View File

@@ -1,17 +1,20 @@
def truncate_sequence(
sequence: list[int], truncate_at: int, end_token: int
sequence: list[int], truncate_at: int, end_token: int, add_ending: bool
) -> list[int]:
if len(sequence) < truncate_at - 1:
sequence.append(end_token)
if add_ending:
sequence.append(end_token)
return sequence
if len(sequence) < truncate_at:
sequence[-1] = end_token
if add_ending:
sequence[-1] = end_token
return sequence
TRUNCATED_SEQUENCE = sequence[:truncate_at]
TRUNCATED_SEQUENCE[-1] = end_token
if add_ending:
TRUNCATED_SEQUENCE[-1] = end_token
return TRUNCATED_SEQUENCE
@@ -48,8 +51,9 @@ def normalize_sequence(
max_length: int,
pad_token: int,
end_token: int,
add_ending: bool = True
) -> tuple[list[int], list[bool]]:
new_sequence = truncate_sequence(sequence, max_length, end_token)
new_sequence = truncate_sequence(sequence, max_length, end_token, add_ending)
new_sequence = pad_sequence(new_sequence, max_length, pad_token)
PADDING_MASK = create_padding_mask(new_sequence, pad_token)