Pipeline fix and added a util to decode

This commit is contained in:
Christian Risi
2025-10-09 13:24:48 +02:00
parent f3b83eda3d
commit aac7675b30
7 changed files with 78 additions and 29 deletions

View File

@@ -0,0 +1,27 @@
from typing import Generator
import torch
def tensor2token(tensor: torch.Tensor, end_token: int) -> Generator[list[int]]:
if len(tensor.shape) < 1 or len(tensor.shape) > 2:
raise ValueError("Shape is not correct")
if len(tensor.shape) == 1:
token_list: list[int] = tensor.tolist()
token_list.append(end_token)
yield token_list
return
batch_len: int
batch_len, _ = tensor.shape
for i in range(batch_len):
smaller_tensor = tensor[i, :]
token_list: list[int] = smaller_tensor.tolist()
token_list.append(end_token)
yield token_list