diff --git a/Project_Model/Libs/Training/logistic_collector.py b/Project_Model/Libs/Training/logistic_collector.py new file mode 100644 index 0000000..2e1ad36 --- /dev/null +++ b/Project_Model/Libs/Training/logistic_collector.py @@ -0,0 +1,42 @@ +import torch + +class LogitsCollector: + def __init__(self, pad_token: int, end_token: int, tokenizer) -> None: + self.__pad_token = pad_token # used to skip PAD + self.__end_token = end_token # used to stop at END + self.__tokenizer = tokenizer # exposes .decode(list[int]) -> str + self.__steps: list[torch.Tensor] = [] # list of per-step logits [B,V] + + def reset(self) -> None: + self.__steps.clear() # clear history + + def add(self, logits_step: torch.Tensor) -> None: + if logits_step.dim() == 3: # handle [B,1,V] + logits_step = logits_step[:, -1, :] # -> [B,V] + self.__steps.append(logits_step.detach()) # store raw logits (detached) + + def tokens(self) -> list[list[int]]: + if not self.__steps: + return [] + stack = torch.stack(self.__steps, dim=0) # [T,B,V] + probs = torch.softmax(stack, dim=-1) # softmax over vocab -> [T,B,V] + ids = probs.argmax(dim=-1).transpose(0, 1) # greedy ids -> [B,T] + out: list[list[int]] = [] + for row in ids.tolist(): + seq: list[int] = [] + for tok in row: + if tok == self.__end_token: # stop on END + break + if tok == self.__pad_token: # skip PAD + continue + seq.append(tok) + out.append(seq) + return out + + def print_decoded(self) -> None: + for i, seq in enumerate(self.tokens()): + try: + text = self.__tokenizer.decode(seq) # decode tokens to string + except Exception: + text = str(seq) # fallback to ids + print(f"[{i}] {text}") # simple print