added logistic collector
This commit is contained in:
parent
d2fdeb18a2
commit
2036b4015f
42
Project_Model/Libs/Training/logistic_collector.py
Normal file
42
Project_Model/Libs/Training/logistic_collector.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class LogitsCollector:
|
||||||
|
def __init__(self, pad_token: int, end_token: int, tokenizer) -> None:
|
||||||
|
self.__pad_token = pad_token # used to skip PAD
|
||||||
|
self.__end_token = end_token # used to stop at END
|
||||||
|
self.__tokenizer = tokenizer # exposes .decode(list[int]) -> str
|
||||||
|
self.__steps: list[torch.Tensor] = [] # list of per-step logits [B,V]
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.__steps.clear() # clear history
|
||||||
|
|
||||||
|
def add(self, logits_step: torch.Tensor) -> None:
|
||||||
|
if logits_step.dim() == 3: # handle [B,1,V]
|
||||||
|
logits_step = logits_step[:, -1, :] # -> [B,V]
|
||||||
|
self.__steps.append(logits_step.detach()) # store raw logits (detached)
|
||||||
|
|
||||||
|
def tokens(self) -> list[list[int]]:
|
||||||
|
if not self.__steps:
|
||||||
|
return []
|
||||||
|
stack = torch.stack(self.__steps, dim=0) # [T,B,V]
|
||||||
|
probs = torch.softmax(stack, dim=-1) # softmax over vocab -> [T,B,V]
|
||||||
|
ids = probs.argmax(dim=-1).transpose(0, 1) # greedy ids -> [B,T]
|
||||||
|
out: list[list[int]] = []
|
||||||
|
for row in ids.tolist():
|
||||||
|
seq: list[int] = []
|
||||||
|
for tok in row:
|
||||||
|
if tok == self.__end_token: # stop on END
|
||||||
|
break
|
||||||
|
if tok == self.__pad_token: # skip PAD
|
||||||
|
continue
|
||||||
|
seq.append(tok)
|
||||||
|
out.append(seq)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def print_decoded(self) -> None:
|
||||||
|
for i, seq in enumerate(self.tokens()):
|
||||||
|
try:
|
||||||
|
text = self.__tokenizer.decode(seq) # decode tokens to string
|
||||||
|
except Exception:
|
||||||
|
text = str(seq) # fallback to ids
|
||||||
|
print(f"[{i}] {text}") # simple print
|
||||||
Loading…
x
Reference in New Issue
Block a user