Added fit method

This commit is contained in:
Christian Risi 2025-09-30 13:33:28 +02:00
parent 7020c9e683
commit c9032cab09

View File

@ -1,11 +1,16 @@
from .Encoder import Encoder from .Encoder import Encoder
from ..Errors import OutOfDictionaryException from ..Errors import OutOfDictionaryException, DuplicateWordException
class NanoSocratesBatchMemoryBPE: class NanoSocratesBatchMemoryBPE:
def __init__(self) -> None: def __init__(
pass self,
frequencies: dict[tuple[int, int], int],
merge_treshold: int
) -> None:
self.frequencies = frequencies
self.merge_treshold = merge_treshold
class NanoSocratesBPE(Encoder): class NanoSocratesBPE(Encoder):
@ -22,12 +27,66 @@ class NanoSocratesBPE(Encoder):
for key, value in vocabulary.items(): for key, value in vocabulary.items():
if value < 256: if value < 256:
raise OutOfDictionaryException() raise OutOfDictionaryException()
# TODO: check if they are in order
self.__vocabulary[key] = value self.__vocabulary[key] = value
self.__reverse_vocabulary[value] = key self.__reverse_vocabulary[value] = key
@property
def vocabulary_size(self):
return len(self.__vocabulary) + 255
@property
def vocabulary(self):
return self.__vocabulary
@property
def __next_id(self):
return self.vocabulary_size + 1
# TODO: implement fit # TODO: implement fit
def fit(): def fit(
pass self,
chunk_data: list[int],
memory: NanoSocratesBatchMemoryBPE,
last_batch: bool
):
ENCODED_CHUNK = self.__round_encode(chunk_data)
DATA_LEN_BEFORE_LAST = len(ENCODED_CHUNK) - 1
for i in range(0, DATA_LEN_BEFORE_LAST):
CANDIDATE_COUPLE = (ENCODED_CHUNK[i], ENCODED_CHUNK[i+1])
frequency = memory.frequencies.get(CANDIDATE_COUPLE)
# Initialize frequency
if frequency is None:
frequency = 0
memory.frequencies[CANDIDATE_COUPLE] = 0
frequency += 1
memory.frequencies[CANDIDATE_COUPLE] = frequency
if not last_batch:
return (self, memory, ENCODED_CHUNK)
if len(memory.frequencies) < 1:
return (self, memory, ENCODED_CHUNK)
FREQUENCIES = memory.frequencies
MAX_COUPLE = max(FREQUENCIES.items(), key=lambda item: item[1])[0]
FREQUENCY = FREQUENCIES[MAX_COUPLE]
if FREQUENCY < memory.merge_treshold:
return (self, memory, ENCODED_CHUNK)
self.__learn_word(MAX_COUPLE)
return (self, memory, ENCODED_CHUNK)
def encode(self, piece: str) -> list[int]: def encode(self, piece: str) -> list[int]:
@ -104,3 +163,15 @@ class NanoSocratesBPE(Encoder):
raise OutOfDictionaryException() raise OutOfDictionaryException()
return CANDIDATE_DECODED return CANDIDATE_DECODED
def __learn_word(self, words: tuple[int, int]):
ID = self.__next_id
DUPLICATE = self.__vocabulary.get(words)
if DUPLICATE is not None:
raise DuplicateWordException()
self.__vocabulary[words] = ID
self.__reverse_vocabulary[ID] = words