diff --git a/Project_Model/Libs/Transformer/Utils/inference_masking.py b/Project_Model/Libs/Transformer/Utils/inference_masking.py new file mode 100644 index 0000000..4dc4345 --- /dev/null +++ b/Project_Model/Libs/Transformer/Utils/inference_masking.py @@ -0,0 +1,13 @@ +def inference_masking(sequence: list[int], mask_token: int, max_vocabulary: int) -> list[int]: + + current_mask_token = max_vocabulary + 1 + + for i in range(0, len(sequence)): + + if sequence[i] != mask_token: + continue + + sequence[i] = current_mask_token + current_mask_token += 1 + + return sequence