101 lines
2.9 KiB
Python
101 lines
2.9 KiB
Python
|
|
import argparse
|
||
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
import sys
|
||
|
|
# TODO: make relative imports
|
||
|
|
import Project_Model.Libs.BPE as BPE
|
||
|
|
from Scripts.Libs.CleaningPipeline.special_token import SpecialToken
|
||
|
|
|
||
|
|
DEFAULT_CHUNK_SIZE = int(18e4)
|
||
|
|
DEFAULT_DEBUG_AFTER_ITER = 1
|
||
|
|
DEFAULT_MAX_VOCABULARY = int(32E3)
|
||
|
|
DEFAULT_MERGE_TRESHOLD = 1
|
||
|
|
DEFAULT_MAX_ITERATIONS = 0
|
||
|
|
TOKEN_LIST = [token.value for token in SpecialToken]
|
||
|
|
|
||
|
|
|
||
|
|
class ProgramArgs:
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
input_file: str,
|
||
|
|
cache_dir: str,
|
||
|
|
output_file: str,
|
||
|
|
max_vocabulary: int,
|
||
|
|
max_iterations: int,
|
||
|
|
merge_treshold: int,
|
||
|
|
chunk_size: int,
|
||
|
|
debug_after: int,
|
||
|
|
) -> None:
|
||
|
|
self.input_file = input_file
|
||
|
|
self.cache_dir = cache_dir
|
||
|
|
self.output_file = output_file
|
||
|
|
self.max_vocabulary = max_vocabulary
|
||
|
|
self.max_iterations = max_iterations
|
||
|
|
self.merge_treshold = merge_treshold
|
||
|
|
self.chunk_size = chunk_size
|
||
|
|
self.debug_after = debug_after
|
||
|
|
|
||
|
|
|
||
|
|
def get_args(args: list[str]) -> ProgramArgs:
|
||
|
|
|
||
|
|
PARSER = argparse.ArgumentParser()
|
||
|
|
PARSER.add_argument("--input-file", "--input", "-i", required=True, type=str)
|
||
|
|
PARSER.add_argument("--cache-dir", "--cache", "-c", required=True, type=str)
|
||
|
|
PARSER.add_argument("--output-file", "--output", "-o", required=True, type=str)
|
||
|
|
PARSER.add_argument("--max-vocabulary", "--max-voc", default=DEFAULT_MAX_VOCABULARY, type=int)
|
||
|
|
PARSER.add_argument("--max-iterations", "--max-iter", default=DEFAULT_MAX_ITERATIONS, type=int)
|
||
|
|
PARSER.add_argument("--merge-treshold", "--tresh", default=DEFAULT_MERGE_TRESHOLD, type=int)
|
||
|
|
PARSER.add_argument("--chunk-size", default=DEFAULT_CHUNK_SIZE, type=int)
|
||
|
|
PARSER.add_argument("--debug-after", default=DEFAULT_DEBUG_AFTER_ITER, type=int)
|
||
|
|
|
||
|
|
parsed_args, _ = PARSER.parse_known_args(args)
|
||
|
|
|
||
|
|
return ProgramArgs(
|
||
|
|
parsed_args.input_file,
|
||
|
|
parsed_args.cache_dir,
|
||
|
|
parsed_args.output_file,
|
||
|
|
parsed_args.max_vocabulary,
|
||
|
|
parsed_args.max_iterations,
|
||
|
|
parsed_args.merge_treshold,
|
||
|
|
parsed_args.chunk_size,
|
||
|
|
parsed_args.debug_after,
|
||
|
|
) # type ignore
|
||
|
|
|
||
|
|
|
||
|
|
def train(args: ProgramArgs):
|
||
|
|
|
||
|
|
TRAINER = BPE.NanoSocraTrainer(
|
||
|
|
args.max_vocabulary,
|
||
|
|
TOKEN_LIST,
|
||
|
|
args.chunk_size,
|
||
|
|
args.merge_treshold,
|
||
|
|
args.max_iterations,
|
||
|
|
args.debug_after
|
||
|
|
)
|
||
|
|
|
||
|
|
DATASET_PATH = Path(args.input_file)
|
||
|
|
CACHE_DIR = Path(args.cache_dir)
|
||
|
|
VOCABULARY_PATH = Path(args.output_file)
|
||
|
|
|
||
|
|
print(f"Training BPE")
|
||
|
|
|
||
|
|
BPE_ENCODER = TRAINER.trainBPE(
|
||
|
|
DATASET_PATH,
|
||
|
|
CACHE_DIR
|
||
|
|
)
|
||
|
|
|
||
|
|
VOCABULARY = BPE_ENCODER.vocabulary
|
||
|
|
VOCABULARY_JSON = json.dumps(VOCABULARY)
|
||
|
|
|
||
|
|
print(f"Saving Vocabulary in {VOCABULARY_PATH}")
|
||
|
|
|
||
|
|
FILE = open(VOCABULARY_PATH, "w")
|
||
|
|
FILE.write(VOCABULARY_JSON)
|
||
|
|
FILE.close()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
ARGS = get_args(sys.argv)
|
||
|
|
train(ARGS)
|