import torch class LogitsCollector: def __init__(self, pad_token: int, end_token: int, tokenizer) -> None: self.__pad_token = pad_token # used to skip PAD self.__end_token = end_token # used to stop at END self.__tokenizer = tokenizer # exposes .decode(list[int]) -> str self.__steps: list[torch.Tensor] = [] # list of per-step logits [B,V] def reset(self) -> None: self.__steps.clear() # clear history def add(self, logits_step: torch.Tensor) -> None: if logits_step.dim() == 3: # handle [B,1,V] logits_step = logits_step[:, -1, :] # -> [B,V] self.__steps.append(logits_step.detach()) # store raw logits (detached) def tokens(self) -> list[list[int]]: if not self.__steps: return [] stack = torch.stack(self.__steps, dim=0) # [T,B,V] probs = torch.softmax(stack, dim=-1) # softmax over vocab -> [T,B,V] ids = probs.argmax(dim=-1).transpose(0, 1) # greedy ids -> [B,T] out: list[list[int]] = [] for row in ids.tolist(): seq: list[int] = [] for tok in row: # if tok == self.__end_token: # stop on END # break if tok == self.__pad_token: # skip PAD continue seq.append(tok) out.append(seq) return out def print_decoded(self) -> None: for i, seq in enumerate(self.tokens()): try: # text = text + self.__end_token text = self.__tokenizer.decode(seq) # decode tokens to string except Exception: text = str(seq) # fallback to ids print(f"[{i}] {text}") # simple print