From 6787b2f0a6446db0364746dc71b994f20213d6b8 Mon Sep 17 00:00:00 2001 From: severinsimmler Date: Tue, 15 Mar 2022 16:41:20 +0100 Subject: [PATCH] fix: evaluation --- chaine/optimization/metrics.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/chaine/optimization/metrics.py b/chaine/optimization/metrics.py index 060c532..4d6a9fa 100644 --- a/chaine/optimization/metrics.py +++ b/chaine/optimization/metrics.py @@ -108,17 +108,18 @@ def evaluate_predictions(true: list[list[str]], pred: list[list[str]]) -> dict[s true_labels = [l.removeprefix("B-").removeprefix("I-") for l in true_labels] predicted_labels = [l.removeprefix("B-").removeprefix("I-") for l in predicted_labels] + if len(true_labels) != len(predicted_labels): + raise ValueError(f"Different lengths: '{true_labels}' vs. '{predicted_labels}'") + for true_label, predicted_label in zip(true_labels, predicted_labels): - if true_label == "O" and true_label == predicted_label: - counts["tn"] += 1 - elif true_label == "O" and true_label != predicted_label: - counts["fp"] += 1 - elif true_label != "O" and true_label == predicted_label: + if true_label != "O" and predicted_label == true_label: counts["tp"] += 1 - elif true_label != "O" and predicted_label == "O": - counts["fn"] += 1 - elif true_label != "O" and true_label != predicted_label: + if predicted_label != "O" and predicted_label != true_label: counts["fp"] += 1 + if true_label == "O" and predicted_label == "O": + counts["tn"] += 1 + if true_label != "O" and predicted_label == "O": + counts["fn"] += 1 # calculate precision, recall and f1 score return {