from .Encoder import Encoder from ..Errors import OutOfDictionaryException class NanoSocratesBatchMemoryBPE: def __init__(self) -> None: pass class NanoSocratesBPE(Encoder): def __init__(self, vocabulary: dict[tuple[int, int], int] | None = None) -> None: super().__init__() self.__vocabulary: dict[tuple[int, int], int] = {} self.__reverse_vocabulary: dict[int, tuple[int, int]] = {} if vocabulary is None: return for key, value in vocabulary.items(): if value < 256: raise OutOfDictionaryException() self.__vocabulary[key] = value self.__reverse_vocabulary[value] = key # TODO: implement fit def fit(): pass def encode(self, piece: str) -> list[int]: current_piece = list(map(ord, piece)) new_piece = self.__round_encode(current_piece) while len(current_piece) != len(new_piece): current_piece = new_piece new_piece = self.__round_encode(current_piece) return current_piece def __round_encode(self, piece: list[int]): if len(piece) == 1: return piece PIECE_LENGTH = len(piece) - 1 NEW_PIECE = [] index = 0 while index < PIECE_LENGTH: CANDIDATE_WORD = (piece[index], piece[index + 1]) CANDIDATE_TOKEN = self.__vocabulary.get(CANDIDATE_WORD) if CANDIDATE_TOKEN is None: NEW_PIECE.append(piece[index]) index += 1 if index == PIECE_LENGTH: NEW_PIECE.append(piece[index]) continue NEW_PIECE.append(CANDIDATE_TOKEN) index += 2 return NEW_PIECE # TODO: decode def decode(self, token_id: int) -> str: token_stack: list[int] = [token_id] DECODED_STRING_ARR: list[str] = [] while len(token_stack) > 0: TOKEN_ID = token_stack.pop() if TOKEN_ID < 256: DECODED_CHAR = chr(TOKEN_ID) DECODED_STRING_ARR.append( DECODED_CHAR ) continue left_token, right_token = self.__token_decode(TOKEN_ID) token_stack.append( right_token ) token_stack.append( left_token ) return "".join(DECODED_STRING_ARR) def __token_decode(self, token_id: int) -> tuple[int, int]: CANDIDATE_DECODED = self.__reverse_vocabulary.get(token_id) if CANDIDATE_DECODED is None: raise OutOfDictionaryException() return CANDIDATE_DECODED