Refactored to remove tokens that can't be compressed anymore
This commit is contained in:
parent
fbbe6226bb
commit
7cfaf601b4
@ -105,18 +105,29 @@ class NanoSocraTraineRam:
|
||||
return BPE
|
||||
|
||||
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
|
||||
|
||||
DATA_LEN = len(data)
|
||||
NEW_DATA = []
|
||||
|
||||
counter = 0
|
||||
memory = NanoSocratesBatchMemoryBPE({}, 0)
|
||||
for piece, index in zip(data, range(0, DATA_LEN)):
|
||||
while len(data) > 0:
|
||||
counter += 1
|
||||
last_batch = len(data) == 1
|
||||
|
||||
last_batch = index == DATA_LEN - 1
|
||||
piece = data.pop()
|
||||
|
||||
bpe, memory, output = bpe.fit(piece, memory, last_batch)
|
||||
|
||||
data[index] = output
|
||||
if counter % int(1E6) == 0:
|
||||
print(f"Fitted: {counter}/{DATA_LEN}")
|
||||
|
||||
return (bpe, data, memory)
|
||||
if len(output) < 2:
|
||||
continue
|
||||
|
||||
NEW_DATA.append(output)
|
||||
|
||||
return (bpe, NEW_DATA, memory)
|
||||
|
||||
def __gather_data_from_file(self, path: Path) -> list[list[int]]:
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user