Fixes for evaluation
This commit is contained in:
99
Project_Model/Libs/TransformerUtils/metrics.py
Normal file
99
Project_Model/Libs/TransformerUtils/metrics.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import evaluate as eval
|
||||
|
||||
BLEU = eval.load("bleu")
|
||||
ROUGE = eval.load("rouge")
|
||||
METEOR = eval.load("meteor")
|
||||
|
||||
def precision(ref: list[int], pred: list[int]):
|
||||
metric = eval.load("precision")
|
||||
return metric.compute(predictions=pred, references=ref, average="weighted", zero_division=0)
|
||||
|
||||
|
||||
def recall(ref: list[int], pred: list[int]):
|
||||
metric = eval.load("recall")
|
||||
return metric.compute(predictions=pred, references=ref, average="weighted", zero_division=0)
|
||||
|
||||
|
||||
def accuracy(ref: list[int], pred: list[int]):
|
||||
metric = eval.load("accuracy")
|
||||
return metric.compute(predictions=pred, references=ref)
|
||||
|
||||
|
||||
def meteor(ref: list[str], pred: list[str]):
|
||||
metric = METEOR
|
||||
return metric.compute(predictions=pred, references=ref)
|
||||
|
||||
|
||||
def bleu(ref: list[str], pred: list[str]):
|
||||
metric = BLEU
|
||||
return metric.compute(predictions=pred, references=ref)
|
||||
|
||||
|
||||
def rouge(ref: list[str], pred: list[str]):
|
||||
metric = ROUGE
|
||||
return metric.compute(predictions=pred, references=ref)
|
||||
|
||||
|
||||
def f1(precision: float, recall: float):
|
||||
return (2 * recall * precision) / (precision + recall)
|
||||
|
||||
|
||||
def average(array: list[float]):
|
||||
return sum(array) / len(array)
|
||||
|
||||
|
||||
def rdf2txt(ref: list[str], pred: list[str]):
|
||||
|
||||
b_m = bleu(ref, pred)
|
||||
r_m = rouge(ref, pred)
|
||||
m_m = meteor(ref, pred)
|
||||
|
||||
return (b_m, r_m, m_m)
|
||||
|
||||
def txt2rdf(ref: list[int], pred: list[int]):
|
||||
|
||||
p_m = precision(ref, pred)
|
||||
r_m = recall(ref, pred)
|
||||
|
||||
return (p_m, r_m)
|
||||
|
||||
def rdf_completion_1(ref: list[int], pred: list[int]):
|
||||
|
||||
a_m = accuracy(ref, pred)
|
||||
|
||||
return a_m
|
||||
|
||||
|
||||
def rdf_completion_2(ref: list[int], pred: list[int]):
|
||||
|
||||
p_m = precision(ref, pred)
|
||||
r_m = recall(ref, pred)
|
||||
|
||||
return (p_m, r_m)
|
||||
|
||||
|
||||
def remove_padding(seq: list[int], pad_token: int, end_token: int):
|
||||
clean_seq = list(filter(lambda x: x != pad_token, seq))
|
||||
|
||||
if clean_seq[-1] == end_token:
|
||||
return clean_seq
|
||||
|
||||
clean_seq.append(
|
||||
end_token
|
||||
)
|
||||
|
||||
return clean_seq
|
||||
|
||||
|
||||
def balance_paddings(seq_1: list[int], seq_2: list[int], pad_token: int):
|
||||
SEQ_1_LEN = len(seq_1)
|
||||
SEQ_2_LEN = len(seq_2)
|
||||
|
||||
if SEQ_1_LEN > SEQ_2_LEN:
|
||||
PAD = [pad_token] * (SEQ_1_LEN - SEQ_2_LEN)
|
||||
seq_2.extend(PAD)
|
||||
|
||||
if SEQ_2_LEN > SEQ_1_LEN:
|
||||
seq_2 = seq_2[:SEQ_1_LEN]
|
||||
|
||||
return (seq_1, seq_2)
|
||||
Reference in New Issue
Block a user