typo
This commit is contained in:
parent
96610612fe
commit
e76dbeb9a7
@ -25,8 +25,8 @@ class LogitsCollector:
|
|||||||
for row in ids.tolist():
|
for row in ids.tolist():
|
||||||
seq: list[int] = []
|
seq: list[int] = []
|
||||||
for tok in row:
|
for tok in row:
|
||||||
if tok == self.__end_token: # stop on END
|
# if tok == self.__end_token: # stop on END
|
||||||
break
|
# break
|
||||||
if tok == self.__pad_token: # skip PAD
|
if tok == self.__pad_token: # skip PAD
|
||||||
continue
|
continue
|
||||||
seq.append(tok)
|
seq.append(tok)
|
||||||
@ -36,6 +36,7 @@ class LogitsCollector:
|
|||||||
def print_decoded(self) -> None:
|
def print_decoded(self) -> None:
|
||||||
for i, seq in enumerate(self.tokens()):
|
for i, seq in enumerate(self.tokens()):
|
||||||
try:
|
try:
|
||||||
|
# text = text + self.__end_token
|
||||||
text = self.__tokenizer.decode(seq) # decode tokens to string
|
text = self.__tokenizer.decode(seq) # decode tokens to string
|
||||||
except Exception:
|
except Exception:
|
||||||
text = str(seq) # fallback to ids
|
text = str(seq) # fallback to ids
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user