in NanoSocratesBPE: encode() method rewritten and tested
This commit is contained in:
parent
856bd8909c
commit
0eef2148a9
@ -4,12 +4,15 @@ from ..Errors import OutOfDictionaryException, DuplicateWordException
|
|||||||
|
|
||||||
|
|
||||||
class NanoSocratesBatchMemoryBPE:
|
class NanoSocratesBatchMemoryBPE:
|
||||||
|
""" Memory to batch training. Keeps token couple frequencies, and merge_treshold
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
frequencies: dict[tuple[int, int], int],
|
frequencies: dict[tuple[int, int], int],
|
||||||
merge_treshold: int
|
merge_treshold: int
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.frequencies = frequencies
|
self.frequencies = frequencies
|
||||||
self.merge_treshold = merge_treshold
|
self.merge_treshold = merge_treshold
|
||||||
|
|
||||||
@ -42,7 +45,12 @@ class NanoSocratesBPE(Encoder):
|
|||||||
return self.__vocabulary
|
return self.__vocabulary
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def __next_id(self):
|
def __next_id(self) -> int:
|
||||||
|
"""
|
||||||
|
Gets the next it
|
||||||
|
Returns:
|
||||||
|
int:
|
||||||
|
"""
|
||||||
return self.vocabulary_size + 1
|
return self.vocabulary_size + 1
|
||||||
|
|
||||||
# TODO: implement fit
|
# TODO: implement fit
|
||||||
@ -90,20 +98,26 @@ class NanoSocratesBPE(Encoder):
|
|||||||
|
|
||||||
|
|
||||||
def encode(self, piece: str) -> list[int]:
|
def encode(self, piece: str) -> list[int]:
|
||||||
|
"""Encode a String into token IDs, it firt convert it into utf-8, then pass the list of integer to encode_intermediate()
|
||||||
|
Args:
|
||||||
|
piece (str):
|
||||||
|
Returns:
|
||||||
|
list[int]:
|
||||||
|
"""
|
||||||
|
converted_piece = list(piece.encode("utf-8"))
|
||||||
|
return self.encode_intermediate(converted_piece)
|
||||||
|
|
||||||
current_piece = list(piece.encode("utf-8"))
|
def encode_intermediate(self, piece: list[int]) -> list[int]:
|
||||||
new_piece = self.__round_encode(current_piece)
|
""" Encode a piece (as list of integer) till its maximum
|
||||||
|
Args:
|
||||||
while len(current_piece) != len(new_piece):
|
piece (list[int]): piece to encode
|
||||||
current_piece = new_piece
|
Returns:
|
||||||
new_piece = self.__round_encode(current_piece)
|
list[int]: piece encoded
|
||||||
|
"""
|
||||||
return current_piece
|
|
||||||
|
|
||||||
def encode_intermediate(self, piece: list[int]):
|
|
||||||
current_piece = piece
|
current_piece = piece
|
||||||
new_piece = self.__round_encode(current_piece)
|
new_piece = self.__round_encode(current_piece)
|
||||||
|
|
||||||
|
# until current_piece is bigger then new_piece, keep encoding
|
||||||
while len(current_piece) != len(new_piece):
|
while len(current_piece) != len(new_piece):
|
||||||
current_piece = new_piece
|
current_piece = new_piece
|
||||||
new_piece = self.__round_encode(current_piece)
|
new_piece = self.__round_encode(current_piece)
|
||||||
@ -112,6 +126,14 @@ class NanoSocratesBPE(Encoder):
|
|||||||
|
|
||||||
|
|
||||||
def __round_encode(self, piece: list[int]):
|
def __round_encode(self, piece: list[int]):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
piece (list[int]): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
|
||||||
if len(piece) == 1:
|
if len(piece) == 1:
|
||||||
return piece
|
return piece
|
||||||
@ -143,6 +165,7 @@ class NanoSocratesBPE(Encoder):
|
|||||||
# TODO: Remake decode to take a list of token IDs
|
# TODO: Remake decode to take a list of token IDs
|
||||||
def decode(self, token_ids: list[int]) -> str:
|
def decode(self, token_ids: list[int]) -> str:
|
||||||
|
|
||||||
|
# deque: double ended queue
|
||||||
token_stack: deque[int] = deque(token_ids)
|
token_stack: deque[int] = deque(token_ids)
|
||||||
UTF_8_STRING_ARR: bytearray = bytearray()
|
UTF_8_STRING_ARR: bytearray = bytearray()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user