Fixed decoding phase

This commit is contained in:
Christian Risi 2025-10-02 09:33:58 +02:00
parent eadba1fb82
commit 1eae8582b2

View File

@ -1,3 +1,4 @@
from collections import deque
from .Encoder import Encoder
from ..Errors import OutOfDictionaryException, DuplicateWordException
@ -140,31 +141,30 @@ class NanoSocratesBPE(Encoder):
return NEW_PIECE
# TODO: Remake decode to take a list of token IDs
def decode(self, token_id: int) -> str:
def decode(self, token_ids: list[int]) -> str:
token_stack: list[int] = [token_id]
DECODED_STRING_ARR: list[str] = []
token_stack: deque[int] = deque(token_ids)
UTF_8_STRING_ARR: bytearray = bytearray()
while len(token_stack) > 0:
TOKEN_ID = token_stack.pop()
TOKEN_ID = token_stack.popleft()
if TOKEN_ID < 256:
DECODED_CHAR = chr(TOKEN_ID)
DECODED_STRING_ARR.append(
DECODED_CHAR
UTF_8_STRING_ARR.append(
TOKEN_ID
)
continue
left_token, right_token = self.__token_decode(TOKEN_ID)
token_stack.append(
token_stack.appendleft(
right_token
)
token_stack.append(
token_stack.appendleft(
left_token
)
return "".join(DECODED_STRING_ARR)
return UTF_8_STRING_ARR.decode("utf-8")
def __token_decode(self, token_id: int) -> tuple[int, int]: