diff --git a/Project_Model/Libs/BPE/Classes/NanoSocratesBPE.py b/Project_Model/Libs/BPE/Classes/NanoSocratesBPE.py index 4d44884..6428cb7 100644 --- a/Project_Model/Libs/BPE/Classes/NanoSocratesBPE.py +++ b/Project_Model/Libs/BPE/Classes/NanoSocratesBPE.py @@ -1,3 +1,4 @@ +from collections import deque from .Encoder import Encoder from ..Errors import OutOfDictionaryException, DuplicateWordException @@ -140,31 +141,30 @@ class NanoSocratesBPE(Encoder): return NEW_PIECE # TODO: Remake decode to take a list of token IDs - def decode(self, token_id: int) -> str: + def decode(self, token_ids: list[int]) -> str: - token_stack: list[int] = [token_id] - DECODED_STRING_ARR: list[str] = [] + token_stack: deque[int] = deque(token_ids) + UTF_8_STRING_ARR: bytearray = bytearray() while len(token_stack) > 0: - TOKEN_ID = token_stack.pop() + TOKEN_ID = token_stack.popleft() if TOKEN_ID < 256: - DECODED_CHAR = chr(TOKEN_ID) - DECODED_STRING_ARR.append( - DECODED_CHAR + UTF_8_STRING_ARR.append( + TOKEN_ID ) continue left_token, right_token = self.__token_decode(TOKEN_ID) - token_stack.append( + token_stack.appendleft( right_token ) - token_stack.append( + token_stack.appendleft( left_token ) - return "".join(DECODED_STRING_ARR) + return UTF_8_STRING_ARR.decode("utf-8") def __token_decode(self, token_id: int) -> tuple[int, int]: