from typing import Generator import torch def tensor2token(tensor: torch.Tensor, end_token: int) -> Generator[list[int]]: if len(tensor.shape) < 1 or len(tensor.shape) > 2: raise ValueError("Shape is not correct") if len(tensor.shape) == 1: token_list: list[int] = tensor.tolist() token_list.append(end_token) yield token_list return batch_len: int batch_len, _ = tensor.shape for i in range(batch_len): smaller_tensor = tensor[i, :] token_list: list[int] = smaller_tensor.tolist() token_list.append(end_token) yield token_list