Added time checking

This commit is contained in:
Christian Risi 2025-10-02 08:48:45 +02:00
parent 17d82f0a4e
commit aa765b4555

View File

@ -5,6 +5,7 @@ from multiprocessing import Pool
import os
from pathlib import Path
import re
import time
from ..Classes import (
NanoSocratesBPE,
NanoSocratesChunker,
@ -49,6 +50,22 @@ def split_fit(object: tuple[NanoSocratesBPE, list[list[int]]]):
return (bpe, NEW_DATA, memory)
def split_encode(object: tuple[NanoSocratesBPE, list[list[int]]]):
bpe, data = object
NEW_DATA: list[list[int]] = []
for piece in data:
output = bpe.encode_intermediate(piece)
if len(output) < 2:
continue
# We are sure of its type
NEW_DATA.append(output) # type: ignore
return NEW_DATA
class NanoSocraTrainerPool:
@ -96,6 +113,8 @@ class NanoSocraTrainerPool:
exit = False
current_iteration = 0
data = self.__gather_data_from_file(path)
data = self.__encode_from_cache(BPE, data)
while not exit:
@ -105,8 +124,9 @@ class NanoSocraTrainerPool:
last_memory = None
start = time.time_ns()
_, data, last_memory = self.__round_train(BPE, data)
end = time.time_ns()
NEW_VOC_SIZE = BPE.vocabulary_size
VOCABULARY = BPE.vocabulary
@ -122,8 +142,8 @@ class NanoSocraTrainerPool:
DELIMITER,
f"ITERATION: {current_iteration}",
DELIMITER,
f"\tVocabulary size: {BPE.vocabulary_size}\n",
f"\tvocabulary:\n{BPE.vocabulary}",
f"\tVocabulary size: {BPE.vocabulary_size - 256}\n",
f"\tTime elapsed: {(end - start)/1E9}s",
DELIMITER,
"",
]
@ -214,6 +234,37 @@ class NanoSocraTrainerPool:
return DATA
def __encode_from_cache(self, bpe: NanoSocratesBPE, data: list[list[int]]):
NEW_DATA : list[list[int]]= []
CPU_COUNT = os.process_cpu_count()
if CPU_COUNT is None:
raise Exception()
VOCABULARY = bpe.vocabulary
data_chunks = split(data, CPU_COUNT)
JOBS = [(NanoSocratesBPE(VOCABULARY), chunk) for chunk in data_chunks]
JOB_RESULTS: list[list[list[int]]]
with Pool() as pool:
JOB_RESULTS = pool.map(split_encode, JOBS)
for i, res in zip(range(0, CPU_COUNT), JOB_RESULTS):
job_output = res
NEW_DATA.extend(job_output)
del job_output
print(f"Joined {i + 1} out of {CPU_COUNT}")
print(f"Sentences from {len(data)} to {len(NEW_DATA)}")
return NEW_DATA
def __increment_counter(self, counter: int):
# What if overflows???