107 lines
2.6 KiB
Python
107 lines
2.6 KiB
Python
from .Encoder import Encoder
|
|
from ..Errors import OutOfDictionaryException
|
|
|
|
|
|
class NanoSocratesBatchMemoryBPE:
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
|
|
class NanoSocratesBPE(Encoder):
|
|
|
|
def __init__(self, vocabulary: dict[tuple[int, int], int] | None = None) -> None:
|
|
super().__init__()
|
|
|
|
self.__vocabulary: dict[tuple[int, int], int] = {}
|
|
self.__reverse_vocabulary: dict[int, tuple[int, int]] = {}
|
|
|
|
if vocabulary is None:
|
|
return
|
|
|
|
for key, value in vocabulary.items():
|
|
if value < 256:
|
|
raise OutOfDictionaryException()
|
|
self.__vocabulary[key] = value
|
|
self.__reverse_vocabulary[value] = key
|
|
|
|
# TODO: implement fit
|
|
def fit():
|
|
pass
|
|
|
|
def encode(self, piece: str) -> list[int]:
|
|
|
|
current_piece = list(map(ord, piece))
|
|
new_piece = self.__round_encode(current_piece)
|
|
|
|
while len(current_piece) != len(new_piece):
|
|
current_piece = new_piece
|
|
new_piece = self.__round_encode(current_piece)
|
|
|
|
return current_piece
|
|
|
|
def __round_encode(self, piece: list[int]):
|
|
|
|
if len(piece) == 1:
|
|
return piece
|
|
|
|
PIECE_LENGTH = len(piece) - 1
|
|
NEW_PIECE = []
|
|
|
|
index = 0
|
|
while index < PIECE_LENGTH:
|
|
|
|
CANDIDATE_WORD = (piece[index], piece[index + 1])
|
|
CANDIDATE_TOKEN = self.__vocabulary.get(CANDIDATE_WORD)
|
|
|
|
if CANDIDATE_TOKEN is None:
|
|
NEW_PIECE.append(piece[index])
|
|
index += 1
|
|
|
|
if index == PIECE_LENGTH:
|
|
NEW_PIECE.append(piece[index])
|
|
|
|
continue
|
|
|
|
NEW_PIECE.append(CANDIDATE_TOKEN)
|
|
index += 2
|
|
|
|
|
|
return NEW_PIECE
|
|
|
|
# TODO: decode
|
|
def decode(self, token_id: int) -> str:
|
|
|
|
token_stack: list[int] = [token_id]
|
|
DECODED_STRING_ARR: list[str] = []
|
|
|
|
while len(token_stack) > 0:
|
|
TOKEN_ID = token_stack.pop()
|
|
|
|
if TOKEN_ID < 256:
|
|
DECODED_CHAR = chr(TOKEN_ID)
|
|
DECODED_STRING_ARR.append(
|
|
DECODED_CHAR
|
|
)
|
|
continue
|
|
|
|
left_token, right_token = self.__token_decode(TOKEN_ID)
|
|
|
|
token_stack.append(
|
|
right_token
|
|
)
|
|
token_stack.append(
|
|
left_token
|
|
)
|
|
|
|
return "".join(DECODED_STRING_ARR)
|
|
|
|
def __token_decode(self, token_id: int) -> tuple[int, int]:
|
|
|
|
CANDIDATE_DECODED = self.__reverse_vocabulary.get(token_id)
|
|
|
|
if CANDIDATE_DECODED is None:
|
|
raise OutOfDictionaryException()
|
|
|
|
return CANDIDATE_DECODED
|