Created trainer to train BPE
This commit is contained in:
parent
c9032cab09
commit
b09bd4acba
164
Project_Model/Libs/BPE/Classes/NanoSocraTrainer.py
Normal file
164
Project_Model/Libs/BPE/Classes/NanoSocraTrainer.py
Normal file
@ -0,0 +1,164 @@
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
import re
|
||||
from ..Classes import NanoSocratesBPE, NanoSocratesChunker, NanoSocratesSplitter, NanoSocratesBatchMemoryBPE
|
||||
from ..Enums import TokenType
|
||||
from ..Utils import special_regex_maker, iterator_with_checks
|
||||
|
||||
|
||||
class NanoSocraTrainer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_vocabulary: int,
|
||||
special_vocabulary: list[str],
|
||||
chunk_size: int,
|
||||
merge_treshold: int = 0,
|
||||
max_iterations: int = 0,
|
||||
) -> None:
|
||||
# Bytes
|
||||
BYTE_RESERVED_TOKENS = 256
|
||||
SPECIAL_RESERVED_TOKENS = len(special_vocabulary)
|
||||
RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS
|
||||
|
||||
self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS
|
||||
self.__max_iterations = max_iterations
|
||||
self.__chunk_size = chunk_size
|
||||
self.__merge_treshold = merge_treshold
|
||||
self.__special_token_regex = special_regex_maker(special_vocabulary)
|
||||
|
||||
def trainBPE(
|
||||
self, path: Path, cache_dir: Path, bpe: NanoSocratesBPE | None = None
|
||||
) -> NanoSocratesBPE:
|
||||
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError()
|
||||
|
||||
if not cache_dir.is_dir():
|
||||
raise NotADirectoryError()
|
||||
|
||||
if bpe is None:
|
||||
bpe = NanoSocratesBPE()
|
||||
BPE = bpe
|
||||
|
||||
if BPE.vocabulary_size > self.__max_vocabulary:
|
||||
return BPE
|
||||
|
||||
exit = False
|
||||
cached = False
|
||||
current_iteration = 0
|
||||
|
||||
PATH_GEN = self.__switch_paths(path, cache_dir)
|
||||
|
||||
input_path = next(PATH_GEN)
|
||||
|
||||
while not exit:
|
||||
|
||||
|
||||
out_path = next(PATH_GEN)
|
||||
current_iteration = self.__increment_counter(current_iteration)
|
||||
LAST_VOC_SIZE = BPE.vocabulary_size
|
||||
|
||||
FILE = open(out_path, "w")
|
||||
|
||||
for _, _, output in self.__round_train(input_path, BPE, cached):
|
||||
FILE.write(output)
|
||||
|
||||
FILE.close()
|
||||
|
||||
cached = True
|
||||
input_path = out_path
|
||||
|
||||
NEW_VOC_SIZE = BPE.vocabulary_size
|
||||
|
||||
if LAST_VOC_SIZE == NEW_VOC_SIZE:
|
||||
exit = True
|
||||
continue
|
||||
|
||||
if current_iteration == self.__max_iterations:
|
||||
exit = True
|
||||
continue
|
||||
|
||||
if BPE.vocabulary_size == self.__max_vocabulary:
|
||||
exit = True
|
||||
continue
|
||||
|
||||
return BPE
|
||||
|
||||
def __round_train(
|
||||
self,
|
||||
path: Path,
|
||||
bpe: NanoSocratesBPE,
|
||||
cached: bool
|
||||
):
|
||||
|
||||
CHUNKER = NanoSocratesChunker(self.__chunk_size, self.__special_token_regex)
|
||||
SPLITTER = NanoSocratesSplitter(self.__special_token_regex)
|
||||
|
||||
BPE = bpe
|
||||
memory = NanoSocratesBatchMemoryBPE({}, self.__merge_treshold)
|
||||
|
||||
CHUNKER_GENERATOR = iterator_with_checks(CHUNKER.chunk(path))
|
||||
|
||||
for chunk, last_chunk in CHUNKER_GENERATOR:
|
||||
|
||||
PIECE_GENERATOR = iterator_with_checks(
|
||||
SPLITTER.split_text(chunk)
|
||||
)
|
||||
|
||||
for piece, last_piece in PIECE_GENERATOR:
|
||||
|
||||
LAST_BATCH = last_chunk and last_piece
|
||||
PIECE, TOKEN_TYPE = piece
|
||||
|
||||
if TOKEN_TYPE != TokenType.BPE:
|
||||
_, _, out = BPE.fit([], memory, LAST_BATCH)
|
||||
yield (BPE, memory, PIECE)
|
||||
continue
|
||||
|
||||
PIECE_DATA = self.__make_list_ids(PIECE, cached)
|
||||
|
||||
_, _, out = BPE.fit(PIECE_DATA, memory, LAST_BATCH)
|
||||
|
||||
OUT_STRING = f"{out}"
|
||||
yield (BPE, memory, OUT_STRING)
|
||||
|
||||
def __increment_counter(self, counter: int):
|
||||
|
||||
# What if overflows???
|
||||
try:
|
||||
counter += 1
|
||||
except:
|
||||
print("Integer overflow")
|
||||
counter = 1
|
||||
|
||||
return counter
|
||||
|
||||
def __make_list_ids(self, corpus: str, cached: bool):
|
||||
|
||||
if not cached:
|
||||
return list(map(ord, corpus))
|
||||
|
||||
REDUCED_CORPUS_LEN = len(corpus) -1
|
||||
|
||||
# Skip these cars "[" "]"
|
||||
INTS = corpus[1:REDUCED_CORPUS_LEN]
|
||||
INT_LIST = list(map(int,INTS.split(",")))
|
||||
return INT_LIST
|
||||
|
||||
def __switch_paths(self, path: Path, cache_path: Path):
|
||||
|
||||
yield path
|
||||
|
||||
TMP_1 = cache_path / "tmp1.txt"
|
||||
TMP_2 = cache_path / "tmp2.txt"
|
||||
|
||||
switch = True
|
||||
|
||||
while True:
|
||||
if switch:
|
||||
yield TMP_1
|
||||
else:
|
||||
yield TMP_2
|
||||
switch = not switch
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user