Added BPE
TODO: - complete the fit method
This commit is contained in:
parent
b46df4f91a
commit
e433941405
106
Project_Model/Libs/BPE/Classes/NanoSocratesBPE.py
Normal file
106
Project_Model/Libs/BPE/Classes/NanoSocratesBPE.py
Normal file
@ -0,0 +1,106 @@
|
||||
from .Encoder import Encoder
|
||||
from ..Errors import OutOfDictionaryException
|
||||
|
||||
|
||||
class NanoSocratesBatchMemoryBPE:
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class NanoSocratesBPE(Encoder):
|
||||
|
||||
def __init__(self, vocabulary: dict[tuple[int, int], int] | None = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.__vocabulary: dict[tuple[int, int], int] = {}
|
||||
self.__reverse_vocabulary: dict[int, tuple[int, int]] = {}
|
||||
|
||||
if vocabulary is None:
|
||||
return
|
||||
|
||||
for key, value in vocabulary.items():
|
||||
if value < 256:
|
||||
raise OutOfDictionaryException()
|
||||
self.__vocabulary[key] = value
|
||||
self.__reverse_vocabulary[value] = key
|
||||
|
||||
# TODO: implement fit
|
||||
def fit():
|
||||
pass
|
||||
|
||||
def encode(self, piece: str) -> list[int]:
|
||||
|
||||
current_piece = list(map(ord, piece))
|
||||
new_piece = self.__round_encode(current_piece)
|
||||
|
||||
while len(current_piece) != len(new_piece):
|
||||
current_piece = new_piece
|
||||
new_piece = self.__round_encode(current_piece)
|
||||
|
||||
return current_piece
|
||||
|
||||
def __round_encode(self, piece: list[int]):
|
||||
|
||||
if len(piece) == 1:
|
||||
return piece
|
||||
|
||||
PIECE_LENGTH = len(piece) - 1
|
||||
NEW_PIECE = []
|
||||
|
||||
index = 0
|
||||
while index < PIECE_LENGTH:
|
||||
|
||||
CANDIDATE_WORD = (piece[index], piece[index + 1])
|
||||
CANDIDATE_TOKEN = self.__vocabulary.get(CANDIDATE_WORD)
|
||||
|
||||
if CANDIDATE_TOKEN is None:
|
||||
NEW_PIECE.append(piece[index])
|
||||
index += 1
|
||||
|
||||
if index == PIECE_LENGTH:
|
||||
NEW_PIECE.append(piece[index])
|
||||
|
||||
continue
|
||||
|
||||
NEW_PIECE.append(CANDIDATE_TOKEN)
|
||||
index += 2
|
||||
|
||||
|
||||
return NEW_PIECE
|
||||
|
||||
# TODO: decode
|
||||
def decode(self, token_id: int) -> str:
|
||||
|
||||
token_stack: list[int] = [token_id]
|
||||
DECODED_STRING_ARR: list[str] = []
|
||||
|
||||
while len(token_stack) > 0:
|
||||
TOKEN_ID = token_stack.pop()
|
||||
|
||||
if TOKEN_ID < 256:
|
||||
DECODED_CHAR = chr(TOKEN_ID)
|
||||
DECODED_STRING_ARR.append(
|
||||
DECODED_CHAR
|
||||
)
|
||||
continue
|
||||
|
||||
left_token, right_token = self.__token_decode(TOKEN_ID)
|
||||
|
||||
token_stack.append(
|
||||
right_token
|
||||
)
|
||||
token_stack.append(
|
||||
left_token
|
||||
)
|
||||
|
||||
return "".join(DECODED_STRING_ARR)
|
||||
|
||||
def __token_decode(self, token_id: int) -> tuple[int, int]:
|
||||
|
||||
CANDIDATE_DECODED = self.__reverse_vocabulary.get(token_id)
|
||||
|
||||
if CANDIDATE_DECODED is None:
|
||||
raise OutOfDictionaryException()
|
||||
|
||||
return CANDIDATE_DECODED
|
||||
6
Project_Model/Libs/BPE/Enums/TokenType.py
Normal file
6
Project_Model/Libs/BPE/Enums/TokenType.py
Normal file
@ -0,0 +1,6 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
class TokenType(Enum):
|
||||
|
||||
SPECIAL = auto()
|
||||
BPE = auto()
|
||||
52
Project_Model/Tests/bpe_test.py
Normal file
52
Project_Model/Tests/bpe_test.py
Normal file
@ -0,0 +1,52 @@
|
||||
from Project_Model.Libs.BPE.Enums import TokenType
|
||||
import Project_Model.Libs.BPE as BPE
|
||||
|
||||
import re
|
||||
|
||||
|
||||
class TestBPE:
|
||||
|
||||
def test_bpe_encoding_simple(self):
|
||||
|
||||
TEXT = "abababab"
|
||||
|
||||
# ab = 256
|
||||
# 256, 256 = 257
|
||||
# 257, 257 = 258
|
||||
|
||||
VOCABULARY = {(ord("a"), ord("b")): 256, (256, 256): 257, (257, 257): 258}
|
||||
EXPECTED = [258]
|
||||
|
||||
BPE_ENCODER = BPE.NanoSocratesBPE(VOCABULARY)
|
||||
|
||||
ENCODED = BPE_ENCODER.encode(TEXT)
|
||||
|
||||
assert len(ENCODED) == len(EXPECTED)
|
||||
|
||||
for encoded, expected in zip(ENCODED, EXPECTED):
|
||||
assert encoded == expected
|
||||
|
||||
def test_bpe_decoding_simple(self):
|
||||
|
||||
|
||||
INPUT = 258
|
||||
|
||||
# ab = 256
|
||||
# 256, 256 = 257
|
||||
# 257, 257 = 258
|
||||
|
||||
VOCABULARY = {(ord("a"), ord("b")): 256, (256, 256): 257, (257, 257): 258}
|
||||
EXPECTED = "abababab"
|
||||
|
||||
BPE_ENCODER = BPE.NanoSocratesBPE(VOCABULARY)
|
||||
|
||||
DECODED = BPE_ENCODER.decode(INPUT)
|
||||
|
||||
assert len(DECODED) == len(EXPECTED)
|
||||
|
||||
for encoded, expected in zip(DECODED, EXPECTED):
|
||||
assert encoded == expected
|
||||
|
||||
# Useful to debug weird cases
|
||||
if __name__ == "__main__":
|
||||
TestBPE().test_bpe_decoding_simple()
|
||||
Loading…
x
Reference in New Issue
Block a user