From cba0089129cd1c2d628036927fbd9395b03df79a Mon Sep 17 00:00:00 2001 From: norm4nn Date: Thu, 14 Nov 2024 22:06:53 +0100 Subject: [PATCH] fixed assumed_centered? true tests --- lib/scholar/covariance/ledoit_wolf.ex | 9 +++++---- lib/scholar/covariance/shrunk_covariance.ex | 2 +- lib/scholar/covariance/utils.ex | 20 ++----------------- test/scholar/covariance/ledoit_wolf_test.exs | 8 +++++--- .../covariance/shrunk_covariance_test.exs | 8 +++++--- 5 files changed, 18 insertions(+), 29 deletions(-) diff --git a/lib/scholar/covariance/ledoit_wolf.ex b/lib/scholar/covariance/ledoit_wolf.ex index a9debc24..a5a6cb1f 100644 --- a/lib/scholar/covariance/ledoit_wolf.ex +++ b/lib/scholar/covariance/ledoit_wolf.ex @@ -93,14 +93,15 @@ defmodule Scholar.Covariance.LedoitWolf do iex> key = Nx.Random.key(0) iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0, 0.0]), Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), shape: {10}, type: :f32) + iex> {x, _} = Scholar.Covariance.Utils.center(x) iex> cov = Scholar.Covariance.LedoitWolf.fit(x, assume_centered?: true) iex> cov.covariance #Nx.Tensor< f32[3][3] [ - [3.8574986457824707, 2.2048025131225586, 2.1504499912261963], - [2.2048025131225586, 2.4572863578796387, 1.7215262651443481], - [2.1504499912261963, 1.7215262651443481, 2.154898166656494] + [2.5945029258728027, 1.5078359842300415, 1.1623677015304565], + [1.5078359842300415, 2.106797456741333, 1.1812156438827515], + [1.1623677015304565, 1.1812156438827515, 1.4606266021728516] ] > """ @@ -148,7 +149,7 @@ defmodule Scholar.Covariance.LedoitWolf do defnp ledoit_wolf_shrinkage_complex(x) do {num_samples, num_features} = Nx.shape(x) - emp_cov = Scholar.Covariance.Utils.empirical_covariance(x) + emp_cov = Nx.covariance(x) emp_cov_trace = Scholar.Covariance.Utils.trace(emp_cov) mu = Nx.sum(emp_cov_trace) / num_features diff --git a/lib/scholar/covariance/shrunk_covariance.ex b/lib/scholar/covariance/shrunk_covariance.ex index 770c1ca7..49687560 100644 --- a/lib/scholar/covariance/shrunk_covariance.ex +++ b/lib/scholar/covariance/shrunk_covariance.ex @@ -97,7 +97,7 @@ defmodule Scholar.Covariance.ShrunkCovariance do {x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered?]) covariance = - Scholar.Covariance.Utils.empirical_covariance(x) + Nx.covariance(x) |> shrunk_covariance(shrinkage) %__MODULE__{ diff --git a/lib/scholar/covariance/utils.ex b/lib/scholar/covariance/utils.ex index e5b62be6..3c32f1b3 100644 --- a/lib/scholar/covariance/utils.ex +++ b/lib/scholar/covariance/utils.ex @@ -3,34 +3,18 @@ defmodule Scholar.Covariance.Utils do import Nx.Defn require Nx - defn center(x, assume_centered \\ false) do + defn center(x, assume_centered \\ Nx.tensor(0)) do x = case Nx.shape(x) do {_} -> Nx.new_axis(x, 1) _ -> x end - location = - if assume_centered do - 0 - else - Nx.mean(x, axes: [0]) - end + location = Nx.select(assume_centered == Nx.tensor(1), 0, Nx.mean(x, axes: [0])) {x - location, location} end - defn empirical_covariance(x) do - n = Nx.axis_size(x, 0) - - covariance = Nx.dot(x, [0], x, [0]) / n - - case Nx.shape(covariance) do - {} -> Nx.reshape(covariance, {1, 1}) - _ -> covariance - end - end - defn trace(x) do x |> Nx.take_diagonal() diff --git a/test/scholar/covariance/ledoit_wolf_test.exs b/test/scholar/covariance/ledoit_wolf_test.exs index d7906b07..560994b7 100644 --- a/test/scholar/covariance/ledoit_wolf_test.exs +++ b/test/scholar/covariance/ledoit_wolf_test.exs @@ -52,14 +52,16 @@ defmodule Scholar.Covariance.LedoitWolfTest do type: :f32 ) + {x, _} = Scholar.Covariance.Utils.center(x) + model = LedoitWolf.fit(x, assume_centered?: true) assert_all_close( model.covariance, Nx.tensor([ - [1.852303147315979, 0.0, 0.0], - [0.0, 1.852303147315979, 0.0], - [0.0, 0.0, 1.852303147315979] + [1.439786434173584, -0.0, 0.0], + [-0.0, 1.439786434173584, 0.0], + [0.0, 0.0, 1.439786434173584] ]), atol: 1.0e-3 ) diff --git a/test/scholar/covariance/shrunk_covariance_test.exs b/test/scholar/covariance/shrunk_covariance_test.exs index e11bc676..cc18f5ca 100644 --- a/test/scholar/covariance/shrunk_covariance_test.exs +++ b/test/scholar/covariance/shrunk_covariance_test.exs @@ -50,14 +50,16 @@ defmodule Scholar.Covariance.ShrunkCovarianceTest do type: :f32 ) + {x, _} = Scholar.Covariance.Utils.center(x) + model = ShrunkCovariance.fit(x, assume_centered?: true) assert_all_close( model.covariance, Nx.tensor([ - [3.0643274784088135, 0.27685147523880005, 0.4822050631046295], - [0.27685147523880005, 1.5171942710876465, 0.03596973791718483], - [0.4822050631046295, 0.03596973791718483, 0.975387692451477] + [2.0949244499206543, -0.13400490581989288, 0.5413897037506104], + [-0.13400490581989288, 1.2940725088119507, 0.0621684193611145], + [0.5413897037506104, 0.0621684193611145, 0.9303621053695679] ]), atol: 1.0e-3 )