Finished uploading stubs for TokeNano

This commit is contained in:
Christian Risi 2025-10-01 18:56:53 +02:00
parent b3d444979f
commit fbbe6226bb
6 changed files with 246 additions and 2 deletions

View File

@ -0,0 +1,153 @@
from collections import deque
import datetime
from pathlib import Path
import re
from ..Classes import (
NanoSocratesBPE,
NanoSocratesChunker,
NanoSocratesSplitter,
NanoSocratesBatchMemoryBPE,
)
from ..Enums import TokenType
from ..Utils import (
special_regex_maker,
iterator_with_checks,
save_nanos_vocabulary,
load_nanos_vocabulary,
save_json,
load_json,
)
class NanoSocraTraineRam:
def __init__(
self,
max_vocabulary: int,
special_vocabulary: list[str],
merge_treshold: int = 0,
max_iterations: int = 0,
print_after_iterations: int = 1,
) -> None:
# Bytes
BYTE_RESERVED_TOKENS = 256
SPECIAL_RESERVED_TOKENS = len(special_vocabulary)
RESERVED_TOKENS = BYTE_RESERVED_TOKENS + SPECIAL_RESERVED_TOKENS
self.__max_vocabulary = max_vocabulary - RESERVED_TOKENS
self.__max_iterations = max_iterations
self.__merge_treshold = merge_treshold
self.__special_token_regex = special_regex_maker(special_vocabulary)
self.__print_after_iterations = print_after_iterations
def trainBPE(
self,
path: Path,
bpe: NanoSocratesBPE | None = None,
) -> NanoSocratesBPE:
if not path.is_file():
raise FileNotFoundError()
if bpe is None:
bpe = NanoSocratesBPE()
BPE = bpe
if BPE.vocabulary_size > self.__max_vocabulary:
return BPE
exit = False
current_iteration = 0
data = self.__gather_data_from_file(path)
while not exit:
current_iteration = self.__increment_counter(current_iteration)
LAST_VOC_SIZE = BPE.vocabulary_size
last_memory = None
_, data, last_memory = self.__round_train(BPE, data)
NEW_VOC_SIZE = BPE.vocabulary_size
if current_iteration % self.__print_after_iterations == 0:
DELIMITER = "==============="
DEBUG = "\n".join(
[
DELIMITER,
f"ITERATION: {current_iteration}",
DELIMITER,
f"\tVocabulary size: {BPE.vocabulary_size}\n",
f"\tFrequencies:\n{last_memory.frequencies}\n", # type: ignore (pretty sure it's not None)
f"\tvocabulary:\n{BPE.vocabulary}",
DELIMITER,
"",
]
)
print(DEBUG)
if LAST_VOC_SIZE == NEW_VOC_SIZE:
exit = True
continue
if current_iteration == self.__max_iterations:
exit = True
continue
if BPE.vocabulary_size == self.__max_vocabulary:
exit = True
continue
return BPE
def __round_train(self, bpe: NanoSocratesBPE, data: list[list[int]]):
DATA_LEN = len(data)
memory = NanoSocratesBatchMemoryBPE({}, 0)
for piece, index in zip(data, range(0, DATA_LEN)):
last_batch = index == DATA_LEN - 1
bpe, memory, output = bpe.fit(piece, memory, last_batch)
data[index] = output
return (bpe, data, memory)
def __gather_data_from_file(self, path: Path) -> list[list[int]]:
SPLITTER = NanoSocratesSplitter(self.__special_token_regex)
DATA: list[list[int]] = []
FILE = open(path, "r", encoding="utf-8")
file_string = FILE.read()
FILE.close()
for piece, type in SPLITTER.split_text(file_string):
if type != TokenType.BPE:
continue
int_list = self.__make_list_ids(piece)
DATA.append(int_list)
return DATA
def __increment_counter(self, counter: int):
# What if overflows???
try:
counter += 1
except:
print("Integer overflow")
counter = 1
return counter
def __make_list_ids(self, corpus: str):
return list(corpus.encode("utf-8"))

