Refactored to remove tokens that can't be compressed anymore

This commit is contained in:
Christian Risi 2025-10-01 19:42:22 +02:00
parent fbbe6226bb
commit 7cfaf601b4

View File

@ -105,18 +105,29 @@ class NanoSocraTraineRam:
return BPE
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
DATA_LEN = len(data)
NEW_DATA = []
counter = 0
memory = NanoSocratesBatchMemoryBPE({}, 0)
for piece, index in zip(data, range(0, DATA_LEN)):
while len(data) > 0:
counter += 1
last_batch = len(data) == 1
last_batch = index == DATA_LEN - 1
piece = data.pop()
bpe, memory, output = bpe.fit(piece, memory, last_batch)
data[index] = output
if counter % int(1E6) == 0:
print(f"Fitted: {counter}/{DATA_LEN}")
return (bpe, data, memory)
if len(output) < 2:
continue
NEW_DATA.append(output)
return (bpe, NEW_DATA, memory)
def __gather_data_from_file(self, path: Path) -> list[list[int]]: