Skip to content

Commit

Permalink
Matthews Correlation Coefficient
Browse files Browse the repository at this point in the history
  • Loading branch information
paulsullivanjr committed Nov 2, 2023
1 parent 3793cd5 commit 3d9e015
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
63 changes: 63 additions & 0 deletions lib/scholar/metrics/mcc.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
defmodule Scholar.Metrics.MCC do
@moduledoc """
Matthews Correlation Coefficient (MCC) provides a measure of the quality of binary classifications.
It returns a value between -1 and 1 where 1 represents a perfect prediction, 0 represents no better
than random prediction, and -1 indicates total disagreement between prediction and observation.
"""

import Nx.Defn

@doc """
Computes the Matthews Correlation Coefficient (MCC) for binary classification.
Assumes `y_true` and `y_pred` are binary tensors (0 or 1).
"""
defn compute(y_true, y_pred) do
true_positives = calculate_true_positives(y_true, y_pred)
true_negatives = calculate_true_negatives(y_true, y_pred)
false_positives = calculate_false_positives(y_true, y_pred)
false_negatives = calculate_false_negatives(y_true, y_pred)

mcc_numerator = true_positives * true_negatives - false_positives * false_negatives

mcc_denominator =
Nx.sqrt(
(true_positives + false_positives) *
(true_positives + false_negatives) *
(true_negatives + false_positives) *
(true_negatives + false_negatives)
)

zero_tensor = Nx.tensor([0.0], type: :f32)

if Nx.all(
Nx.logical_and(
Nx.equal(true_positives, zero_tensor),
Nx.equal(true_negatives, zero_tensor)
)
) do
Nx.tensor([-1.0], type: :f32)
else
Nx.select(
Nx.equal(mcc_denominator, zero_tensor),
zero_tensor,
Nx.divide(mcc_numerator, mcc_denominator)
)
end
end

defnp calculate_true_positives(y_true, y_pred) do
Nx.sum(Nx.equal(y_true, 1) * Nx.equal(y_pred, 1))
end

defnp calculate_true_negatives(y_true, y_pred) do
Nx.sum(Nx.equal(y_true, 0) * Nx.equal(y_pred, 0))
end

defnp calculate_false_positives(y_true, y_pred) do
Nx.sum(Nx.equal(y_true, 0) * Nx.equal(y_pred, 1))
end

defnp calculate_false_negatives(y_true, y_pred) do
Nx.sum(Nx.equal(y_true, 1) * Nx.equal(y_pred, 0))
end
end
47 changes: 47 additions & 0 deletions test/scholar/metrics/mcc_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
defmodule Scholar.Metrics.MCCTest do
use ExUnit.Case, async: true
alias Scholar.Metrics.MCC

describe "MCC.compute/2" do
test "returns 1 for perfect predictions" do
y_true = Nx.tensor([1, 0, 1, 0, 1])
y_pred = Nx.tensor([1, 0, 1, 0, 1])

value = MCC.compute(y_true, y_pred) |> Nx.reshape({1}) |> Nx.to_list() |> hd()
assert value == 1.0
end

test "returns -1 for completely wrong predictions" do
y_true = Nx.tensor([1, 0, 1, 0, 1])
y_pred = Nx.tensor([0, 1, 0, 1, 0])
value = MCC.compute(y_true, y_pred) |> Nx.reshape({1}) |> Nx.to_list() |> hd()
assert value == -1.0
end

test "returns 0 when all predictions are positive" do
y_true = Nx.tensor([1, 0, 1, 0, 1])
y_pred = Nx.tensor([1, 1, 1, 1, 1])
assert MCC.compute(y_true, y_pred) == Nx.tensor([0.0], type: :f32)
end

test "returns 0 when all predictions are negative" do
y_true = Nx.tensor([1, 0, 1, 0, 1])
y_pred = Nx.tensor([0, 0, 0, 0, 0])
assert MCC.compute(y_true, y_pred) == Nx.tensor([0.0], type: :f32)
end

test "computes MCC for generic case" do
y_true = Nx.tensor([1, 0, 1, 0, 1])
y_pred = Nx.tensor([1, 0, 1, 1, 1])
mcc_tensor = MCC.compute(y_true, y_pred)
mcc_value = Nx.reshape(mcc_tensor, {1}) |> Nx.to_list() |> hd()
assert mcc_value == 0.6123723983764648
end

test "returns 0 when TP, TN, FP, and FN are all 0" do
y_true = Nx.tensor([0, 0, 0, 0, 0])
y_pred = Nx.tensor([0, 0, 0, 0, 0])
assert MCC.compute(y_true, y_pred) == Nx.tensor([0.0], type: :f32)
end
end
end

0 comments on commit 3d9e015

Please sign in to comment.