View File

@ -2,10 +2,12 @@ from .NanoSocratesChunker import NanoSocratesChunker
from .NanoSocratesSplitter import NanoSocratesSplitter from .NanoSocratesSplitter import NanoSocratesSplitter
from .NanoSocratesBPE import NanoSocratesBPE, NanoSocratesBatchMemoryBPE from .NanoSocratesBPE import NanoSocratesBPE, NanoSocratesBatchMemoryBPE
from .NanoSocraTrainer import NanoSocraTrainer from .NanoSocraTrainer import NanoSocraTrainer
from .NanoSocraTraineRam import NanoSocraTraineRam
__all__ = [ __all__ = [
"NanoSocratesChunker", "NanoSocratesChunker",
"NanoSocratesSplitter", "NanoSocratesSplitter",
"NanoSocratesBPE", "NanoSocratesBPE",
"NanoSocraTrainer" "NanoSocraTrainer",
"NanoSocraTraineRam"
] ]

View File

@ -1,7 +1,12 @@
from .special_regex_maker import special_regex_maker from .special_regex_maker import special_regex_maker
from .lag_checker_iterator import iterator_with_checks from .lag_checker_iterator import iterator_with_checks
from .vocabulary import save_nanos_vocabulary, load_nanos_vocabulary
from .json_utils import save_json, load_json
__all__ = [ __all__ = [
"special_regex_maker", "special_regex_maker",
"iterator_with_checks" "iterator_with_checks",
"save_nanos_vocabulary",
"load_nanos_vocabulary",
"save_json", "load_json"
] ]

View File

@ -0,0 +1,84 @@
import argparse
import json
from pathlib import Path
import sys
# TODO: make relative imports
import Project_Model.Libs.BPE as BPE
from Scripts.Libs.CleaningPipeline.special_token import SpecialToken
DEFAULT_DEBUG_AFTER_ITER = 1
DEFAULT_MAX_VOCABULARY = int(32E3)
DEFAULT_MERGE_TRESHOLD = 1
DEFAULT_MAX_ITERATIONS = 0
TOKEN_LIST = [token.value for token in SpecialToken]
class ProgramArgs:
def __init__(
self,
input_file: str,
output_file: str,
max_vocabulary: int,
max_iterations: int,
merge_treshold: int,
debug_after: int,
) -> None:
self.input_file = input_file
self.output_file = output_file
self.max_vocabulary = max_vocabulary
self.max_iterations = max_iterations
self.merge_treshold = merge_treshold
self.debug_after = debug_after
def get_args(args: list[str]) -> ProgramArgs:
PARSER = argparse.ArgumentParser()
PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str)
PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str)
PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int)
PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int)
PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int)
PARSER.add_argument("--debug-after", default=DEFAULT_DEBUG_AFTER_ITER, type=int)
parsed_args, _ = PARSER.parse_known_args(args)
return ProgramArgs(
parsed_args.input_file,
parsed_args.output_file,
parsed_args.max_vocabulary,
parsed_args.max_iterations,
parsed_args.merge_treshold,
parsed_args.debug_after,
) # type ignore
def train(args: ProgramArgs):
TRAINER = BPE.NanoSocraTraineRam(
args.max_vocabulary,
TOKEN_LIST,
args.merge_treshold,
args.max_iterations,
args.debug_after
)
DATASET_PATH = Path(args.input_file)
VOCABULARY_PATH = Path(args.output_file)
print(f"Training BPE")
BPE_ENCODER = TRAINER.trainBPE(
DATASET_PATH
)
VOCABULARY = BPE_ENCODER.vocabulary
print(f"Saving Vocabulary in {VOCABULARY_PATH}")
BPE.save_nanos_vocabulary(VOCABULARY, VOCABULARY_PATH)
if __name__ == "__main__":
ARGS = get_args(sys.argv)
train(ARGS)