Fixed a bug while joining frequencies

This commit is contained in:
Christian Risi 2025-10-02 01:50:37 +02:00
parent d19426fa62
commit 3fe4e45ceb

View File

@ -21,15 +21,17 @@ from ..Utils import (
load_json, load_json,
) )
def split(a, n): def split(a, n):
k, m = divmod(len(a), n) k, m = divmod(len(a), n)
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)) return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n))
def split_fit(object: tuple[NanoSocratesBPE, list[list[int]]]): def split_fit(object: tuple[NanoSocratesBPE, list[list[int]]]):
bpe, data = object bpe, data = object
NEW_DATA: list[list[int]]= [] NEW_DATA: list[list[int]] = []
memory = NanoSocratesBatchMemoryBPE({}, 0) memory = NanoSocratesBatchMemoryBPE({}, 0)
@ -144,7 +146,7 @@ class NanoSocraTrainerPool:
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]): def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
NEW_DATA : list[list[int]] = [] NEW_DATA: list[list[int]] = []
MEMORY = NanoSocratesBatchMemoryBPE({}, 0) MEMORY = NanoSocratesBatchMemoryBPE({}, 0)
@ -159,7 +161,9 @@ class NanoSocraTrainerPool:
data_chunks = split(data, CPU_COUNT) data_chunks = split(data, CPU_COUNT)
JOBS = [(NanoSocratesBPE(VOCABULARY), chunk) for chunk in data_chunks] JOBS = [(NanoSocratesBPE(VOCABULARY), chunk) for chunk in data_chunks]
JOB_RESULTS: list[tuple[NanoSocratesBPE, list[list[int]], NanoSocratesBatchMemoryBPE]] JOB_RESULTS: list[
tuple[NanoSocratesBPE, list[list[int]], NanoSocratesBatchMemoryBPE]
]
with Pool() as pool: with Pool() as pool:
JOB_RESULTS = pool.map(fit_funct, JOBS) JOB_RESULTS = pool.map(fit_funct, JOBS)
@ -169,14 +173,20 @@ class NanoSocraTrainerPool:
NEW_DATA.extend(job_output) NEW_DATA.extend(job_output)
for key, value in job_memory.frequencies.items(): for key, value in job_memory.frequencies.items():
MEMORY.frequencies[key] = value frequency = MEMORY.frequencies.get(key)
if frequency is None:
frequency = 0
MEMORY.frequencies[key] = 0
frequency += value
MEMORY.frequencies[key] = frequency
del job_output del job_output
del job_memory del job_memory
print(f"Joined {i + 1} out of {CPU_COUNT}") print(f"Joined {i + 1} out of {CPU_COUNT}")
# Get new token # Get new token
bpe.fit([], MEMORY, True) bpe.fit([], MEMORY, True)