28 lines
630 B
Python
Raw Normal View History

from typing import Generator
import torch
def tensor2token(tensor: torch.Tensor, end_token: int) -> Generator[list[int]]:
if len(tensor.shape) < 1 or len(tensor.shape) > 2:
raise ValueError("Shape is not correct")
if len(tensor.shape) == 1:
token_list: list[int] = tensor.tolist()
token_list.append(end_token)
yield token_list
return
batch_len: int
batch_len, _ = tensor.shape
for i in range(batch_len):
smaller_tensor = tensor[i, :]
token_list: list[int] = smaller_tensor.tolist()
token_list.append(end_token)
yield token_list