Fix of bugs and semantics

This commit is contained in:
Christian Risi
2025-10-03 13:26:58 +02:00
parent 6b9cb7cd35
commit c5c0c61f79
5 changed files with 134 additions and 129 deletions

View File

@@ -2,20 +2,18 @@ from collections import deque
from .Encoder import Encoder
from ..Errors import OutOfDictionaryException, DuplicateWordException
# ABOUT THE DICTIONARY:
# the string is converted into utf-char bytes, that is: each char is rappresented with a set of bytes from 1 to 4.
# each bytes get casted into an integer; such that, if an integer has its value lower then 256,
# then it is rappresenting an utf-char-byte, otherwise it is a token-ID.
class NanoSocratesBatchMemoryBPE:
""" Memory to batch training. Keeps token couple frequencies, and merge_treshold
"""
"""Memory to batch training. Keeps token couple frequencies, and merge_treshold"""
def __init__(
self,
frequencies: dict[tuple[int, int], int],
merge_treshold: int
self, frequencies: dict[tuple[int, int], int], merge_treshold: int
) -> None:
self.frequencies = frequencies
self.merge_treshold = merge_treshold
@@ -39,7 +37,6 @@ class NanoSocratesBPE(Encoder):
self.__vocabulary[key] = value
self.__reverse_vocabulary[value] = key
@property
def vocabulary_size(self):
return len(self.__vocabulary) + 256
@@ -62,7 +59,7 @@ class NanoSocratesBPE(Encoder):
self,
chunk_data: list[int],
memory: NanoSocratesBatchMemoryBPE,
last_batch: bool
last_batch: bool,
):
ENCODED_CHUNK = self.encode_intermediate(chunk_data)
@@ -70,7 +67,7 @@ class NanoSocratesBPE(Encoder):
# update frequency of each couple of element
for i in range(0, DATA_LEN_BEFORE_LAST):
CANDIDATE_COUPLE = (ENCODED_CHUNK[i], ENCODED_CHUNK[i+1])
CANDIDATE_COUPLE = (ENCODED_CHUNK[i], ENCODED_CHUNK[i + 1])
frequency = memory.frequencies.get(CANDIDATE_COUPLE)
@@ -82,7 +79,6 @@ class NanoSocratesBPE(Encoder):
frequency += 1
memory.frequencies[CANDIDATE_COUPLE] = frequency
if not last_batch:
return (self, memory, ENCODED_CHUNK)
@@ -100,9 +96,6 @@ class NanoSocratesBPE(Encoder):
return (self, memory, ENCODED_CHUNK)
def encode(self, piece: str) -> list[int]:
"""Encode a String into token IDs, it firt convert it into utf-8, then pass the list of integer to encode_intermediate()
Args:
@@ -114,12 +107,12 @@ class NanoSocratesBPE(Encoder):
return self.encode_intermediate(converted_piece)
def encode_intermediate(self, piece: list[int]) -> list[int]:
""" Encode a piece (as list of integer) till its maximum
"""Encode a piece (as list of integer) till its maximum
Args:
piece (list[int]): piece to encode
Returns:
list[int]: piece encoded
"""
list[int]: piece encoded
"""
current_piece = piece
new_piece = self.__round_encode(current_piece)
@@ -130,9 +123,8 @@ class NanoSocratesBPE(Encoder):
return current_piece
def __round_encode(self, piece: list[int]):
""" A single round of encode that traverse all the object. Multiple round are needed for a full encode: \n
"""A single round of encode that traverse all the object. Multiple round are needed for a full encode: \n
1) "ABAB" -> "XX"
2) "XX" -> "Y"
Args:
@@ -146,22 +138,25 @@ class NanoSocratesBPE(Encoder):
return piece
PIECE_LENGTH = len(piece) - 1
NEW_PIECE : list[int]= []
NEW_PIECE: list[int] = []
index = 0
while index < PIECE_LENGTH:
CANDIDATE_WORD = (piece[index], piece[index + 1]) # take a tuple of consecutive element [int]
CANDIDATE_WORD = (
piece[index],
piece[index + 1],
) # take a tuple of consecutive element [int]
CANDIDATE_TOKEN = self.__vocabulary.get(CANDIDATE_WORD)
# if no token to substitute the tuple, append the first element
if CANDIDATE_TOKEN is None:
NEW_PIECE.append(piece[index])
NEW_PIECE.append(piece[index])
index += 1
# if the latter element of the tuple is the last element of the piece, append it
if index == PIECE_LENGTH:
NEW_PIECE.append(piece[index])
NEW_PIECE.append(piece[index])
continue
@@ -169,13 +164,10 @@ class NanoSocratesBPE(Encoder):
NEW_PIECE.append(CANDIDATE_TOKEN)
index += 2
return NEW_PIECE
# TODO: Remake decode to take a list of token IDs
def decode(self, token_ids: list[int]) -> str:
# deque: double ended queue
token_stack: deque[int] = deque(token_ids)
@@ -185,19 +177,13 @@ class NanoSocratesBPE(Encoder):
TOKEN_ID = token_stack.popleft()
if TOKEN_ID < 256:
UTF_8_STRING_ARR.append(
TOKEN_ID
)
UTF_8_STRING_ARR.append(TOKEN_ID)
continue
left_token, right_token = self.__token_decode(TOKEN_ID)
token_stack.appendleft(
right_token
)
token_stack.appendleft(
left_token
)
token_stack.appendleft(right_token)
token_stack.appendleft(left_token)
return UTF_8_STRING_ARR.decode("utf-8")
@@ -211,7 +197,7 @@ class NanoSocratesBPE(Encoder):
return CANDIDATE_DECODED
def __learn_word(self, words: tuple[int, int]):
""" learn a new couple of object in the vocabulary
"""learn a new couple of object in the vocabulary
Args:
words (tuple[int, int]): the Pair of element to substitute with a new tokenID