in NanoSocratesBPE: encode() method rewritten and tested

This commit is contained in:
GassiGiuseppe 2025-10-02 12:12:44 +02:00
parent 856bd8909c
commit 0eef2148a9

View File

@ -4,12 +4,15 @@ from ..Errors import OutOfDictionaryException, DuplicateWordException
class NanoSocratesBatchMemoryBPE: class NanoSocratesBatchMemoryBPE:
""" Memory to batch training. Keeps token couple frequencies, and merge_treshold
"""
def __init__( def __init__(
self, self,
frequencies: dict[tuple[int, int], int], frequencies: dict[tuple[int, int], int],
merge_treshold: int merge_treshold: int
) -> None: ) -> None:
self.frequencies = frequencies self.frequencies = frequencies
self.merge_treshold = merge_treshold self.merge_treshold = merge_treshold
@ -42,7 +45,12 @@ class NanoSocratesBPE(Encoder):
return self.__vocabulary return self.__vocabulary
@property @property
def __next_id(self): def __next_id(self) -> int:
"""
Gets the next it
Returns:
int:
"""
return self.vocabulary_size + 1 return self.vocabulary_size + 1
# TODO: implement fit # TODO: implement fit
@ -90,20 +98,26 @@ class NanoSocratesBPE(Encoder):
def encode(self, piece: str) -> list[int]: def encode(self, piece: str) -> list[int]:
"""Encode a String into token IDs, it firt convert it into utf-8, then pass the list of integer to encode_intermediate()
Args:
piece (str):
Returns:
list[int]:
"""
converted_piece = list(piece.encode("utf-8"))
return self.encode_intermediate(converted_piece)
current_piece = list(piece.encode("utf-8")) def encode_intermediate(self, piece: list[int]) -> list[int]:
new_piece = self.__round_encode(current_piece) """ Encode a piece (as list of integer) till its maximum
Args:
while len(current_piece) != len(new_piece): piece (list[int]): piece to encode
current_piece = new_piece Returns:
new_piece = self.__round_encode(current_piece) list[int]: piece encoded
"""
return current_piece
def encode_intermediate(self, piece: list[int]):
current_piece = piece current_piece = piece
new_piece = self.__round_encode(current_piece) new_piece = self.__round_encode(current_piece)
# until current_piece is bigger then new_piece, keep encoding
while len(current_piece) != len(new_piece): while len(current_piece) != len(new_piece):
current_piece = new_piece current_piece = new_piece
new_piece = self.__round_encode(current_piece) new_piece = self.__round_encode(current_piece)
@ -112,6 +126,14 @@ class NanoSocratesBPE(Encoder):
def __round_encode(self, piece: list[int]): def __round_encode(self, piece: list[int]):
"""_summary_
Args:
piece (list[int]): _description_
Returns:
_type_: _description_
"""
if len(piece) == 1: if len(piece) == 1:
return piece return piece
@ -143,6 +165,7 @@ class NanoSocratesBPE(Encoder):
# TODO: Remake decode to take a list of token IDs # TODO: Remake decode to take a list of token IDs
def decode(self, token_ids: list[int]) -> str: def decode(self, token_ids: list[int]) -> str:
# deque: double ended queue
token_stack: deque[int] = deque(token_ids) token_stack: deque[int] = deque(token_ids)
UTF_8_STRING_ARR: bytearray = bytearray() UTF_8_STRING_ARR: bytearray = bytearray()