Added a util to make masked inference
This commit is contained in:
parent
9c1043e0ba
commit
0007c38212
13
Project_Model/Libs/Transformer/Utils/inference_masking.py
Normal file
13
Project_Model/Libs/Transformer/Utils/inference_masking.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
def inference_masking(sequence: list[int], mask_token: int, max_vocabulary: int) -> list[int]:
|
||||||
|
|
||||||
|
current_mask_token = max_vocabulary + 1
|
||||||
|
|
||||||
|
for i in range(0, len(sequence)):
|
||||||
|
|
||||||
|
if sequence[i] != mask_token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sequence[i] = current_mask_token
|
||||||
|
current_mask_token += 1
|
||||||
|
|
||||||
|
return sequence
|
||||||
Loading…
x
Reference in New Issue
Block a user