Fixed decoding phase
This commit is contained in:
parent
eadba1fb82
commit
1eae8582b2
@ -1,3 +1,4 @@
|
||||
from collections import deque
|
||||
from .Encoder import Encoder
|
||||
from ..Errors import OutOfDictionaryException, DuplicateWordException
|
||||
|
||||
@ -140,31 +141,30 @@ class NanoSocratesBPE(Encoder):
|
||||
return NEW_PIECE
|
||||
|
||||
# TODO: Remake decode to take a list of token IDs
|
||||
def decode(self, token_id: int) -> str:
|
||||
def decode(self, token_ids: list[int]) -> str:
|
||||
|
||||
token_stack: list[int] = [token_id]
|
||||
DECODED_STRING_ARR: list[str] = []
|
||||
token_stack: deque[int] = deque(token_ids)
|
||||
UTF_8_STRING_ARR: bytearray = bytearray()
|
||||
|
||||
while len(token_stack) > 0:
|
||||
TOKEN_ID = token_stack.pop()
|
||||
TOKEN_ID = token_stack.popleft()
|
||||
|
||||
if TOKEN_ID < 256:
|
||||
DECODED_CHAR = chr(TOKEN_ID)
|
||||
DECODED_STRING_ARR.append(
|
||||
DECODED_CHAR
|
||||
UTF_8_STRING_ARR.append(
|
||||
TOKEN_ID
|
||||
)
|
||||
continue
|
||||
|
||||
left_token, right_token = self.__token_decode(TOKEN_ID)
|
||||
|
||||
token_stack.append(
|
||||
token_stack.appendleft(
|
||||
right_token
|
||||
)
|
||||
token_stack.append(
|
||||
token_stack.appendleft(
|
||||
left_token
|
||||
)
|
||||
|
||||
return "".join(DECODED_STRING_ARR)
|
||||
return UTF_8_STRING_ARR.decode("utf-8")
|
||||
|
||||
def __token_decode(self, token_id: int) -> tuple[int, int]:
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user