Skip to content

Commit

Permalink
fixed assumed_centered? true tests
Browse files Browse the repository at this point in the history
  • Loading branch information
norm4nn committed Nov 14, 2024
1 parent df80bc3 commit cba0089
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 29 deletions.
9 changes: 5 additions & 4 deletions lib/scholar/covariance/ledoit_wolf.ex
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ defmodule Scholar.Covariance.LedoitWolf do
iex> key = Nx.Random.key(0)
iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0, 0.0]), Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), shape: {10}, type: :f32)
iex> {x, _} = Scholar.Covariance.Utils.center(x)
iex> cov = Scholar.Covariance.LedoitWolf.fit(x, assume_centered?: true)
iex> cov.covariance
#Nx.Tensor<
f32[3][3]
[
[3.8574986457824707, 2.2048025131225586, 2.1504499912261963],
[2.2048025131225586, 2.4572863578796387, 1.7215262651443481],
[2.1504499912261963, 1.7215262651443481, 2.154898166656494]
[2.5945029258728027, 1.5078359842300415, 1.1623677015304565],
[1.5078359842300415, 2.106797456741333, 1.1812156438827515],
[1.1623677015304565, 1.1812156438827515, 1.4606266021728516]
]
>
"""
Expand Down Expand Up @@ -148,7 +149,7 @@ defmodule Scholar.Covariance.LedoitWolf do

defnp ledoit_wolf_shrinkage_complex(x) do
{num_samples, num_features} = Nx.shape(x)
emp_cov = Scholar.Covariance.Utils.empirical_covariance(x)
emp_cov = Nx.covariance(x)

emp_cov_trace = Scholar.Covariance.Utils.trace(emp_cov)
mu = Nx.sum(emp_cov_trace) / num_features
Expand Down
2 changes: 1 addition & 1 deletion lib/scholar/covariance/shrunk_covariance.ex
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ defmodule Scholar.Covariance.ShrunkCovariance do
{x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered?])

covariance =
Scholar.Covariance.Utils.empirical_covariance(x)
Nx.covariance(x)
|> shrunk_covariance(shrinkage)

%__MODULE__{
Expand Down
20 changes: 2 additions & 18 deletions lib/scholar/covariance/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,18 @@ defmodule Scholar.Covariance.Utils do
import Nx.Defn
require Nx

defn center(x, assume_centered \\ false) do
defn center(x, assume_centered \\ Nx.tensor(0)) do
x =
case Nx.shape(x) do
{_} -> Nx.new_axis(x, 1)
_ -> x
end

location =
if assume_centered do
0
else
Nx.mean(x, axes: [0])
end
location = Nx.select(assume_centered == Nx.tensor(1), 0, Nx.mean(x, axes: [0]))

{x - location, location}
end

defn empirical_covariance(x) do
n = Nx.axis_size(x, 0)

covariance = Nx.dot(x, [0], x, [0]) / n

case Nx.shape(covariance) do
{} -> Nx.reshape(covariance, {1, 1})
_ -> covariance
end
end

defn trace(x) do
x
|> Nx.take_diagonal()
Expand Down
8 changes: 5 additions & 3 deletions test/scholar/covariance/ledoit_wolf_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,16 @@ defmodule Scholar.Covariance.LedoitWolfTest do
type: :f32
)

{x, _} = Scholar.Covariance.Utils.center(x)

model = LedoitWolf.fit(x, assume_centered?: true)

assert_all_close(
model.covariance,
Nx.tensor([
[1.852303147315979, 0.0, 0.0],
[0.0, 1.852303147315979, 0.0],
[0.0, 0.0, 1.852303147315979]
[1.439786434173584, -0.0, 0.0],
[-0.0, 1.439786434173584, 0.0],
[0.0, 0.0, 1.439786434173584]
]),
atol: 1.0e-3
)
Expand Down
8 changes: 5 additions & 3 deletions test/scholar/covariance/shrunk_covariance_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,16 @@ defmodule Scholar.Covariance.ShrunkCovarianceTest do
type: :f32
)

{x, _} = Scholar.Covariance.Utils.center(x)

model = ShrunkCovariance.fit(x, assume_centered?: true)

assert_all_close(
model.covariance,
Nx.tensor([
[3.0643274784088135, 0.27685147523880005, 0.4822050631046295],
[0.27685147523880005, 1.5171942710876465, 0.03596973791718483],
[0.4822050631046295, 0.03596973791718483, 0.975387692451477]
[2.0949244499206543, -0.13400490581989288, 0.5413897037506104],
[-0.13400490581989288, 1.2940725088119507, 0.0621684193611145],
[0.5413897037506104, 0.0621684193611145, 0.9303621053695679]
]),
atol: 1.0e-3
)
Expand Down

0 comments on commit cba0089

Please sign in to comment.