Added fit method
This commit is contained in:
parent
7020c9e683
commit
c9032cab09
@ -1,11 +1,16 @@
|
||||
from .Encoder import Encoder
|
||||
from ..Errors import OutOfDictionaryException
|
||||
from ..Errors import OutOfDictionaryException, DuplicateWordException
|
||||
|
||||
|
||||
class NanoSocratesBatchMemoryBPE:
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
frequencies: dict[tuple[int, int], int],
|
||||
merge_treshold: int
|
||||
) -> None:
|
||||
self.frequencies = frequencies
|
||||
self.merge_treshold = merge_treshold
|
||||
|
||||
|
||||
class NanoSocratesBPE(Encoder):
|
||||
@ -22,12 +27,66 @@ class NanoSocratesBPE(Encoder):
|
||||
for key, value in vocabulary.items():
|
||||
if value < 256:
|
||||
raise OutOfDictionaryException()
|
||||
# TODO: check if they are in order
|
||||
self.__vocabulary[key] = value
|
||||
self.__reverse_vocabulary[value] = key
|
||||
|
||||
|
||||
@property
|
||||
def vocabulary_size(self):
|
||||
return len(self.__vocabulary) + 255
|
||||
|
||||
@property
|
||||
def vocabulary(self):
|
||||
return self.__vocabulary
|
||||
|
||||
@property
|
||||
def __next_id(self):
|
||||
return self.vocabulary_size + 1
|
||||
|
||||
# TODO: implement fit
|
||||
def fit():
|
||||
pass
|
||||
def fit(
|
||||
self,
|
||||
chunk_data: list[int],
|
||||
memory: NanoSocratesBatchMemoryBPE,
|
||||
last_batch: bool
|
||||
):
|
||||
|
||||
ENCODED_CHUNK = self.__round_encode(chunk_data)
|
||||
DATA_LEN_BEFORE_LAST = len(ENCODED_CHUNK) - 1
|
||||
|
||||
for i in range(0, DATA_LEN_BEFORE_LAST):
|
||||
CANDIDATE_COUPLE = (ENCODED_CHUNK[i], ENCODED_CHUNK[i+1])
|
||||
|
||||
frequency = memory.frequencies.get(CANDIDATE_COUPLE)
|
||||
|
||||
# Initialize frequency
|
||||
if frequency is None:
|
||||
frequency = 0
|
||||
memory.frequencies[CANDIDATE_COUPLE] = 0
|
||||
|
||||
frequency += 1
|
||||
memory.frequencies[CANDIDATE_COUPLE] = frequency
|
||||
|
||||
if not last_batch:
|
||||
return (self, memory, ENCODED_CHUNK)
|
||||
|
||||
if len(memory.frequencies) < 1:
|
||||
return (self, memory, ENCODED_CHUNK)
|
||||
|
||||
FREQUENCIES = memory.frequencies
|
||||
MAX_COUPLE = max(FREQUENCIES.items(), key=lambda item: item[1])[0]
|
||||
FREQUENCY = FREQUENCIES[MAX_COUPLE]
|
||||
|
||||
if FREQUENCY < memory.merge_treshold:
|
||||
return (self, memory, ENCODED_CHUNK)
|
||||
|
||||
self.__learn_word(MAX_COUPLE)
|
||||
|
||||
return (self, memory, ENCODED_CHUNK)
|
||||
|
||||
|
||||
|
||||
|
||||
def encode(self, piece: str) -> list[int]:
|
||||
|
||||
@ -104,3 +163,15 @@ class NanoSocratesBPE(Encoder):
|
||||
raise OutOfDictionaryException()
|
||||
|
||||
return CANDIDATE_DECODED
|
||||
|
||||
def __learn_word(self, words: tuple[int, int]):
|
||||
|
||||
ID = self.__next_id
|
||||
|
||||
DUPLICATE = self.__vocabulary.get(words)
|
||||
|
||||
if DUPLICATE is not None:
|
||||
raise DuplicateWordException()
|
||||
|
||||
self.__vocabulary[words] = ID
|
||||
self.__reverse_vocabulary[ID] = words
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user