Fixed decoding phase

This commit is contained in:
Christian Risi 2025-10-02 09:33:58 +02:00
parent eadba1fb82
commit 1eae8582b2

View File

@ -1,3 +1,4 @@
from collections import deque
from .Encoder import Encoder from .Encoder import Encoder
from ..Errors import OutOfDictionaryException, DuplicateWordException from ..Errors import OutOfDictionaryException, DuplicateWordException
@ -140,31 +141,30 @@ class NanoSocratesBPE(Encoder):
return NEW_PIECE return NEW_PIECE
# TODO: Remake decode to take a list of token IDs # TODO: Remake decode to take a list of token IDs
def decode(self, token_id: int) -> str: def decode(self, token_ids: list[int]) -> str:
token_stack: list[int] = [token_id] token_stack: deque[int] = deque(token_ids)
DECODED_STRING_ARR: list[str] = [] UTF_8_STRING_ARR: bytearray = bytearray()
while len(token_stack) > 0: while len(token_stack) > 0:
TOKEN_ID = token_stack.pop() TOKEN_ID = token_stack.popleft()
if TOKEN_ID < 256: if TOKEN_ID < 256:
DECODED_CHAR = chr(TOKEN_ID) UTF_8_STRING_ARR.append(
DECODED_STRING_ARR.append( TOKEN_ID
DECODED_CHAR
) )
continue continue
left_token, right_token = self.__token_decode(TOKEN_ID) left_token, right_token = self.__token_decode(TOKEN_ID)
token_stack.append( token_stack.appendleft(
right_token right_token
) )
token_stack.append( token_stack.appendleft(
left_token left_token
) )
return "".join(DECODED_STRING_ARR) return UTF_8_STRING_ARR.decode("utf-8")
def __token_decode(self, token_id: int) -> tuple[int, int]: def __token_decode(self, token_id: int) -> tuple[int, int]: