This commit is contained in:
GassiGiuseppe 2025-10-10 22:26:06 +02:00
parent 96610612fe
commit e76dbeb9a7

View File

@ -25,8 +25,8 @@ class LogitsCollector:
for row in ids.tolist(): for row in ids.tolist():
seq: list[int] = [] seq: list[int] = []
for tok in row: for tok in row:
if tok == self.__end_token: # stop on END # if tok == self.__end_token: # stop on END
break # break
if tok == self.__pad_token: # skip PAD if tok == self.__pad_token: # skip PAD
continue continue
seq.append(tok) seq.append(tok)
@ -36,6 +36,7 @@ class LogitsCollector:
def print_decoded(self) -> None: def print_decoded(self) -> None:
for i, seq in enumerate(self.tokens()): for i, seq in enumerate(self.tokens()):
try: try:
# text = text + self.__end_token
text = self.__tokenizer.decode(seq) # decode tokens to string text = self.__tokenizer.decode(seq) # decode tokens to string
except Exception: except Exception:
text = str(seq) # fallback to ids text = str(seq) # fallback to ids