from .Encoder import Encoder from ..Errors import OutOfDictionaryException, DuplicateWordException class NanoSocratesBatchMemoryBPE: def __init__( self, frequencies: dict[tuple[int, int], int], merge_treshold: int ) -> None: self.frequencies = frequencies self.merge_treshold = merge_treshold class NanoSocratesBPE(Encoder): def __init__(self, vocabulary: dict[tuple[int, int], int] | None = None) -> None: super().__init__() self.__vocabulary: dict[tuple[int, int], int] = {} self.__reverse_vocabulary: dict[int, tuple[int, int]] = {} if vocabulary is None: return for key, value in vocabulary.items(): if value < 256: raise OutOfDictionaryException() # TODO: check if they are in order self.__vocabulary[key] = value self.__reverse_vocabulary[value] = key @property def vocabulary_size(self): return len(self.__vocabulary) + 255 @property def vocabulary(self): return self.__vocabulary @property def __next_id(self): return self.vocabulary_size + 1 # TODO: implement fit def fit( self, chunk_data: list[int], memory: NanoSocratesBatchMemoryBPE, last_batch: bool ): ENCODED_CHUNK = self.encode_intermediate(chunk_data) DATA_LEN_BEFORE_LAST = len(ENCODED_CHUNK) - 1 for i in range(0, DATA_LEN_BEFORE_LAST): CANDIDATE_COUPLE = (ENCODED_CHUNK[i], ENCODED_CHUNK[i+1]) frequency = memory.frequencies.get(CANDIDATE_COUPLE) # Initialize frequency if frequency is None: frequency = 0 memory.frequencies[CANDIDATE_COUPLE] = 0 frequency += 1 memory.frequencies[CANDIDATE_COUPLE] = frequency if not last_batch: return (self, memory, ENCODED_CHUNK) if len(memory.frequencies) < 1: return (self, memory, ENCODED_CHUNK) FREQUENCIES = memory.frequencies MAX_COUPLE = max(FREQUENCIES.items(), key=lambda item: item[1])[0] FREQUENCY = FREQUENCIES[MAX_COUPLE] if FREQUENCY < memory.merge_treshold: return (self, memory, ENCODED_CHUNK) self.__learn_word(MAX_COUPLE) return (self, memory, ENCODED_CHUNK) def encode(self, piece: str) -> list[int]: current_piece = list(piece.encode("utf-8")) new_piece = self.__round_encode(current_piece) while len(current_piece) != len(new_piece): current_piece = new_piece new_piece = self.__round_encode(current_piece) return current_piece def encode_intermediate(self, piece: list[int]): current_piece = piece new_piece = self.__round_encode(current_piece) while len(current_piece) != len(new_piece): current_piece = new_piece new_piece = self.__round_encode(current_piece) return current_piece def __round_encode(self, piece: list[int]): if len(piece) == 1: return piece PIECE_LENGTH = len(piece) - 1 NEW_PIECE : list[int]= [] index = 0 while index < PIECE_LENGTH: CANDIDATE_WORD = (piece[index], piece[index + 1]) CANDIDATE_TOKEN = self.__vocabulary.get(CANDIDATE_WORD) if CANDIDATE_TOKEN is None: NEW_PIECE.append(piece[index]) index += 1 if index == PIECE_LENGTH: NEW_PIECE.append(piece[index]) continue NEW_PIECE.append(CANDIDATE_TOKEN) index += 2 return NEW_PIECE # TODO: Remake decode to take a list of token IDs def decode(self, token_id: int) -> str: token_stack: list[int] = [token_id] DECODED_STRING_ARR: list[str] = [] while len(token_stack) > 0: TOKEN_ID = token_stack.pop() if TOKEN_ID < 256: DECODED_CHAR = chr(TOKEN_ID) DECODED_STRING_ARR.append( DECODED_CHAR ) continue left_token, right_token = self.__token_decode(TOKEN_ID) token_stack.append( right_token ) token_stack.append( left_token ) return "".join(DECODED_STRING_ARR) def __token_decode(self, token_id: int) -> tuple[int, int]: CANDIDATE_DECODED = self.__reverse_vocabulary.get(token_id) if CANDIDATE_DECODED is None: raise OutOfDictionaryException() return CANDIDATE_DECODED def __learn_word(self, words: tuple[int, int]): ID = self.__next_id DUPLICATE = self.__vocabulary.get(words) if DUPLICATE is not None: raise DuplicateWordException() self.__vocabulary[words] = ID self.__reverse_vocabulary[ID] = words