from collections import deque from .Encoder import Encoder from ..Errors import OutOfDictionaryException, DuplicateWordException # ABOUT THE DICTIONARY: # the string is converted into utf-char bytes, that is: each char is rappresented with a set of bytes from 1 to 4. # each bytes get casted into an integer; such that, if an integer has its value lower then 256, # then it is rappresenting an utf-char-byte, otherwise it is a token-ID. class NanoSocratesBatchMemoryBPE: """Memory to batch training. Keeps token couple frequencies, and merge_treshold""" def __init__( self, frequencies: dict[tuple[int, int], int], merge_treshold: int ) -> None: self.frequencies = frequencies self.merge_treshold = merge_treshold class NanoSocratesBPE(Encoder): def __init__(self, vocabulary: dict[tuple[int, int], int] | None = None) -> None: super().__init__() self.__vocabulary: dict[tuple[int, int], int] = {} self.__reverse_vocabulary: dict[int, tuple[int, int]] = {} if vocabulary is None: return for key, value in vocabulary.items(): if value < 256: raise OutOfDictionaryException() # values under 256 are used for unpaired char # TODO: check if they are in order self.__vocabulary[key] = value self.__reverse_vocabulary[value] = key @property def vocabulary_size(self): return len(self.__vocabulary) + 256 @property def vocabulary(self): return self.__vocabulary @property def __next_id(self) -> int: """ Gets the next it Returns: int: """ return self.vocabulary_size # TODO: implement fit def fit( self, chunk_data: list[int], memory: NanoSocratesBatchMemoryBPE, last_batch: bool, ): ENCODED_CHUNK = self.encode_intermediate(chunk_data) DATA_LEN_BEFORE_LAST = len(ENCODED_CHUNK) - 1 # update frequency of each couple of element for i in range(0, DATA_LEN_BEFORE_LAST): CANDIDATE_COUPLE = (ENCODED_CHUNK[i], ENCODED_CHUNK[i + 1]) frequency = memory.frequencies.get(CANDIDATE_COUPLE) # Initialize frequency if frequency is None: frequency = 0 memory.frequencies[CANDIDATE_COUPLE] = 0 frequency += 1 memory.frequencies[CANDIDATE_COUPLE] = frequency if not last_batch: return (self, memory, ENCODED_CHUNK) if len(memory.frequencies) < 1: return (self, memory, ENCODED_CHUNK) FREQUENCIES = memory.frequencies MAX_COUPLE = max(FREQUENCIES.items(), key=lambda item: item[1])[0] FREQUENCY = FREQUENCIES[MAX_COUPLE] if FREQUENCY < memory.merge_treshold: return (self, memory, ENCODED_CHUNK) self.__learn_word(MAX_COUPLE) return (self, memory, ENCODED_CHUNK) def encode(self, piece: str) -> list[int]: """Encode a String into token IDs, it firt convert it into utf-8, then pass the list of integer to encode_intermediate() Args: piece (str): Returns: list[int]: """ converted_piece = list(piece.encode("utf-8")) return self.encode_intermediate(converted_piece) def encode_intermediate(self, piece: list[int]) -> list[int]: """Encode a piece (as list of integer) till its maximum Args: piece (list[int]): piece to encode Returns: list[int]: piece encoded """ current_piece = piece new_piece = self.__round_encode(current_piece) # until current_piece is bigger then new_piece, keep encoding while len(current_piece) != len(new_piece): current_piece = new_piece new_piece = self.__round_encode(current_piece) return current_piece def __round_encode(self, piece: list[int]): """A single round of encode that traverse all the object. Multiple round are needed for a full encode: \n 1) "ABAB" -> "XX" 2) "XX" -> "Y" Args: piece (list[int]): the object to encode as a list of integer Returns: (list[int]): the one time encoded object """ if len(piece) == 1: return piece PIECE_LENGTH = len(piece) - 1 NEW_PIECE: list[int] = [] index = 0 while index < PIECE_LENGTH: CANDIDATE_WORD = ( piece[index], piece[index + 1], ) # take a tuple of consecutive element [int] CANDIDATE_TOKEN = self.__vocabulary.get(CANDIDATE_WORD) # if no token to substitute the tuple, append the first element if CANDIDATE_TOKEN is None: NEW_PIECE.append(piece[index]) index += 1 # if the latter element of the tuple is the last element of the piece, append it if index == PIECE_LENGTH: NEW_PIECE.append(piece[index]) continue # in this case there was a candidate token to substitute the couple of element NEW_PIECE.append(CANDIDATE_TOKEN) index += 2 if index == PIECE_LENGTH: NEW_PIECE.append(piece[index]) return NEW_PIECE # TODO: Remake decode to take a list of token IDs def decode(self, token_ids: list[int]) -> str: # deque: double ended queue token_stack: deque[int] = deque(token_ids) UTF_8_STRING_ARR: bytearray = bytearray() while len(token_stack) > 0: TOKEN_ID = token_stack.popleft() if TOKEN_ID < 256: UTF_8_STRING_ARR.append(TOKEN_ID) continue left_token, right_token = self.__token_decode(TOKEN_ID) token_stack.appendleft(right_token) token_stack.appendleft(left_token) return UTF_8_STRING_ARR.decode("utf-8") def __token_decode(self, token_id: int) -> tuple[int, int]: CANDIDATE_DECODED = self.__reverse_vocabulary.get(token_id) if CANDIDATE_DECODED is None: raise OutOfDictionaryException() return CANDIDATE_DECODED def __learn_word(self, words: tuple[int, int]): """learn a new couple of object in the vocabulary Args: words (tuple[int, int]): the Pair of element to substitute with a new tokenID Raises: DuplicateWordException: it launch if there is a duplicate of the new tokenID in the dictionary """ ID = self.__next_id DUPLICATE = self.__vocabulary.get(words) if DUPLICATE is not None: raise DuplicateWordException() self.__vocabulary[words] = ID self.__reverse_vocabulary[ID] = words