165 lines
4.3 KiB
Python
165 lines
4.3 KiB
Python
from collections import deque
|
|
import datetime
|
|
from pathlib import Path
|
|
import re
|
|
from ..Classes import (
|
|
NanoSocratesBPE,
|
|
NanoSocratesChunker,
|
|
NanoSocratesSplitter,
|
|
NanoSocratesBatchMemoryBPE,
|
|
)
|
|
from ..Enums import TokenType
|
|
from ..Utils import (
|
|
special_regex_maker,
|
|
iterator_with_checks,
|
|
save_nanos_vocabulary,
|
|
load_nanos_vocabulary,
|
|
save_json,
|
|
load_json,
|
|
)
|
|
|
|
|
|
class NanoSocraTraineRam:
|
|
|
|
def __init__(
|
|
self,
|
|
max_vocabulary: int,
|
|
special_vocabulary: list[str],
|
|
merge_treshold: int = 0,
|
|
max_iterations: int = 0,
|
|
print_after_iterations: int = 1,
|
|
) -> None:
|
|
# Bytes
|
|
BYTE_RESERVED_TOKENS = 256
|
|
SPECIAL_RESERVED_TOKENS = len(special_vocabulary)
|
|
RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS
|
|
|
|
self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS
|
|
self.__max_iterations = max_iterations
|
|
self.__merge_treshold = merge_treshold
|
|
self.__special_token_regex = special_regex_maker(special_vocabulary)
|
|
self.__print_after_iterations = print_after_iterations
|
|
|
|
def trainBPE(
|
|
self,
|
|
path: Path,
|
|
bpe: NanoSocratesBPE | None = None,
|
|
) -> NanoSocratesBPE:
|
|
|
|
if not path.is_file():
|
|
raise FileNotFoundError()
|
|
|
|
if bpe is None:
|
|
bpe = NanoSocratesBPE()
|
|
BPE = bpe
|
|
|
|
if BPE.vocabulary_size > self.__max_vocabulary:
|
|
return BPE
|
|
|
|
exit = False
|
|
current_iteration = 0
|
|
data = self.__gather_data_from_file(path)
|
|
|
|
while not exit:
|
|
|
|
current_iteration = self.__increment_counter(current_iteration)
|
|
|
|
LAST_VOC_SIZE = BPE.vocabulary_size
|
|
|
|
last_memory = None
|
|
|
|
_, data, last_memory = self.__round_train(BPE, data)
|
|
|
|
NEW_VOC_SIZE = BPE.vocabulary_size
|
|
|
|
if current_iteration % self.__print_after_iterations == 0:
|
|
|
|
DELIMITER = "==============="
|
|
|
|
DEBUG = "\n".join(
|
|
[
|
|
DELIMITER,
|
|
f"ITERATION: {current_iteration}",
|
|
DELIMITER,
|
|
f"\tVocabulary size: {BPE.vocabulary_size}\n",
|
|
f"\tFrequencies:\n{last_memory.frequencies}\n", # type: ignore (pretty sure it's not None)
|
|
f"\tvocabulary:\n{BPE.vocabulary}",
|
|
DELIMITER,
|
|
"",
|
|
]
|
|
)
|
|
print(DEBUG)
|
|
|
|
if LAST_VOC_SIZE == NEW_VOC_SIZE:
|
|
exit = True
|
|
continue
|
|
|
|
if current_iteration == self.__max_iterations:
|
|
exit = True
|
|
continue
|
|
|
|
if BPE.vocabulary_size == self.__max_vocabulary:
|
|
exit = True
|
|
continue
|
|
|
|
return BPE
|
|
|
|
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
|
|
|
|
DATA_LEN = len(data)
|
|
NEW_DATA = []
|
|
|
|
counter = 0
|
|
memory = NanoSocratesBatchMemoryBPE({}, 0)
|
|
while len(data) > 0:
|
|
counter += 1
|
|
last_batch = len(data) == 1
|
|
|
|
piece = data.pop()
|
|
|
|
bpe, memory, output = bpe.fit(piece, memory, last_batch)
|
|
|
|
if counter % int(1E6) == 0:
|
|
print(f"Fitted: {counter}/{DATA_LEN}")
|
|
|
|
if len(output) < 2:
|
|
continue
|
|
|
|
NEW_DATA.append(output)
|
|
|
|
return (bpe, NEW_DATA, memory)
|
|
|
|
def __gather_data_from_file(self, path: Path) -> list[list[int]]:
|
|
|
|
SPLITTER = NanoSocratesSplitter(self.__special_token_regex)
|
|
|
|
DATA: list[list[int]] = []
|
|
|
|
FILE = open(path, "r", encoding="utf-8")
|
|
file_string = FILE.read()
|
|
FILE.close()
|
|
|
|
for piece, type in SPLITTER.split_text(file_string):
|
|
|
|
if type != TokenType.BPE:
|
|
continue
|
|
|
|
int_list = self.__make_list_ids(piece)
|
|
DATA.append(int_list)
|
|
|
|
return DATA
|
|
|
|
def __increment_counter(self, counter: int):
|
|
|
|
# What if overflows???
|
|
try:
|
|
counter += 1
|
|
except:
|
|
print("Integer overflow")
|
|
counter = 1
|
|
|
|
return counter
|
|
|
|
def __make_list_ids(self, corpus: str):
|
|
return list(corpus.encode("utf-8"))
|