in NanoSocratesBPE: encode() method rewritten and tested
This commit is contained in:
parent
856bd8909c
commit
0eef2148a9
@ -4,12 +4,15 @@ from ..Errors import OutOfDictionaryException, DuplicateWordException
|
||||
|
||||
|
||||
class NanoSocratesBatchMemoryBPE:
|
||||
""" Memory to batch training. Keeps token couple frequencies, and merge_treshold
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequencies: dict[tuple[int, int], int],
|
||||
merge_treshold: int
|
||||
) -> None:
|
||||
|
||||
self.frequencies = frequencies
|
||||
self.merge_treshold = merge_treshold
|
||||
|
||||
@ -42,7 +45,12 @@ class NanoSocratesBPE(Encoder):
|
||||
return self.__vocabulary
|
||||
|
||||
@property
|
||||
def __next_id(self):
|
||||
def __next_id(self) -> int:
|
||||
"""
|
||||
Gets the next it
|
||||
Returns:
|
||||
int:
|
||||
"""
|
||||
return self.vocabulary_size + 1
|
||||
|
||||
# TODO: implement fit
|
||||
@ -90,20 +98,26 @@ class NanoSocratesBPE(Encoder):
|
||||
|
||||
|
||||
def encode(self, piece: str) -> list[int]:
|
||||
"""Encode a String into token IDs, it firt convert it into utf-8, then pass the list of integer to encode_intermediate()
|
||||
Args:
|
||||
piece (str):
|
||||
Returns:
|
||||
list[int]:
|
||||
"""
|
||||
converted_piece = list(piece.encode("utf-8"))
|
||||
return self.encode_intermediate(converted_piece)
|
||||
|
||||
current_piece = list(piece.encode("utf-8"))
|
||||
new_piece = self.__round_encode(current_piece)
|
||||
|
||||
while len(current_piece) != len(new_piece):
|
||||
current_piece = new_piece
|
||||
new_piece = self.__round_encode(current_piece)
|
||||
|
||||
return current_piece
|
||||
|
||||
def encode_intermediate(self, piece: list[int]):
|
||||
def encode_intermediate(self, piece: list[int]) -> list[int]:
|
||||
""" Encode a piece (as list of integer) till its maximum
|
||||
Args:
|
||||
piece (list[int]): piece to encode
|
||||
Returns:
|
||||
list[int]: piece encoded
|
||||
"""
|
||||
current_piece = piece
|
||||
new_piece = self.__round_encode(current_piece)
|
||||
|
||||
# until current_piece is bigger then new_piece, keep encoding
|
||||
while len(current_piece) != len(new_piece):
|
||||
current_piece = new_piece
|
||||
new_piece = self.__round_encode(current_piece)
|
||||
@ -112,6 +126,14 @@ class NanoSocratesBPE(Encoder):
|
||||
|
||||
|
||||
def __round_encode(self, piece: list[int]):
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
piece (list[int]): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
if len(piece) == 1:
|
||||
return piece
|
||||
@ -143,6 +165,7 @@ class NanoSocratesBPE(Encoder):
|
||||
# TODO: Remake decode to take a list of token IDs
|
||||
def decode(self, token_ids: list[int]) -> str:
|
||||
|
||||
# deque: double ended queue
|
||||
token_stack: deque[int] = deque(token_ids)
|
||||
UTF_8_STRING_ARR: bytearray = bytearray()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user