Fixed a bug while joining frequencies

This commit is contained in:
Christian Risi 2025-10-02 01:50:37 +02:00
parent d19426fa62
commit 3fe4e45ceb

View File

@ -21,15 +21,17 @@ from ..Utils import (
load_json,
)
def split(a, n):
k, m = divmod(len(a), n)
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n))
def split_fit(object: tuple[NanoSocratesBPE, list[list[int]]]):
bpe, data = object
NEW_DATA: list[list[int]]= []
NEW_DATA: list[list[int]] = []
memory = NanoSocratesBatchMemoryBPE({}, 0)
@ -144,7 +146,7 @@ class NanoSocraTrainerPool:
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
NEW_DATA : list[list[int]] = []
NEW_DATA: list[list[int]] = []
MEMORY = NanoSocratesBatchMemoryBPE({}, 0)
@ -159,7 +161,9 @@ class NanoSocraTrainerPool:
data_chunks = split(data, CPU_COUNT)
JOBS = [(NanoSocratesBPE(VOCABULARY), chunk) for chunk in data_chunks]
JOB_RESULTS: list[tuple[NanoSocratesBPE, list[list[int]], NanoSocratesBatchMemoryBPE]]
JOB_RESULTS: list[
tuple[NanoSocratesBPE, list[list[int]], NanoSocratesBatchMemoryBPE]
]
with Pool() as pool:
JOB_RESULTS = pool.map(fit_funct, JOBS)
@ -169,14 +173,20 @@ class NanoSocraTrainerPool:
NEW_DATA.extend(job_output)
for key, value in job_memory.frequencies.items():
MEMORY.frequencies[key] = value
frequency = MEMORY.frequencies.get(key)
if frequency is None:
frequency = 0
MEMORY.frequencies[key] = 0
frequency += value
MEMORY.frequencies[key] = frequency
del job_output
del job_memory
print(f"Joined {i + 1} out of {CPU_COUNT}")
# Get new token
bpe.fit([], MEMORY, True)