85 lines
2.7 KiB
Python
Raw Normal View History

from pathlib import Path
2025-10-03 13:26:58 +02:00
from ..Classes import NanoSocratesSplitter
from ..Classes import NanoSocratesBPE
from ..Classes import NanoSocratesSpecial
from ..Utils import special_regex_maker
from ..Enums import TokenType
2025-10-07 17:41:53 +02:00
from ..Enums import SpecialToken
2025-10-03 13:26:58 +02:00
class TokeNanoCore:
2025-10-03 13:26:58 +02:00
def __init__(
self,
bpe_vocabulary: dict[tuple[int, int], int],
special_token_list: list[str],
# special_vocabulary: dict[str, int]
):
self.__bpe_encoder = NanoSocratesBPE(bpe_vocabulary)
SPECIAL_REGEX = special_regex_maker(special_token_list)
BPE_VOCABULARY_SIZE = self.__bpe_encoder.vocabulary_size
self.__splitter = NanoSocratesSplitter(SPECIAL_REGEX, BPE_VOCABULARY_SIZE)
self.__special_encoder = NanoSocratesSpecial(
BPE_VOCABULARY_SIZE, special_token_list
2025-10-04 19:42:29 +02:00
)
@property
def vocabulary_size(self):
BPE_VOC_SIZE = self.__bpe_encoder.vocabulary_size
SPECIAL_VOC_SIZE = self.__special_encoder.vocabulary_size
2025-10-10 20:10:08 +02:00
return BPE_VOC_SIZE + SPECIAL_VOC_SIZE + 1
2025-10-03 13:26:58 +02:00
def encode(self, corpus: str) -> list[int]:
output: list[int] = []
for piece, token_type in self.__splitter.split_text(corpus):
if token_type == TokenType.SPECIAL:
2025-10-03 13:26:58 +02:00
output.extend(self.__special_encoder.encode(piece))
# slow but clear
if token_type == TokenType.BPE:
2025-10-03 13:26:58 +02:00
output.extend(self.__bpe_encoder.encode(piece))
return output
2025-10-07 17:41:53 +02:00
def encode_incomplete_string(self, corpus: str) -> list[int]:
"""
Encode string which don't end with a special token
"""
corpus = corpus + SpecialToken.CORPUS_END.value
output: list[int] = []
for piece, token_type in self.__splitter.split_text(corpus):
if token_type == TokenType.SPECIAL:
output.extend(self.__special_encoder.encode(piece))
# slow but clear
if token_type == TokenType.BPE:
output.extend(self.__bpe_encoder.encode(piece))
return output[:-1]
2025-10-03 13:26:58 +02:00
def decode(self, corpus: list[int]) -> str:
output_str = ""
for token, token_type in self.__splitter.split_tokens(corpus):
# token is an integer if special, a list of integer otherwise
if token_type == TokenType.SPECIAL:
2025-10-03 13:26:58 +02:00
output_str += self.__special_encoder.decode(
token
) # it accept an integer
# slow but clear
if token_type == TokenType.BPE:
2025-10-03 13:26:58 +02:00
output_str += self.__bpe_encoder.decode(
token
) # it accept a list of integer
return output_str