28 lines
630 B
Python
28 lines
630 B
Python
from typing import Generator
|
|
|
|
import torch
|
|
|
|
|
|
def tensor2token(tensor: torch.Tensor, end_token: int) -> Generator[list[int]]:
|
|
|
|
if len(tensor.shape) < 1 or len(tensor.shape) > 2:
|
|
raise ValueError("Shape is not correct")
|
|
|
|
if len(tensor.shape) == 1:
|
|
token_list: list[int] = tensor.tolist()
|
|
token_list.append(end_token)
|
|
yield token_list
|
|
return
|
|
|
|
batch_len: int
|
|
batch_len, _ = tensor.shape
|
|
|
|
for i in range(batch_len):
|
|
|
|
smaller_tensor = tensor[i, :]
|
|
token_list: list[int] = smaller_tensor.tolist()
|
|
token_list.append(end_token)
|
|
yield token_list
|
|
|
|
|