V0.0.1 Athene
This commit is contained in:
@@ -10,7 +10,7 @@ class SpannedMasker:
|
||||
max_vocabulary: int,
|
||||
forbidden_tokens: set[int],
|
||||
change_token_probability: float = 0.15,
|
||||
average_span: int = 1,
|
||||
average_span: int = 2,
|
||||
seed: int = random.randint(0, sys.maxsize),
|
||||
|
||||
) -> None:
|
||||
@@ -103,7 +103,7 @@ class SpannedMasker:
|
||||
|
||||
if self.__is_illegal_token(INNER_TOKEN, forbidden_tokens):
|
||||
continue
|
||||
|
||||
|
||||
MASK[mask_index] = True
|
||||
mask_index += 1
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from .post_tokenization import truncate_sequence, pad_sequence, normalize_sequen
|
||||
from .inference_masking import inference_masking
|
||||
from .truncate_rdf_list import truncate_rdf_list
|
||||
from .decode_out import tensor2token
|
||||
from .decoder_input import get_decoder_input
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -17,4 +18,5 @@ __all__ = [
|
||||
"inference_masking",
|
||||
"truncate_rdf_list",
|
||||
"tensor2token",
|
||||
"get_decoder_input"
|
||||
]
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from Project_Model.Libs.Transformer import normalize_sequence
|
||||
from ..Utils import normalize_sequence
|
||||
# from Project_Model.Libs.Embedder import NanoSocratesEmbedder as Embedder
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user