Pipeline fix and added a util to decode
This commit is contained in:
27
Project_Model/Libs/Transformer/Utils/decode_out.py
Normal file
27
Project_Model/Libs/Transformer/Utils/decode_out.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Generator
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def tensor2token(tensor: torch.Tensor, end_token: int) -> Generator[list[int]]:
|
||||
|
||||
if len(tensor.shape) < 1 or len(tensor.shape) > 2:
|
||||
raise ValueError("Shape is not correct")
|
||||
|
||||
if len(tensor.shape) == 1:
|
||||
token_list: list[int] = tensor.tolist()
|
||||
token_list.append(end_token)
|
||||
yield token_list
|
||||
return
|
||||
|
||||
batch_len: int
|
||||
batch_len, _ = tensor.shape
|
||||
|
||||
for i in range(batch_len):
|
||||
|
||||
smaller_tensor = tensor[i, :]
|
||||
token_list: list[int] = smaller_tensor.tolist()
|
||||
token_list.append(end_token)
|
||||
yield token_list
|
||||
|
||||
|
||||
Reference in New Issue
Block a user