Fixed decoding phase
This commit is contained in:
parent
eadba1fb82
commit
1eae8582b2
@ -1,3 +1,4 @@
|
|||||||
|
from collections import deque
|
||||||
from .Encoder import Encoder
|
from .Encoder import Encoder
|
||||||
from ..Errors import OutOfDictionaryException, DuplicateWordException
|
from ..Errors import OutOfDictionaryException, DuplicateWordException
|
||||||
|
|
||||||
@ -140,31 +141,30 @@ class NanoSocratesBPE(Encoder):
|
|||||||
return NEW_PIECE
|
return NEW_PIECE
|
||||||
|
|
||||||
# TODO: Remake decode to take a list of token IDs
|
# TODO: Remake decode to take a list of token IDs
|
||||||
def decode(self, token_id: int) -> str:
|
def decode(self, token_ids: list[int]) -> str:
|
||||||
|
|
||||||
token_stack: list[int] = [token_id]
|
token_stack: deque[int] = deque(token_ids)
|
||||||
DECODED_STRING_ARR: list[str] = []
|
UTF_8_STRING_ARR: bytearray = bytearray()
|
||||||
|
|
||||||
while len(token_stack) > 0:
|
while len(token_stack) > 0:
|
||||||
TOKEN_ID = token_stack.pop()
|
TOKEN_ID = token_stack.popleft()
|
||||||
|
|
||||||
if TOKEN_ID < 256:
|
if TOKEN_ID < 256:
|
||||||
DECODED_CHAR = chr(TOKEN_ID)
|
UTF_8_STRING_ARR.append(
|
||||||
DECODED_STRING_ARR.append(
|
TOKEN_ID
|
||||||
DECODED_CHAR
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
left_token, right_token = self.__token_decode(TOKEN_ID)
|
left_token, right_token = self.__token_decode(TOKEN_ID)
|
||||||
|
|
||||||
token_stack.append(
|
token_stack.appendleft(
|
||||||
right_token
|
right_token
|
||||||
)
|
)
|
||||||
token_stack.append(
|
token_stack.appendleft(
|
||||||
left_token
|
left_token
|
||||||
)
|
)
|
||||||
|
|
||||||
return "".join(DECODED_STRING_ARR)
|
return UTF_8_STRING_ARR.decode("utf-8")
|
||||||
|
|
||||||
def __token_decode(self, token_id: int) -> tuple[int, int]:
|
def __token_decode(self, token_id: int) -> tuple[int, int]:
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user