Merge branch 'dev.train' of https://repositories.communitynotfound.work/PoliBa-DeepLearning/NanoSocrates into dev.train
This commit is contained in:
commit
86a063591e
BIN
Assets/Model/curated/NanoSocrates.zip
(Stored with Git LFS)
Normal file
BIN
Assets/Model/curated/NanoSocrates.zip
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
Assets/Model/curated/dec_optim(5).zip
(Stored with Git LFS)
Normal file
BIN
Assets/Model/curated/dec_optim(5).zip
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
Assets/Model/curated/enc_optim(5).zip
(Stored with Git LFS)
Normal file
BIN
Assets/Model/curated/enc_optim(5).zip
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
Assets/Model/curated/last_epoch(5).txt
(Stored with Git LFS)
Normal file
BIN
Assets/Model/curated/last_epoch(5).txt
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
Assets/Model/curated/nano_optim(5).zip
(Stored with Git LFS)
Normal file
BIN
Assets/Model/curated/nano_optim(5).zip
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
Assets/Model/curated/settings.txt
(Stored with Git LFS)
Normal file
BIN
Assets/Model/curated/settings.txt
(Stored with Git LFS)
Normal file
Binary file not shown.
256
Playgrounds/evaluation.py
Normal file
256
Playgrounds/evaluation.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
import Project_Model.Libs.BPE as BPE
|
||||||
|
import Project_Model.Libs.Transformer as Transformer
|
||||||
|
import Project_Model.Libs.TransformerUtils as TUtils
|
||||||
|
import Project_Model.Libs.TorchShims as torch_shims
|
||||||
|
import Project_Model.Libs.Batch as Batch
|
||||||
|
|
||||||
|
# set a default device
|
||||||
|
DEVICE = torch_shims.get_default_device()
|
||||||
|
torch.set_default_device(DEVICE)
|
||||||
|
|
||||||
|
# Get paths
|
||||||
|
# MODEL_DIR = "Assets/Model/curated"
|
||||||
|
MODEL_DIR= "Assets/Dataset/Tmp"
|
||||||
|
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
||||||
|
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
|
||||||
|
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
|
||||||
|
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
|
||||||
|
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv")
|
||||||
|
MODEL_PATH = Path(f"{MODEL_DIR}/NanoSocrates.zip")
|
||||||
|
|
||||||
|
|
||||||
|
# BPE Init
|
||||||
|
SPECIAL_VOC = BPE.default_special_tokens()
|
||||||
|
VOCABULARY = BPE.load_nanos_vocabulary(VOCABULARY_PATH)
|
||||||
|
TOKENANO = BPE.TokeNanoCore(VOCABULARY, SPECIAL_VOC)
|
||||||
|
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
MASK_EXTRA_SPACE = 100
|
||||||
|
REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size
|
||||||
|
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
|
||||||
|
EMBEDDED_SIZE = 256
|
||||||
|
FEED_FORWARD_MULTIPLIER = 4
|
||||||
|
ATTENTION_HEADS = 4
|
||||||
|
SENTENCE_LENGTH = 256
|
||||||
|
NUMBER_OF_BLOCKS = 2
|
||||||
|
|
||||||
|
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||||
|
PAD_TOKEN = TOKENANO.encode("<PAD>")[0]
|
||||||
|
END_TOKEN = TOKENANO.encode("<EOS>")[0]
|
||||||
|
SUBJ_TOKEN = TOKENANO.encode("<SUBJ>")[0]
|
||||||
|
REL_TOKEN = TOKENANO.encode("<PRED>")[0]
|
||||||
|
OBJ_TOKEN = TOKENANO.encode("<OBJ>")[0]
|
||||||
|
MASK_TOKEN = TOKENANO.encode("<MASK>")[0]
|
||||||
|
|
||||||
|
SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens())))
|
||||||
|
ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN])
|
||||||
|
FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS
|
||||||
|
|
||||||
|
|
||||||
|
# Spanned_Masker
|
||||||
|
MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS)
|
||||||
|
|
||||||
|
TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
||||||
|
VALIDATION_BATCHER = Batch.Batcher(
|
||||||
|
VALIDATION_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER
|
||||||
|
)
|
||||||
|
TEST_BATCHER = Batch.Batcher(TEST_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER, debug=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Model
|
||||||
|
NANOSOCRATES_TRAIN = Transformer.TrainingModel(
|
||||||
|
TOKEN_SPACE_SIZE,
|
||||||
|
EMBEDDED_SIZE,
|
||||||
|
FEED_FORWARD_MULTIPLIER,
|
||||||
|
ATTENTION_HEADS,
|
||||||
|
NUMBER_OF_BLOCKS,
|
||||||
|
)
|
||||||
|
|
||||||
|
NANOSOCRATES = Transformer.NanoSocratesCore(
|
||||||
|
TOKEN_SPACE_SIZE,
|
||||||
|
SENTENCE_LENGTH,
|
||||||
|
SOS_TOKEN,
|
||||||
|
PAD_TOKEN,
|
||||||
|
END_TOKEN,
|
||||||
|
EMBEDDED_SIZE,
|
||||||
|
FEED_FORWARD_MULTIPLIER,
|
||||||
|
ATTENTION_HEADS,
|
||||||
|
NUMBER_OF_BLOCKS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if MODEL_PATH.is_file():
|
||||||
|
nanosocrates_dict = torch.load(MODEL_PATH, weights_only=True, map_location=DEVICE)
|
||||||
|
NANOSOCRATES_TRAIN.load_state_dict(nanosocrates_dict)
|
||||||
|
|
||||||
|
_, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
||||||
|
NANOSOCRATES, TOKEN_SPACE_SIZE, EMBEDDED_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
|
NANOSOCRATES = TUtils.train2inference(
|
||||||
|
NANOSOCRATES_TRAIN,
|
||||||
|
NANOSOCRATES
|
||||||
|
)
|
||||||
|
|
||||||
|
NANOSOCRATES.eval()
|
||||||
|
ENCODER_ONLY.eval()
|
||||||
|
DECODER_ONLY.eval()
|
||||||
|
NANOSOCRATES_TRAIN.eval()
|
||||||
|
|
||||||
|
task_1_metrics = []
|
||||||
|
task_2_metrics = []
|
||||||
|
task_3_metrics = []
|
||||||
|
task_4_metrics = []
|
||||||
|
|
||||||
|
example_num = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for example in TEST_BATCHER.batch(1):
|
||||||
|
|
||||||
|
|
||||||
|
print(f"DOING Example: {example_num}")
|
||||||
|
|
||||||
|
src_x, tgt_y, pad_x, pad_y, tasktype = example
|
||||||
|
|
||||||
|
enc_x = torch.tensor(src_x)
|
||||||
|
|
||||||
|
ACTUAL_BATCH_SIZE, _ = enc_x.shape
|
||||||
|
enc_x_pad = torch.tensor(pad_x, dtype=torch.bool)
|
||||||
|
tgt = torch.tensor(tgt_y)
|
||||||
|
tgt_pad = torch.tensor(pad_y, dtype=torch.bool)
|
||||||
|
|
||||||
|
dec_x = Transformer.get_decoder_input(
|
||||||
|
ACTUAL_BATCH_SIZE, SOS_TOKEN, PAD_TOKEN, SENTENCE_LENGTH
|
||||||
|
)
|
||||||
|
dec_x[:, 1:] = tgt[:, :-1]
|
||||||
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||||
|
|
||||||
|
out: torch.Tensor = NANOSOCRATES.inference((enc_x, enc_x_pad), tasktype)
|
||||||
|
|
||||||
|
tokens: list[int] = out.tolist()[0]
|
||||||
|
tokens.append(END_TOKEN)
|
||||||
|
tokens = list(map(lambda x: MASK_TOKEN if x > TOKENANO.vocabulary_size else x, tokens))
|
||||||
|
out_string = TOKENANO.decode(tokens)
|
||||||
|
|
||||||
|
exp_tokens: list[int] = tgt_y[0]
|
||||||
|
exp_tokens = list(map(lambda x: MASK_TOKEN if x > TOKENANO.vocabulary_size else x, exp_tokens))
|
||||||
|
exp_string = TOKENANO.decode(exp_tokens)
|
||||||
|
|
||||||
|
enc_tokens: list[int] = src_x[0]
|
||||||
|
enc_tokens = list(map(lambda x: MASK_TOKEN if x > TOKENANO.vocabulary_size else x, enc_tokens))
|
||||||
|
enc_string = TOKENANO.decode(enc_tokens)
|
||||||
|
|
||||||
|
print(f"PROMPT:\n{enc_string}")
|
||||||
|
print(f"EXPECTED:\n{exp_string}")
|
||||||
|
print(f"ACTUAL:\n{out_string}")
|
||||||
|
|
||||||
|
if tasktype == Batch.TaskType.RDF2TXT:
|
||||||
|
example_num += 1
|
||||||
|
ref = TUtils.remove_padding(exp_tokens, PAD_TOKEN, END_TOKEN)
|
||||||
|
pred = TUtils.remove_padding(tokens, PAD_TOKEN, END_TOKEN)
|
||||||
|
ref_str = TOKENANO.decode(ref)
|
||||||
|
pred_str = TOKENANO.decode(pred)
|
||||||
|
|
||||||
|
bleu, rouge, meteor = TUtils.rdf2txt([ref_str], [pred_str])
|
||||||
|
|
||||||
|
task_1_metrics.append(
|
||||||
|
[
|
||||||
|
bleu["bleu"], rouge["rougeL"], meteor["meteor"] # type: ignore
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if tasktype == Batch.TaskType.TEXT2RDF:
|
||||||
|
ref = TUtils.remove_padding(exp_tokens, PAD_TOKEN, END_TOKEN)
|
||||||
|
pred = TUtils.remove_padding(tokens[1:], PAD_TOKEN, END_TOKEN)
|
||||||
|
ref, pred = TUtils.balance_paddings(ref, pred, PAD_TOKEN)
|
||||||
|
|
||||||
|
precision, recall = TUtils.txt2rdf(ref, pred)
|
||||||
|
|
||||||
|
task_2_metrics.append(
|
||||||
|
[
|
||||||
|
precision["precision"], recall["recall"] # type: ignore
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if tasktype == Batch.TaskType.MASKING:
|
||||||
|
ref = TUtils.remove_padding(exp_tokens, PAD_TOKEN, END_TOKEN)
|
||||||
|
pred = TUtils.remove_padding(tokens, PAD_TOKEN, END_TOKEN)
|
||||||
|
ref, pred = TUtils.balance_paddings(ref, pred, PAD_TOKEN)
|
||||||
|
|
||||||
|
accuracy = TUtils.accuracy(ref, pred)
|
||||||
|
|
||||||
|
task_3_metrics.append(
|
||||||
|
|
||||||
|
accuracy["accuracy"] # type: ignore
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
if tasktype == Batch.TaskType.COMPLETATION:
|
||||||
|
|
||||||
|
ref = TUtils.remove_padding(exp_tokens, PAD_TOKEN, END_TOKEN)
|
||||||
|
pred = TUtils.remove_padding(tokens, PAD_TOKEN, END_TOKEN)
|
||||||
|
ref, pred = TUtils.balance_paddings(ref, pred, PAD_TOKEN)
|
||||||
|
|
||||||
|
precision, recall = TUtils.txt2rdf(ref, pred)
|
||||||
|
|
||||||
|
task_4_metrics.append(
|
||||||
|
[
|
||||||
|
precision["precision"], recall["recall"] # type: ignore
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
bleus = [row[0] for row in task_1_metrics]
|
||||||
|
rouges = [row[1] for row in task_1_metrics]
|
||||||
|
meteors = [row[2] for row in task_1_metrics]
|
||||||
|
|
||||||
|
prec_1 = [row[0] for row in task_2_metrics]
|
||||||
|
rec_1 = [row[1] for row in task_2_metrics]
|
||||||
|
|
||||||
|
acc = task_3_metrics
|
||||||
|
|
||||||
|
prec_2 = [row[0] for row in task_4_metrics]
|
||||||
|
rec_2 = [row[1] for row in task_4_metrics]
|
||||||
|
|
||||||
|
BLEU = TUtils.average(bleus)
|
||||||
|
ROUGE = TUtils.average(rouges)
|
||||||
|
METEOR = TUtils.average(meteors)
|
||||||
|
|
||||||
|
PREC_1 = TUtils.average(prec_1)
|
||||||
|
REC_1 = TUtils.average(rec_1)
|
||||||
|
F1_1 = TUtils.f1(PREC_1, REC_1)
|
||||||
|
|
||||||
|
ACC = TUtils.average(acc)
|
||||||
|
|
||||||
|
PREC_2 = TUtils.average(prec_2)
|
||||||
|
REC_2 = TUtils.average(rec_2)
|
||||||
|
F1_2 = TUtils.f1(PREC_2, REC_2)
|
||||||
|
|
||||||
|
SEPARATOR = "**************************************************************************"
|
||||||
|
OUTPUT = "".join([
|
||||||
|
f"{SEPARATOR}\n",
|
||||||
|
"*\tRDF2TXT:\n",
|
||||||
|
f"*\t\tBLEU: {BLEU} - ROUGE: {ROUGE} - METEOR: {METEOR}\n"
|
||||||
|
f"{SEPARATOR}\n",
|
||||||
|
"*\tTXT2RDF:\n",
|
||||||
|
f"*\t\tPRECISION: {PREC_1} - RECALL: {REC_1} - F1: {F1_1}\n"
|
||||||
|
f"{SEPARATOR}\n",
|
||||||
|
"*\tRDF Completion 1:\n",
|
||||||
|
f"*\t\tACCURACY: {ACC}\n"
|
||||||
|
f"{SEPARATOR}\n",
|
||||||
|
"*\tRDF Completion 2:\n",
|
||||||
|
f"*\t\tPRECISION: {PREC_2} - RECALL: {REC_2} - F1: {F1_2}\n"
|
||||||
|
f"{SEPARATOR}\n",
|
||||||
|
""
|
||||||
|
])
|
||||||
|
|
||||||
|
print(OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
|
print("\nDEBUG")
|
||||||
|
print(task_1_metrics)
|
||||||
|
print(task_2_metrics)
|
||||||
|
print(task_3_metrics)
|
||||||
|
print(task_4_metrics)
|
||||||
|
|
||||||
@ -24,9 +24,9 @@ torch.set_default_device(DEVICE)
|
|||||||
# Get paths
|
# Get paths
|
||||||
CHECKPOINT_DIR = "Assets/Dataset/Tmp"
|
CHECKPOINT_DIR = "Assets/Dataset/Tmp"
|
||||||
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
||||||
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
|
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv")
|
||||||
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
|
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv")
|
||||||
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
|
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv")
|
||||||
CHECKPOINT_PATH = Path(f"{CHECKPOINT_DIR}/NanoSocrates.zip")
|
CHECKPOINT_PATH = Path(f"{CHECKPOINT_DIR}/NanoSocrates.zip")
|
||||||
|
|
||||||
NANO_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/nano_optim.zip")
|
NANO_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/nano_optim.zip")
|
||||||
@ -51,19 +51,20 @@ REAL_TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size
|
|||||||
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
|
TOKEN_SPACE_SIZE = TOKENANO.vocabulary_size + MASK_EXTRA_SPACE
|
||||||
EMBEDDED_SIZE = 256
|
EMBEDDED_SIZE = 256
|
||||||
FEED_FORWARD_MULTIPLIER = 4
|
FEED_FORWARD_MULTIPLIER = 4
|
||||||
ATTENTION_HEADS = 8
|
ATTENTION_HEADS = 4
|
||||||
SENTENCE_LENGTH = 256
|
SENTENCE_LENGTH = 256
|
||||||
NUMBER_OF_BLOCKS = 4
|
NUMBER_OF_BLOCKS = 2
|
||||||
MAX_EPOCHS = int(3e3)
|
MAX_EPOCHS = int(300)
|
||||||
PRETRAIN_EPOCHS = int(300)
|
PRETRAIN_EPOCHS = int(20)
|
||||||
WARMUP_EPOCHS = int(1e3)
|
WARMUP_EPOCHS = int(30)
|
||||||
MINI_BATCH_SIZE = 80
|
MINI_BATCH_SIZE = 20
|
||||||
VALIDATION_STEPS = 5
|
VALIDATION_STEPS = 10
|
||||||
CHECKPOINT_STEPS = VALIDATION_STEPS * 4
|
CHECKPOINT_STEPS = VALIDATION_STEPS
|
||||||
PATIENCE = 4
|
PATIENCE = 4
|
||||||
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
|
CURRENT_EPOCH = -1 if not LAST_EPOCH_PATH.is_file() else int(LAST_EPOCH_PATH.read_text())
|
||||||
VERBOSE = True
|
VERBOSE = False
|
||||||
LEARNING_RATE = 1.5
|
LEARNING_RATE = 0.05
|
||||||
|
LABEL_SMOOTHING = 0.01
|
||||||
|
|
||||||
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
SOS_TOKEN = TOKENANO.encode("<SOS>")[0]
|
||||||
|
|
||||||
@ -72,6 +73,7 @@ END_TOKEN = TOKENANO.encode("<END>")[0]
|
|||||||
SUBJ_TOKEN = TOKENANO.encode("<SUBJ>")[0]
|
SUBJ_TOKEN = TOKENANO.encode("<SUBJ>")[0]
|
||||||
REL_TOKEN = TOKENANO.encode("<PRED>")[0]
|
REL_TOKEN = TOKENANO.encode("<PRED>")[0]
|
||||||
OBJ_TOKEN = TOKENANO.encode("<OBJ>")[0]
|
OBJ_TOKEN = TOKENANO.encode("<OBJ>")[0]
|
||||||
|
MASK_TOKEN = TOKENANO.encode("<MASK>")[0]
|
||||||
|
|
||||||
SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens())))
|
SPECIAL_TOKENS: set[int] = set(TOKENANO.encode("".join(BPE.default_special_tokens())))
|
||||||
ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN])
|
ALLOWED_TOKENS = set([SUBJ_TOKEN, REL_TOKEN, OBJ_TOKEN])
|
||||||
@ -79,7 +81,7 @@ FORBIDDEN_TOKENS = SPECIAL_TOKENS - ALLOWED_TOKENS
|
|||||||
|
|
||||||
|
|
||||||
# Spanned_Masker
|
# Spanned_Masker
|
||||||
MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS)
|
MASKER = Transformer.SpannedMasker(REAL_TOKEN_SPACE_SIZE, FORBIDDEN_TOKENS, average_span=4)
|
||||||
|
|
||||||
TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
TRAIN_BATCHER = Batch.Batcher(TRAIN_DATASET_PATH, SENTENCE_LENGTH, TOKENANO, MASKER)
|
||||||
VALIDATION_BATCHER = Batch.Batcher(
|
VALIDATION_BATCHER = Batch.Batcher(
|
||||||
@ -107,9 +109,9 @@ _, ENCODER_ONLY, DECODER_ONLY = TUtils.decompose_nano_socrates(
|
|||||||
|
|
||||||
|
|
||||||
# Training constants
|
# Training constants
|
||||||
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
nano_cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
|
||||||
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
encoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
|
||||||
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
decoder_ce = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
|
||||||
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
nano_optim = torch.optim.AdamW(NANOSOCRATES.parameters(), LEARNING_RATE)
|
||||||
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
encoder_only_optim = torch.optim.AdamW(ENCODER_ONLY.parameters(), LEARNING_RATE)
|
||||||
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
decoder_only_optim = torch.optim.AdamW(DECODER_ONLY.parameters(), LEARNING_RATE)
|
||||||
@ -179,6 +181,24 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
dec_x[:, 1:] = tgt[:, :-1]
|
dec_x[:, 1:] = tgt[:, :-1]
|
||||||
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
dec_x_pad = dec_x.eq(PAD_TOKEN)
|
||||||
|
|
||||||
|
if VERBOSE:
|
||||||
|
for s in TUtils.decode_batch(enc_x, TOKENANO, MASK_TOKEN):
|
||||||
|
print("Input")
|
||||||
|
print(s)
|
||||||
|
|
||||||
|
for s in TUtils.decode_batch(enc_x_pad, TOKENANO, MASK_TOKEN):
|
||||||
|
print("Encoder Padding mask")
|
||||||
|
print(s)
|
||||||
|
|
||||||
|
for s in TUtils.decode_batch(tgt, TOKENANO, MASK_TOKEN):
|
||||||
|
print("Desired Output")
|
||||||
|
print(s)
|
||||||
|
a_dx = dec_x[:,:]
|
||||||
|
a_dx[:, -1]= END_TOKEN
|
||||||
|
for s in TUtils.decode_batch(a_dx, TOKENANO, MASK_TOKEN):
|
||||||
|
print("Decoder Input")
|
||||||
|
print(s)
|
||||||
|
|
||||||
if VERBOSE:
|
if VERBOSE:
|
||||||
print(f"\tBATCH {batch_counter} Starting")
|
print(f"\tBATCH {batch_counter} Starting")
|
||||||
|
|
||||||
@ -346,24 +366,45 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
average_loss_validation["txt"] = txt_avg_loss
|
average_loss_validation["txt"] = txt_avg_loss
|
||||||
else:
|
else:
|
||||||
patience += 1
|
patience += 1
|
||||||
|
if VERBOSE:
|
||||||
|
print(f"losing a patience, current irritation: {patience}")
|
||||||
else:
|
else:
|
||||||
|
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
if txt_avg_loss > average_loss_validation["txt"]:
|
if txt_avg_loss > average_loss_validation["txt"]:
|
||||||
counter += 1
|
|
||||||
|
|
||||||
if txt_avg_loss > average_loss_validation["encoder_only"]:
|
if VERBOSE:
|
||||||
|
print("txt average is higher than lowest")
|
||||||
counter += 1
|
counter += 1
|
||||||
|
else:
|
||||||
|
average_loss_validation["txt"] = txt_avg_loss
|
||||||
|
|
||||||
if txt_avg_loss > average_loss_validation["decoder_only"]:
|
if enc_avg_loss > average_loss_validation["encoder_only"]:
|
||||||
|
if VERBOSE:
|
||||||
|
print("masking average is higher than lowest")
|
||||||
counter += 1
|
counter += 1
|
||||||
|
else:
|
||||||
|
average_loss_validation["encoder_only"] = enc_avg_loss
|
||||||
|
|
||||||
|
if dec_avg_loss > average_loss_validation["decoder_only"]:
|
||||||
|
if VERBOSE:
|
||||||
|
print("decoding only average is higher than lowest")
|
||||||
|
counter += 1
|
||||||
|
else:
|
||||||
|
average_loss_validation["decoder_only"] = dec_avg_loss
|
||||||
|
|
||||||
if counter > 1:
|
if counter > 1:
|
||||||
patience += 1
|
patience += 1
|
||||||
|
if VERBOSE:
|
||||||
|
print(f"losing a patience, current irritation: {patience}")
|
||||||
|
|
||||||
|
|
||||||
if counter == 0:
|
if counter == 0:
|
||||||
patience = max(0, patience - 1)
|
patience = max(0, patience - 1)
|
||||||
|
if VERBOSE:
|
||||||
|
print(f"all good, gaining a patience, current irritation: {patience}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)
|
txt_train_avg_loss = sum(text_batch_losses) / len(text_batch_losses)
|
||||||
|
|||||||
@ -28,6 +28,7 @@ class Batcher:
|
|||||||
tokenizer: BPE.TokeNanoCore,
|
tokenizer: BPE.TokeNanoCore,
|
||||||
masker: SpannedMasker,
|
masker: SpannedMasker,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
|
debug = False
|
||||||
) -> None:
|
) -> None:
|
||||||
# ABSTRACT, TRIPLE
|
# ABSTRACT, TRIPLE
|
||||||
# tasks:
|
# tasks:
|
||||||
@ -44,6 +45,7 @@ class Batcher:
|
|||||||
self._seed = seed
|
self._seed = seed
|
||||||
# self._token_completation = TokenCompletationTransformer(sotl,eos)
|
# self._token_completation = TokenCompletationTransformer(sotl,eos)
|
||||||
self._completation_task_token_truncator = truncate_rdf_list
|
self._completation_task_token_truncator = truncate_rdf_list
|
||||||
|
self.__debug = debug
|
||||||
|
|
||||||
def batch(self, batch_size) -> Generator[
|
def batch(self, batch_size) -> Generator[
|
||||||
tuple[
|
tuple[
|
||||||
@ -142,6 +144,7 @@ class Batcher:
|
|||||||
return out_X, out_Y, padding_X, padding_Y
|
return out_X, out_Y, padding_X, padding_Y
|
||||||
|
|
||||||
def __rdf2txt_transformation(self, batch: pd.DataFrame):
|
def __rdf2txt_transformation(self, batch: pd.DataFrame):
|
||||||
|
X: list[list[int]]
|
||||||
task_token = self._tokenizer.encode(SpecialToken.RDF_TO_TEXT.value)
|
task_token = self._tokenizer.encode(SpecialToken.RDF_TO_TEXT.value)
|
||||||
out = batch.rename(columns={"RDFs": "X", "Abstract": "Y"})[["X", "Y"]]
|
out = batch.rename(columns={"RDFs": "X", "Abstract": "Y"})[["X", "Y"]]
|
||||||
out["X"] = [task_token + x for x in out["X"]]
|
out["X"] = [task_token + x for x in out["X"]]
|
||||||
@ -157,7 +160,7 @@ class Batcher:
|
|||||||
X = []
|
X = []
|
||||||
Y = []
|
Y = []
|
||||||
for rdf in batch["RDFs"]:
|
for rdf in batch["RDFs"]:
|
||||||
x, y = self._masker.mask_sequence(rdf)
|
x, y = self._masker.mask_sequence(rdf[:self.__max_length])
|
||||||
X.append(x)
|
X.append(x)
|
||||||
Y.append(y)
|
Y.append(y)
|
||||||
return self.__normalization(X, Y)
|
return self.__normalization(X, Y)
|
||||||
|
|||||||
@ -83,7 +83,14 @@ class NanoSocratesCore(torch.nn.Module):
|
|||||||
x, padding = args
|
x, padding = args
|
||||||
|
|
||||||
encoder_tensor = self.__encoder_embedder(x)
|
encoder_tensor = self.__encoder_embedder(x)
|
||||||
|
|
||||||
|
BATCH: int
|
||||||
|
|
||||||
|
if len(x.shape) > 2:
|
||||||
BATCH, SEQ_LEN, _ = x.shape
|
BATCH, SEQ_LEN, _ = x.shape
|
||||||
|
else:
|
||||||
|
_, SEQ_LEN = x.shape
|
||||||
|
BATCH = 1
|
||||||
|
|
||||||
encoder_output, _ = self.__encoder((encoder_tensor, padding))
|
encoder_output, _ = self.__encoder((encoder_tensor, padding))
|
||||||
|
|
||||||
@ -95,25 +102,32 @@ class NanoSocratesCore(torch.nn.Module):
|
|||||||
|
|
||||||
while continue_generating:
|
while continue_generating:
|
||||||
|
|
||||||
decoder_in = self.__decoder_embedder(decoder_in)
|
decoder_in_x = self.__decoder_embedder(decoder_in)
|
||||||
|
|
||||||
decoder_output, _, _, _, _, _ = self.__decoder(
|
decoder_output, _, _, _, _, _ = self.__decoder(
|
||||||
(decoder_in, encoder_output, encoder_output, padding, decoder_in_pad_mask, False)
|
(decoder_in_x, encoder_output, encoder_output, padding, decoder_in_pad_mask, False)
|
||||||
)
|
)
|
||||||
|
|
||||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||||
|
|
||||||
logits = torch.softmax(logits, 2)
|
logits = torch.softmax(logits, 2)
|
||||||
|
|
||||||
tokens = torch.argmax(logits)
|
tokens = torch.argmax(logits, 2)
|
||||||
|
|
||||||
|
if token_idx < self.__sentence_len - 1:
|
||||||
|
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
|
||||||
|
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
||||||
|
|
||||||
|
if token_idx == self.__sentence_len - 1:
|
||||||
|
continue_generating = False
|
||||||
|
continue
|
||||||
|
|
||||||
if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos:
|
if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos:
|
||||||
continue_generating = False
|
continue_generating = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if token_idx < self.__sentence_len - 1:
|
|
||||||
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
|
token_idx += 1
|
||||||
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
|
||||||
|
|
||||||
return decoder_in
|
return decoder_in
|
||||||
|
|
||||||
@ -130,7 +144,7 @@ class NanoSocratesCore(torch.nn.Module):
|
|||||||
|
|
||||||
logits = torch.softmax(logits, 2)
|
logits = torch.softmax(logits, 2)
|
||||||
|
|
||||||
tokens = torch.argmax(logits)
|
tokens = torch.argmax(logits, 2)
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@ -146,31 +160,56 @@ class NanoSocratesCore(torch.nn.Module):
|
|||||||
|
|
||||||
while continue_generating:
|
while continue_generating:
|
||||||
|
|
||||||
decoder_in = self.__decoder_embedder(decoder_in)
|
decoder_x = self.__decoder_embedder(decoder_in)
|
||||||
|
|
||||||
decoder_output, _, _, _, _, _ = self.__decoder(
|
decoder_output, _, _, _, _, _ = self.__decoder(
|
||||||
(decoder_in, decoder_in, decoder_in, decoder_in_prefix_mask, decoder_in_pad_mask, False)
|
(decoder_x, decoder_in, decoder_in, decoder_in_prefix_mask, decoder_in_pad_mask, True)
|
||||||
)
|
)
|
||||||
|
|
||||||
logits: torch.Tensor = self.__detokener(decoder_output)
|
logits: torch.Tensor = self.__detokener(decoder_output)
|
||||||
|
|
||||||
logits = torch.softmax(logits, 2)
|
logits = torch.softmax(logits, 2)
|
||||||
|
|
||||||
tokens = torch.argmax(logits)
|
tokens = torch.argmax(logits, 2)
|
||||||
|
|
||||||
|
if token_idx < self.__sentence_len - 1:
|
||||||
|
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
|
||||||
|
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
||||||
|
|
||||||
|
if token_idx == self.__sentence_len - 1:
|
||||||
|
continue_generating = False
|
||||||
|
continue
|
||||||
|
|
||||||
if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos:
|
if tokens.shape[0] == 1 and tokens[0,token_idx] == self.__eos:
|
||||||
continue_generating = False
|
continue_generating = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if token_idx < self.__sentence_len - 1:
|
token_idx += 1
|
||||||
decoder_in[:,token_idx + 1] = tokens[:,token_idx]
|
|
||||||
decoder_in_pad_mask = decoder_in.eq(self.__pad)
|
|
||||||
|
|
||||||
|
|
||||||
return decoder_in
|
return decoder_in
|
||||||
|
|
||||||
def take_pieces(self):
|
def take_pieces(self):
|
||||||
|
|
||||||
return (
|
return (
|
||||||
(self.__encoder_embedder, self.__encoder),
|
(self.__encoder_embedder, self.__encoder, self.__encoder_detokener),
|
||||||
(self.__decoder_embedder, self.__decoder, self.__detokener)
|
(self.__decoder_embedder, self.__decoder, self.__detokener)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_pieces(
|
||||||
|
self,
|
||||||
|
encoder_embedder: Embedder.NanoSocratesEmbedder,
|
||||||
|
decoder_embedder: Embedder.NanoSocratesEmbedder,
|
||||||
|
encoder: torch.nn.Sequential,
|
||||||
|
decoder: torch.nn.Sequential,
|
||||||
|
encoder_detokener: DeToken,
|
||||||
|
decoder_detokener: DeToken
|
||||||
|
):
|
||||||
|
self.__encoder_embedder = encoder_embedder
|
||||||
|
self.__decoder_embedder = decoder_embedder
|
||||||
|
self.__encoder = encoder
|
||||||
|
self.__decoder = decoder
|
||||||
|
self.__encoder_detokener = encoder_detokener
|
||||||
|
self.__detokener = decoder_detokener
|
||||||
@ -1,9 +1,11 @@
|
|||||||
from .TrainingModel import TrainingModel
|
from .TrainingModel import TrainingModel
|
||||||
from .NanoSocratEncoder import NanoSocratEncoder
|
from .NanoSocratEncoder import NanoSocratEncoder
|
||||||
from .NanoSocraDecoder import NanoSocraDecoder
|
from .NanoSocraDecoder import NanoSocraDecoder
|
||||||
|
from .NanoSocrates import NanoSocratesCore
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingModel",
|
"TrainingModel",
|
||||||
"NanoSocratEncoder",
|
"NanoSocratEncoder",
|
||||||
"NanoSocraDecoder"
|
"NanoSocraDecoder",
|
||||||
|
"NanoSocratesCore"
|
||||||
]
|
]
|
||||||
@ -1,6 +1,7 @@
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
|
||||||
class TaskType(Enum):
|
class TaskType(Enum):
|
||||||
|
TEXT2RDF = auto()
|
||||||
RDF2TEXT = auto()
|
RDF2TEXT = auto()
|
||||||
MASK = auto()
|
MASK = auto()
|
||||||
COMPLETATION = auto()
|
COMPLETATION = auto()
|
||||||
@ -1,8 +1,14 @@
|
|||||||
from .model_utils import decompose_nano_socrates, create_standalone_model
|
from .model_utils import decompose_nano_socrates, create_standalone_model, train2inference
|
||||||
from .ModelType import ModelType
|
from .ModelType import ModelType
|
||||||
|
from .decode_batch import decode_batch
|
||||||
|
from .metrics import precision, recall, accuracy, f1, meteor, bleu, rouge, average, rdf2txt, txt2rdf, rdf_completion_1, rdf_completion_2, remove_padding, balance_paddings
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelType",
|
"ModelType",
|
||||||
"decompose_nano_socrates",
|
"decompose_nano_socrates",
|
||||||
"create_standalone_model"
|
"create_standalone_model",
|
||||||
|
"decode_batch",
|
||||||
|
"train2inference",
|
||||||
|
"precision", "recall", "accuracy", "f1", "meteor", "bleu", "rouge", "average",
|
||||||
|
"rdf2txt", "txt2rdf", "rdf_completion_1", "rdf_completion_2", "remove_padding", "balance_paddings"
|
||||||
]
|
]
|
||||||
16
Project_Model/Libs/TransformerUtils/decode_batch.py
Normal file
16
Project_Model/Libs/TransformerUtils/decode_batch.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import torch
|
||||||
|
import Project_Model.Libs.BPE as BPE
|
||||||
|
|
||||||
|
def decode_batch(batch: torch.Tensor, tokenizer: BPE.TokeNanoCore ,uknonw_token: int) -> list[str]:
|
||||||
|
|
||||||
|
strings = []
|
||||||
|
|
||||||
|
BATCH, _ = batch.shape
|
||||||
|
|
||||||
|
for i in range(0, BATCH):
|
||||||
|
|
||||||
|
tokens: list[int] = batch.tolist()[i]
|
||||||
|
tokens = list(map(lambda x: uknonw_token if x > tokenizer.vocabulary_size else x, tokens))
|
||||||
|
strings.append(tokenizer.decode(tokens))
|
||||||
|
|
||||||
|
return strings
|
||||||
100
Project_Model/Libs/TransformerUtils/metrics.py
Normal file
100
Project_Model/Libs/TransformerUtils/metrics.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import evaluate as eval
|
||||||
|
|
||||||
|
BLEU = eval.load("bleu")
|
||||||
|
ROUGE = eval.load("rouge")
|
||||||
|
METEOR = eval.load("meteor")
|
||||||
|
|
||||||
|
def precision(ref: list[int], pred: list[int]):
|
||||||
|
metric = eval.load("precision")
|
||||||
|
return metric.compute(predictions=pred, references=ref, average="weighted", zero_division=0)
|
||||||
|
|
||||||
|
|
||||||
|
def recall(ref: list[int], pred: list[int]):
|
||||||
|
metric = eval.load("recall")
|
||||||
|
return metric.compute(predictions=pred, references=ref, average="weighted", zero_division=0)
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy(ref: list[int], pred: list[int]):
|
||||||
|
metric = eval.load("accuracy")
|
||||||
|
return metric.compute(predictions=pred, references=ref)
|
||||||
|
|
||||||
|
|
||||||
|
def meteor(ref: list[str], pred: list[str]):
|
||||||
|
metric = METEOR
|
||||||
|
return metric.compute(predictions=pred, references=ref)
|
||||||
|
|
||||||
|
|
||||||
|
def bleu(ref: list[str], pred: list[str]):
|
||||||
|
metric = BLEU
|
||||||
|
return metric.compute(predictions=pred, references=ref)
|
||||||
|
|
||||||
|
|
||||||
|
def rouge(ref: list[str], pred: list[str]):
|
||||||
|
metric = ROUGE
|
||||||
|
return metric.compute(predictions=pred, references=ref)
|
||||||
|
|
||||||
|
|
||||||
|
def f1(precision: float, recall: float):
|
||||||
|
divisor = max((precision + recall), 1E-5)
|
||||||
|
return (2 * recall * precision) / divisor
|
||||||
|
|
||||||
|
|
||||||
|
def average(array: list[float]):
|
||||||
|
return sum(array) / len(array)
|
||||||
|
|
||||||
|
|
||||||
|
def rdf2txt(ref: list[str], pred: list[str]):
|
||||||
|
|
||||||
|
b_m = bleu(ref, pred)
|
||||||
|
r_m = rouge(ref, pred)
|
||||||
|
m_m = meteor(ref, pred)
|
||||||
|
|
||||||
|
return (b_m, r_m, m_m)
|
||||||
|
|
||||||
|
def txt2rdf(ref: list[int], pred: list[int]):
|
||||||
|
|
||||||
|
p_m = precision(ref, pred)
|
||||||
|
r_m = recall(ref, pred)
|
||||||
|
|
||||||
|
return (p_m, r_m)
|
||||||
|
|
||||||
|
def rdf_completion_1(ref: list[int], pred: list[int]):
|
||||||
|
|
||||||
|
a_m = accuracy(ref, pred)
|
||||||
|
|
||||||
|
return a_m
|
||||||
|
|
||||||
|
|
||||||
|
def rdf_completion_2(ref: list[int], pred: list[int]):
|
||||||
|
|
||||||
|
p_m = precision(ref, pred)
|
||||||
|
r_m = recall(ref, pred)
|
||||||
|
|
||||||
|
return (p_m, r_m)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_padding(seq: list[int], pad_token: int, end_token: int):
|
||||||
|
clean_seq = list(filter(lambda x: x != pad_token, seq))
|
||||||
|
|
||||||
|
if clean_seq[-1] == end_token:
|
||||||
|
return clean_seq
|
||||||
|
|
||||||
|
clean_seq.append(
|
||||||
|
end_token
|
||||||
|
)
|
||||||
|
|
||||||
|
return clean_seq
|
||||||
|
|
||||||
|
|
||||||
|
def balance_paddings(seq_1: list[int], seq_2: list[int], pad_token: int):
|
||||||
|
SEQ_1_LEN = len(seq_1)
|
||||||
|
SEQ_2_LEN = len(seq_2)
|
||||||
|
|
||||||
|
if SEQ_1_LEN > SEQ_2_LEN:
|
||||||
|
PAD = [pad_token] * (SEQ_1_LEN - SEQ_2_LEN)
|
||||||
|
seq_2.extend(PAD)
|
||||||
|
|
||||||
|
if SEQ_2_LEN > SEQ_1_LEN:
|
||||||
|
seq_2 = seq_2[:SEQ_1_LEN]
|
||||||
|
|
||||||
|
return (seq_1, seq_2)
|
||||||
@ -1,13 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
from Project_Model.Libs.Embedder import NanoSocratesEmbedder
|
from Project_Model.Libs.Embedder import NanoSocratesEmbedder
|
||||||
from Project_Model.Libs.Transformer import TrainingModel, NanoSocraDecoder, NanoSocratEncoder, DeToken, Encoder, Decoder
|
from Project_Model.Libs.Transformer import TrainingModel,NanoSocratesCore, NanoSocraDecoder, NanoSocratEncoder, DeToken, Encoder, Decoder
|
||||||
from .ModelType import ModelType
|
from .ModelType import ModelType
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def decompose_nano_socrates(
|
def decompose_nano_socrates(
|
||||||
model: TrainingModel, vocabulary_size: int, embedding_size: int
|
model: TrainingModel | NanoSocratesCore , vocabulary_size: int, embedding_size: int
|
||||||
) -> tuple[TrainingModel, NanoSocratEncoder, NanoSocraDecoder]:
|
) -> tuple[TrainingModel | NanoSocratesCore, NanoSocratEncoder, NanoSocraDecoder]:
|
||||||
|
|
||||||
encoder_pieces, decoder_pieces = model.take_pieces()
|
encoder_pieces, decoder_pieces = model.take_pieces()
|
||||||
encoder_embedder, encoder, encoder_detokener = encoder_pieces
|
encoder_embedder, encoder, encoder_detokener = encoder_pieces
|
||||||
@ -19,6 +19,26 @@ def decompose_nano_socrates(
|
|||||||
NanoSocraDecoder(decoder_embedder, decoder, decoder_detokener),
|
NanoSocraDecoder(decoder_embedder, decoder, decoder_detokener),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def train2inference(
|
||||||
|
train_model: TrainingModel,
|
||||||
|
inference_model: NanoSocratesCore
|
||||||
|
) -> NanoSocratesCore:
|
||||||
|
|
||||||
|
encoder_pieces, decoder_pieces = train_model.take_pieces()
|
||||||
|
enc_emb, encoder, enc_det = encoder_pieces
|
||||||
|
dec_emb, decoder, dec_det = decoder_pieces
|
||||||
|
inference_model.load_pieces(
|
||||||
|
enc_emb,
|
||||||
|
dec_emb,
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
enc_det,
|
||||||
|
dec_det
|
||||||
|
)
|
||||||
|
|
||||||
|
return inference_model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_standalone_model(
|
def create_standalone_model(
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user