Added a util to make masked inference
This commit is contained in:
parent
9c1043e0ba
commit
0007c38212
13
Project_Model/Libs/Transformer/Utils/inference_masking.py
Normal file
13
Project_Model/Libs/Transformer/Utils/inference_masking.py
Normal file
@ -0,0 +1,13 @@
|
||||
def inference_masking(sequence: list[int], mask_token: int, max_vocabulary: int) -> list[int]:
|
||||
|
||||
current_mask_token = max_vocabulary + 1
|
||||
|
||||
for i in range(0, len(sequence)):
|
||||
|
||||
if sequence[i] != mask_token:
|
||||
continue
|
||||
|
||||
sequence[i] = current_mask_token
|
||||
current_mask_token += 1
|
||||
|
||||
return sequence
|
||||
Loading…
x
Reference in New Issue
Block a user