2025-10-02 09:33:58 +02:00
|
|
|
from collections import deque
|
2025-09-28 18:04:44 +02:00
|
|
|
from .Encoder import Encoder
|
2025-09-30 13:33:28 +02:00
|
|
|
from ..Errors import OutOfDictionaryException, DuplicateWordException
|
2025-09-28 18:04:44 +02:00
|
|
|
|
2025-10-03 13:26:58 +02:00
|
|
|
|
2025-10-03 00:57:19 +02:00
|
|
|
# ABOUT THE DICTIONARY:
|
|
|
|
|
# the string is converted into utf-char bytes, that is: each char is rappresented with a set of bytes from 1 to 4.
|
|
|
|
|
# each bytes get casted into an integer; such that, if an integer has its value lower then 256,
|
|
|
|
|
# then it is rappresenting an utf-char-byte, otherwise it is a token-ID.
|
2025-09-28 18:04:44 +02:00
|
|
|
class NanoSocratesBatchMemoryBPE:
|
2025-10-03 13:26:58 +02:00
|
|
|
"""Memory to batch training. Keeps token couple frequencies, and merge_treshold"""
|
2025-09-28 18:04:44 +02:00
|
|
|
|
2025-09-30 13:33:28 +02:00
|
|
|
def __init__(
|
2025-10-03 13:26:58 +02:00
|
|
|
self, frequencies: dict[tuple[int, int], int], merge_treshold: int
|
2025-09-30 13:33:28 +02:00
|
|
|
) -> None:
|
2025-10-03 13:26:58 +02:00
|
|
|
|
2025-09-30 13:33:28 +02:00
|
|
|
self.frequencies = frequencies
|
|
|
|
|
self.merge_treshold = merge_treshold
|
2025-09-28 18:04:44 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class NanoSocratesBPE(Encoder):
|
|
|
|
|
|
|
|
|
|
def __init__(self, vocabulary: dict[tuple[int, int], int] | None = None) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.__vocabulary: dict[tuple[int, int], int] = {}
|
|
|
|
|
self.__reverse_vocabulary: dict[int, tuple[int, int]] = {}
|
|
|
|
|
|
|
|
|
|
if vocabulary is None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
for key, value in vocabulary.items():
|
|
|
|
|
if value < 256:
|
|
|
|
|
raise OutOfDictionaryException()
|
2025-10-03 00:57:19 +02:00
|
|
|
# values under 256 are used for unpaired char
|
2025-09-30 13:33:28 +02:00
|
|
|
# TODO: check if they are in order
|
2025-09-28 18:04:44 +02:00
|
|
|
self.__vocabulary[key] = value
|
|
|
|
|
self.__reverse_vocabulary[value] = key
|
|
|
|
|
|
2025-09-30 13:33:28 +02:00
|
|
|
@property
|
|
|
|
|
def vocabulary_size(self):
|
2025-10-03 00:57:19 +02:00
|
|
|
return len(self.__vocabulary) + 256
|
2025-09-30 13:33:28 +02:00
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def vocabulary(self):
|
|
|
|
|
return self.__vocabulary
|
|
|
|
|
|
|
|
|
|
@property
|
2025-10-02 12:12:44 +02:00
|
|
|
def __next_id(self) -> int:
|
|
|
|
|
"""
|
|
|
|
|
Gets the next it
|
|
|
|
|
Returns:
|
|
|
|
|
int:
|
|
|
|
|
"""
|
2025-10-03 00:57:19 +02:00
|
|
|
return self.vocabulary_size
|
2025-09-30 13:33:28 +02:00
|
|
|
|
2025-09-28 18:04:44 +02:00
|
|
|
# TODO: implement fit
|
2025-09-30 13:33:28 +02:00
|
|
|
def fit(
|
|
|
|
|
self,
|
|
|
|
|
chunk_data: list[int],
|
|
|
|
|
memory: NanoSocratesBatchMemoryBPE,
|
2025-10-03 13:26:58 +02:00
|
|
|
last_batch: bool,
|
2025-09-30 13:33:28 +02:00
|
|
|
):
|
|
|
|
|
|
2025-10-02 08:48:13 +02:00
|
|
|
ENCODED_CHUNK = self.encode_intermediate(chunk_data)
|
2025-09-30 13:33:28 +02:00
|
|
|
DATA_LEN_BEFORE_LAST = len(ENCODED_CHUNK) - 1
|
|
|
|
|
|
2025-10-03 00:57:19 +02:00
|
|
|
# update frequency of each couple of element
|
2025-09-30 13:33:28 +02:00
|
|
|
for i in range(0, DATA_LEN_BEFORE_LAST):
|
2025-10-03 13:26:58 +02:00
|
|
|
CANDIDATE_COUPLE = (ENCODED_CHUNK[i], ENCODED_CHUNK[i + 1])
|
2025-09-30 13:33:28 +02:00
|
|
|
|
|
|
|
|
frequency = memory.frequencies.get(CANDIDATE_COUPLE)
|
|
|
|
|
|
|
|
|
|
# Initialize frequency
|
|
|
|
|
if frequency is None:
|
|
|
|
|
frequency = 0
|
|
|
|
|
memory.frequencies[CANDIDATE_COUPLE] = 0
|
|
|
|
|
|
|
|
|
|
frequency += 1
|
|
|
|
|
memory.frequencies[CANDIDATE_COUPLE] = frequency
|
|
|
|
|
|
|
|
|
|
if not last_batch:
|
|
|
|
|
return (self, memory, ENCODED_CHUNK)
|
|
|
|
|
|
|
|
|
|
if len(memory.frequencies) < 1:
|
|
|
|
|
return (self, memory, ENCODED_CHUNK)
|
|
|
|
|
|
|
|
|
|
FREQUENCIES = memory.frequencies
|
|
|
|
|
MAX_COUPLE = max(FREQUENCIES.items(), key=lambda item: item[1])[0]
|
|
|
|
|
FREQUENCY = FREQUENCIES[MAX_COUPLE]
|
|
|
|
|
|
|
|
|
|
if FREQUENCY < memory.merge_treshold:
|
|
|
|
|
return (self, memory, ENCODED_CHUNK)
|
|
|
|
|
|
|
|
|
|
self.__learn_word(MAX_COUPLE)
|
|
|
|
|
|
|
|
|
|
return (self, memory, ENCODED_CHUNK)
|
|
|
|
|
|
2025-09-28 18:04:44 +02:00
|
|
|
def encode(self, piece: str) -> list[int]:
|
2025-10-02 12:12:44 +02:00
|
|
|
"""Encode a String into token IDs, it firt convert it into utf-8, then pass the list of integer to encode_intermediate()
|
|
|
|
|
Args:
|
|
|
|
|
piece (str):
|
|
|
|
|
Returns:
|
|
|
|
|
list[int]:
|
|
|
|
|
"""
|
|
|
|
|
converted_piece = list(piece.encode("utf-8"))
|
|
|
|
|
return self.encode_intermediate(converted_piece)
|
|
|
|
|
|
|
|
|
|
def encode_intermediate(self, piece: list[int]) -> list[int]:
|
2025-10-03 13:26:58 +02:00
|
|
|
"""Encode a piece (as list of integer) till its maximum
|
2025-10-02 12:12:44 +02:00
|
|
|
Args:
|
|
|
|
|
piece (list[int]): piece to encode
|
|
|
|
|
Returns:
|
2025-10-03 13:26:58 +02:00
|
|
|
list[int]: piece encoded
|
|
|
|
|
"""
|
2025-10-02 08:48:13 +02:00
|
|
|
current_piece = piece
|
|
|
|
|
new_piece = self.__round_encode(current_piece)
|
|
|
|
|
|
2025-10-02 12:12:44 +02:00
|
|
|
# until current_piece is bigger then new_piece, keep encoding
|
2025-10-02 08:48:13 +02:00
|
|
|
while len(current_piece) != len(new_piece):
|
|
|
|
|
current_piece = new_piece
|
|
|
|
|
new_piece = self.__round_encode(current_piece)
|
|
|
|
|
|
|
|
|
|
return current_piece
|
|
|
|
|
|
2025-09-28 18:04:44 +02:00
|
|
|
def __round_encode(self, piece: list[int]):
|
2025-10-03 13:26:58 +02:00
|
|
|
"""A single round of encode that traverse all the object. Multiple round are needed for a full encode: \n
|
2025-10-03 00:57:19 +02:00
|
|
|
1) "ABAB" -> "XX"
|
|
|
|
|
2) "XX" -> "Y"
|
2025-10-02 12:12:44 +02:00
|
|
|
Args:
|
2025-10-03 00:57:19 +02:00
|
|
|
piece (list[int]): the object to encode as a list of integer
|
2025-10-02 12:12:44 +02:00
|
|
|
|
|
|
|
|
Returns:
|
2025-10-03 00:57:19 +02:00
|
|
|
(list[int]): the one time encoded object
|
2025-10-02 12:12:44 +02:00
|
|
|
"""
|
2025-09-28 18:04:44 +02:00
|
|
|
|
|
|
|
|
if len(piece) == 1:
|
|
|
|
|
return piece
|
|
|
|
|
|
|
|
|
|
PIECE_LENGTH = len(piece) - 1
|
2025-10-03 13:26:58 +02:00
|
|
|
NEW_PIECE: list[int] = []
|
2025-09-28 18:04:44 +02:00
|
|
|
|
|
|
|
|
index = 0
|
|
|
|
|
while index < PIECE_LENGTH:
|
|
|
|
|
|
2025-10-03 13:26:58 +02:00
|
|
|
CANDIDATE_WORD = (
|
|
|
|
|
piece[index],
|
|
|
|
|
piece[index + 1],
|
|
|
|
|
) # take a tuple of consecutive element [int]
|
2025-09-28 18:04:44 +02:00
|
|
|
CANDIDATE_TOKEN = self.__vocabulary.get(CANDIDATE_WORD)
|
|
|
|
|
|
2025-10-03 00:57:19 +02:00
|
|
|
# if no token to substitute the tuple, append the first element
|
2025-09-28 18:04:44 +02:00
|
|
|
if CANDIDATE_TOKEN is None:
|
2025-10-03 13:26:58 +02:00
|
|
|
NEW_PIECE.append(piece[index])
|
2025-09-28 18:04:44 +02:00
|
|
|
index += 1
|
|
|
|
|
|
2025-10-03 00:57:19 +02:00
|
|
|
# if the latter element of the tuple is the last element of the piece, append it
|
2025-09-28 18:04:44 +02:00
|
|
|
if index == PIECE_LENGTH:
|
2025-10-03 13:26:58 +02:00
|
|
|
NEW_PIECE.append(piece[index])
|
2025-09-28 18:04:44 +02:00
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
2025-10-03 00:57:19 +02:00
|
|
|
# in this case there was a candidate token to substitute the couple of element
|
2025-09-28 18:04:44 +02:00
|
|
|
NEW_PIECE.append(CANDIDATE_TOKEN)
|
2025-10-03 16:08:11 +02:00
|
|
|
|
2025-10-03 16:09:53 +02:00
|
|
|
index += 2
|
2025-10-03 17:59:46 +02:00
|
|
|
|
2025-10-03 16:08:11 +02:00
|
|
|
if index == PIECE_LENGTH:
|
2025-10-03 17:59:46 +02:00
|
|
|
NEW_PIECE.append(piece[index])
|
2025-09-28 18:04:44 +02:00
|
|
|
|
|
|
|
|
return NEW_PIECE
|
|
|
|
|
|
2025-09-30 23:58:31 +02:00
|
|
|
# TODO: Remake decode to take a list of token IDs
|
2025-10-02 09:33:58 +02:00
|
|
|
def decode(self, token_ids: list[int]) -> str:
|
2025-09-28 18:04:44 +02:00
|
|
|
|
2025-10-02 12:12:44 +02:00
|
|
|
# deque: double ended queue
|
2025-10-02 09:33:58 +02:00
|
|
|
token_stack: deque[int] = deque(token_ids)
|
|
|
|
|
UTF_8_STRING_ARR: bytearray = bytearray()
|
2025-09-28 18:04:44 +02:00
|
|
|
|
|
|
|
|
while len(token_stack) > 0:
|
2025-10-02 09:33:58 +02:00
|
|
|
TOKEN_ID = token_stack.popleft()
|
2025-09-28 18:04:44 +02:00
|
|
|
|
|
|
|
|
if TOKEN_ID < 256:
|
2025-10-03 13:26:58 +02:00
|
|
|
UTF_8_STRING_ARR.append(TOKEN_ID)
|
2025-09-28 18:04:44 +02:00
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
left_token, right_token = self.__token_decode(TOKEN_ID)
|
|
|
|
|
|
2025-10-03 13:26:58 +02:00
|
|
|
token_stack.appendleft(right_token)
|
|
|
|
|
token_stack.appendleft(left_token)
|
2025-09-28 18:04:44 +02:00
|
|
|
|
2025-10-09 13:24:48 +02:00
|
|
|
return UTF_8_STRING_ARR.decode("utf-8", errors="ignore")
|
2025-09-28 18:04:44 +02:00
|
|
|
|
|
|
|
|
def __token_decode(self, token_id: int) -> tuple[int, int]:
|
|
|
|
|
|
|
|
|
|
CANDIDATE_DECODED = self.__reverse_vocabulary.get(token_id)
|
|
|
|
|
|
|
|
|
|
if CANDIDATE_DECODED is None:
|
|
|
|
|
raise OutOfDictionaryException()
|
|
|
|
|
|
|
|
|
|
return CANDIDATE_DECODED
|
2025-09-30 13:33:28 +02:00
|
|
|
|
|
|
|
|
def __learn_word(self, words: tuple[int, int]):
|
2025-10-03 13:26:58 +02:00
|
|
|
"""learn a new couple of object in the vocabulary
|
2025-10-03 00:57:19 +02:00
|
|
|
Args:
|
|
|
|
|
words (tuple[int, int]): the Pair of element to substitute with a new tokenID
|
2025-09-30 13:33:28 +02:00
|
|
|
|
2025-10-03 00:57:19 +02:00
|
|
|
Raises:
|
|
|
|
|
DuplicateWordException: it launch if there is a duplicate of the new tokenID in the dictionary
|
|
|
|
|
"""
|
2025-09-30 13:33:28 +02:00
|
|
|
ID = self.__next_id
|
|
|
|
|
|
|
|
|
|
DUPLICATE = self.__vocabulary.get(words)
|
|
|
|
|
|
|
|
|
|
if DUPLICATE is not None:
|
|
|
|
|
raise DuplicateWordException()
|
|
|
|
|
|
|
|
|
|
self.__vocabulary[words] = ID
|
|
|
|
|
self.__reverse_vocabulary[ID] = words
|