-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3793cd5
commit 3d9e015
Showing
2 changed files
with
110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
defmodule Scholar.Metrics.MCC do | ||
@moduledoc """ | ||
Matthews Correlation Coefficient (MCC) provides a measure of the quality of binary classifications. | ||
It returns a value between -1 and 1 where 1 represents a perfect prediction, 0 represents no better | ||
than random prediction, and -1 indicates total disagreement between prediction and observation. | ||
""" | ||
|
||
import Nx.Defn | ||
|
||
@doc """ | ||
Computes the Matthews Correlation Coefficient (MCC) for binary classification. | ||
Assumes `y_true` and `y_pred` are binary tensors (0 or 1). | ||
""" | ||
defn compute(y_true, y_pred) do | ||
true_positives = calculate_true_positives(y_true, y_pred) | ||
true_negatives = calculate_true_negatives(y_true, y_pred) | ||
false_positives = calculate_false_positives(y_true, y_pred) | ||
false_negatives = calculate_false_negatives(y_true, y_pred) | ||
|
||
mcc_numerator = true_positives * true_negatives - false_positives * false_negatives | ||
|
||
mcc_denominator = | ||
Nx.sqrt( | ||
(true_positives + false_positives) * | ||
(true_positives + false_negatives) * | ||
(true_negatives + false_positives) * | ||
(true_negatives + false_negatives) | ||
) | ||
|
||
zero_tensor = Nx.tensor([0.0], type: :f32) | ||
|
||
if Nx.all( | ||
Nx.logical_and( | ||
Nx.equal(true_positives, zero_tensor), | ||
Nx.equal(true_negatives, zero_tensor) | ||
) | ||
) do | ||
Nx.tensor([-1.0], type: :f32) | ||
else | ||
Nx.select( | ||
Nx.equal(mcc_denominator, zero_tensor), | ||
zero_tensor, | ||
Nx.divide(mcc_numerator, mcc_denominator) | ||
) | ||
end | ||
end | ||
|
||
defnp calculate_true_positives(y_true, y_pred) do | ||
Nx.sum(Nx.equal(y_true, 1) * Nx.equal(y_pred, 1)) | ||
end | ||
|
||
defnp calculate_true_negatives(y_true, y_pred) do | ||
Nx.sum(Nx.equal(y_true, 0) * Nx.equal(y_pred, 0)) | ||
end | ||
|
||
defnp calculate_false_positives(y_true, y_pred) do | ||
Nx.sum(Nx.equal(y_true, 0) * Nx.equal(y_pred, 1)) | ||
end | ||
|
||
defnp calculate_false_negatives(y_true, y_pred) do | ||
Nx.sum(Nx.equal(y_true, 1) * Nx.equal(y_pred, 0)) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
defmodule Scholar.Metrics.MCCTest do | ||
use ExUnit.Case, async: true | ||
alias Scholar.Metrics.MCC | ||
|
||
describe "MCC.compute/2" do | ||
test "returns 1 for perfect predictions" do | ||
y_true = Nx.tensor([1, 0, 1, 0, 1]) | ||
y_pred = Nx.tensor([1, 0, 1, 0, 1]) | ||
|
||
value = MCC.compute(y_true, y_pred) |> Nx.reshape({1}) |> Nx.to_list() |> hd() | ||
assert value == 1.0 | ||
end | ||
|
||
test "returns -1 for completely wrong predictions" do | ||
y_true = Nx.tensor([1, 0, 1, 0, 1]) | ||
y_pred = Nx.tensor([0, 1, 0, 1, 0]) | ||
value = MCC.compute(y_true, y_pred) |> Nx.reshape({1}) |> Nx.to_list() |> hd() | ||
assert value == -1.0 | ||
end | ||
|
||
test "returns 0 when all predictions are positive" do | ||
y_true = Nx.tensor([1, 0, 1, 0, 1]) | ||
y_pred = Nx.tensor([1, 1, 1, 1, 1]) | ||
assert MCC.compute(y_true, y_pred) == Nx.tensor([0.0], type: :f32) | ||
end | ||
|
||
test "returns 0 when all predictions are negative" do | ||
y_true = Nx.tensor([1, 0, 1, 0, 1]) | ||
y_pred = Nx.tensor([0, 0, 0, 0, 0]) | ||
assert MCC.compute(y_true, y_pred) == Nx.tensor([0.0], type: :f32) | ||
end | ||
|
||
test "computes MCC for generic case" do | ||
y_true = Nx.tensor([1, 0, 1, 0, 1]) | ||
y_pred = Nx.tensor([1, 0, 1, 1, 1]) | ||
mcc_tensor = MCC.compute(y_true, y_pred) | ||
mcc_value = Nx.reshape(mcc_tensor, {1}) |> Nx.to_list() |> hd() | ||
assert mcc_value == 0.6123723983764648 | ||
end | ||
|
||
test "returns 0 when TP, TN, FP, and FN are all 0" do | ||
y_true = Nx.tensor([0, 0, 0, 0, 0]) | ||
y_pred = Nx.tensor([0, 0, 0, 0, 0]) | ||
assert MCC.compute(y_true, y_pred) == Nx.tensor([0.0], type: :f32) | ||
end | ||
end | ||
end |