Refactored to remove tokens that can't be compressed anymore
This commit is contained in:
parent
fbbe6226bb
commit
7cfaf601b4
@ -105,18 +105,29 @@ class NanoSocraTraineRam:
|
|||||||
return BPE
|
return BPE
|
||||||
|
|
||||||
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
|
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
|
||||||
|
|
||||||
DATA_LEN = len(data)
|
DATA_LEN = len(data)
|
||||||
|
NEW_DATA = []
|
||||||
|
|
||||||
|
counter = 0
|
||||||
memory = NanoSocratesBatchMemoryBPE({}, 0)
|
memory = NanoSocratesBatchMemoryBPE({}, 0)
|
||||||
for piece, index in zip(data, range(0, DATA_LEN)):
|
while len(data) > 0:
|
||||||
|
counter += 1
|
||||||
|
last_batch = len(data) == 1
|
||||||
|
|
||||||
last_batch = index == DATA_LEN - 1
|
piece = data.pop()
|
||||||
|
|
||||||
bpe, memory, output = bpe.fit(piece, memory, last_batch)
|
bpe, memory, output = bpe.fit(piece, memory, last_batch)
|
||||||
|
|
||||||
data[index] = output
|
if counter % int(1E6) == 0:
|
||||||
|
print(f"Fitted: {counter}/{DATA_LEN}")
|
||||||
|
|
||||||
return (bpe, data, memory)
|
if len(output) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
NEW_DATA.append(output)
|
||||||
|
|
||||||
|
return (bpe, NEW_DATA, memory)
|
||||||
|
|
||||||
def __gather_data_from_file(self, path: Path) -> list[list[int]]:
|
def __gather_data_from_file(self, path: Path) -> list[list[int]]:
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user