diff --git a/Project_Model/Libs/Transformer/Classes/SpannedMasker.py b/Project_Model/Libs/Transformer/Classes/SpannedMasker.py index 156f512..441a3d8 100644 --- a/Project_Model/Libs/Transformer/Classes/SpannedMasker.py +++ b/Project_Model/Libs/Transformer/Classes/SpannedMasker.py @@ -90,6 +90,11 @@ class SpannedMasker: SPAN_LENGTH = min(CANDIDATE_SPAN, REMAINING_MASK) for _ in range(0, SPAN_LENGTH): + INNER_TOKEN = sequence[mask_index] + + if self.__is_illegal_token(INNER_TOKEN, forbidden_tokens): + continue + MASK[mask_index] = True mask_index += 1