diff --git a/Project_Model/Libs/Training/logistic_collector.py b/Project_Model/Libs/Training/logistic_collector.py index 2e1ad36..15db97d 100644 --- a/Project_Model/Libs/Training/logistic_collector.py +++ b/Project_Model/Libs/Training/logistic_collector.py @@ -25,8 +25,8 @@ class LogitsCollector: for row in ids.tolist(): seq: list[int] = [] for tok in row: - if tok == self.__end_token: # stop on END - break + # if tok == self.__end_token: # stop on END + # break if tok == self.__pad_token: # skip PAD continue seq.append(tok) @@ -36,6 +36,7 @@ class LogitsCollector: def print_decoded(self) -> None: for i, seq in enumerate(self.tokens()): try: + # text = text + self.__end_token text = self.__tokenizer.decode(seq) # decode tokens to string except Exception: text = str(seq) # fallback to ids