16 lines
442 B
Python
16 lines
442 B
Python
|
|
import torch
|
||
|
|
import Project_Model.Libs.BPE as BPE
|
||
|
|
|
||
|
|
def decode_batch(batch: torch.Tensor, tokenizer: BPE.TokeNanoCore ,uknonw_token: int) -> list[str]:
|
||
|
|
|
||
|
|
strings = []
|
||
|
|
|
||
|
|
BATCH, _ = batch.shape
|
||
|
|
|
||
|
|
for i in range(0, BATCH):
|
||
|
|
|
||
|
|
tokens: list[int] = batch.tolist()[i]
|
||
|
|
tokens = list(map(lambda x: uknonw_token if x > tokenizer.vocabulary_size else x, tokens))
|
||
|
|
strings.append(tokenizer.decode(tokens))
|
||
|
|
|
||
|
|
return strings
|