Fixed evaluation
This commit is contained in:
parent
892f91aad7
commit
fe62b1edd5
@ -11,13 +11,13 @@ DEVICE = torch_shims.get_default_device()
|
|||||||
torch.set_default_device(DEVICE)
|
torch.set_default_device(DEVICE)
|
||||||
|
|
||||||
# Get paths
|
# Get paths
|
||||||
MODEL_DIR = "Assets/Model/curated"
|
# MODEL_DIR = "Assets/Model/curated"
|
||||||
# MODEL_DIR= "Assets/Dataset/Tmp"
|
MODEL_DIR= "Assets/Dataset/Tmp"
|
||||||
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
VOCABULARY_PATH = Path("Assets/Model/small/bpe-small-16.json")
|
||||||
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
|
TRAIN_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/train.csv")
|
||||||
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
|
VALIDATION_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/evaluation.csv")
|
||||||
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
|
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/small/holdout/test.csv")
|
||||||
# TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv")
|
TEST_DATASET_PATH = Path("Assets/Dataset/1-hop/toy/rdf_text.csv")
|
||||||
MODEL_PATH = Path(f"{MODEL_DIR}/NanoSocrates.zip")
|
MODEL_PATH = Path(f"{MODEL_DIR}/NanoSocrates.zip")
|
||||||
|
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ with torch.no_grad():
|
|||||||
|
|
||||||
if tasktype == Batch.TaskType.TEXT2RDF:
|
if tasktype == Batch.TaskType.TEXT2RDF:
|
||||||
ref = TUtils.remove_padding(exp_tokens, PAD_TOKEN, END_TOKEN)
|
ref = TUtils.remove_padding(exp_tokens, PAD_TOKEN, END_TOKEN)
|
||||||
pred = TUtils.remove_padding(tokens, PAD_TOKEN, END_TOKEN)
|
pred = TUtils.remove_padding(tokens[1:], PAD_TOKEN, END_TOKEN)
|
||||||
ref, pred = TUtils.balance_paddings(ref, pred, PAD_TOKEN)
|
ref, pred = TUtils.balance_paddings(ref, pred, PAD_TOKEN)
|
||||||
|
|
||||||
precision, recall = TUtils.txt2rdf(ref, pred)
|
precision, recall = TUtils.txt2rdf(ref, pred)
|
||||||
|
|||||||
@ -35,7 +35,8 @@ def rouge(ref: list[str], pred: list[str]):
|
|||||||
|
|
||||||
|
|
||||||
def f1(precision: float, recall: float):
|
def f1(precision: float, recall: float):
|
||||||
return (2 * recall * precision) / (precision + recall)
|
divisor = max((precision + recall), 1E-5)
|
||||||
|
return (2 * recall * precision) / divisor
|
||||||
|
|
||||||
|
|
||||||
def average(array: list[float]):
|
def average(array: list[float]):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user