Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standard Scaler fit-transform interface #179

Merged
merged 9 commits into from
Dec 14, 2023
10 changes: 2 additions & 8 deletions lib/scholar/preprocessing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,8 @@ defmodule Scholar.Preprocessing do
>
"""
deftransform standard_scale(tensor, opts \\ []) do
standard_scale_n(tensor, NimbleOptions.validate!(opts, @general_schema))
end

defnp standard_scale_n(tensor, opts) do
std = Nx.standard_deviation(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.mean(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.select(std == 0, 0.0, mean_reduced)
(tensor - mean_reduced) / Nx.select(std == 0, 1.0, std)
opts = NimbleOptions.validate!(opts, @general_schema)
Scholar.Preprocessing.StandardScaler.fit_transform(tensor, opts)
end

@doc """
Expand Down
145 changes: 145 additions & 0 deletions lib/scholar/preprocessing/standard_scaler.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
defmodule Scholar.Preprocessing.StandardScaler do
@moduledoc """
Standardizes the tensor by removing the mean and scaling to unit variance.

#{~S'''
Formula for input tensor $x$:
$$
z = \frac{x - \mu}{\sigma}
$$
Where $\mu$ is the mean of the samples, and $\sigma$ is the standard deviation.
Standardization can be helpful in cases where the data follows
a Gaussian distribution (or Normal distribution) without outliers.
'''}
josevalim marked this conversation as resolved.
Show resolved Hide resolved

Centering and scaling happen independently on each feature by computing the relevant
statistics on the samples in the training set. Mean and standard deviation are then
stored to be used on new samples.
"""

import Nx.Defn

@derive {Nx.Container, containers: [:standard_deviation, :mean]}
defstruct [:standard_deviation, :mean]

opts_schema = [
axes: [
type: {:custom, Scholar.Options, :axes, []},
doc: """
Axes to calculate the distance over. By default the distance
is calculated between the whole tensors.
"""
]
]

@opts_schema NimbleOptions.new!(opts_schema)

@doc """
Compute the standard deviation and mean of samples to be used for later scaling.

## Options

#{NimbleOptions.docs(@opts_schema)}

## Return values

Returns a struct with the following parameters:
josevalim marked this conversation as resolved.
Show resolved Hide resolved

* `standard_deviation`: the calculated standard deviation of samples.

* `mean`: the calculated mean of samples.

## Examples

iex> Scholar.Preprocessing.StandardScaler.fit(Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]))
%Scholar.Preprocessing.StandardScaler{
standard_deviation: #Nx.Tensor<
f32[1][1]
[
[1.0657403469085693]
]
>,
mean: #Nx.Tensor<
f32[1][1]
[
[0.4444444477558136]
]
>
}
msluszniak marked this conversation as resolved.
Show resolved Hide resolved
"""
deftransform fit(tensor, opts \\ []) do
NimbleOptions.validate!(opts, @opts_schema)
{std, mean} = fit_n(tensor, opts)

%__MODULE__{standard_deviation: std, mean: mean}
end

defnp fit_n(tensor, opts) do
std = Nx.standard_deviation(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.mean(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.select(std == 0, 0.0, mean_reduced)

{std, mean_reduced}
end

@doc """
Performs the standardization of the tensor using a fitted scaler.

## Examples

iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
iex> scaler = Scholar.Preprocessing.StandardScaler.fit(t)
%Scholar.Preprocessing.StandardScaler{
standard_deviation: #Nx.Tensor<
f32[1][1]
[
[1.0657403469085693]
]
>,
mean: #Nx.Tensor<
f32[1][1]
[
[0.4444444477558136]
]
>
}
msluszniak marked this conversation as resolved.
Show resolved Hide resolved
iex> Scholar.Preprocessing.StandardScaler.transform(scaler, t)
#Nx.Tensor<
f32[3][3]
[
[0.5212860703468323, -1.3553436994552612, 1.4596009254455566],
[1.4596009254455566, -0.4170288145542145, -0.4170288145542145],
[-0.4170288145542145, 0.5212860703468323, -1.3553436994552612]
]
>
"""
defn transform(%__MODULE__{standard_deviation: std, mean: mean}, tensor) do
scale(tensor, std, mean)
end

@doc """
Standardizes the tensor by removing the mean and scaling to unit variance.

## Examples

iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
iex> Scholar.Preprocessing.StandardScaler.fit_transform(t)
#Nx.Tensor<
f32[3][3]
[
[0.5212860703468323, -1.3553436994552612, 1.4596009254455566],
[1.4596009254455566, -0.4170288145542145, -0.4170288145542145],
[-0.4170288145542145, 0.5212860703468323, -1.3553436994552612]
]
>
"""
defn fit_transform(tensor, opts \\ []) do
tensor
|> fit(opts)
|> transform(tensor)
end

defnp scale(tensor, std, mean) do
(tensor - mean) / Nx.select(std == 0, 1.0, std)
end
end
27 changes: 27 additions & 0 deletions test/scholar/preprocessing/standard_scaler_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
defmodule Scholar.Preprocessing.StandardScalerTest do
use Scholar.Case, async: true
alias Scholar.Preprocessing.StandardScaler

doctest StandardScaler

Check failure on line 5 in test/scholar/preprocessing/standard_scaler_test.exs

View workflow job for this annotation

GitHub Actions / main (1.15.6, 26.1, true)

doctest Scholar.Preprocessing.StandardScaler.transform/2 (2) (Scholar.Preprocessing.StandardScalerTest)

Check failure on line 5 in test/scholar/preprocessing/standard_scaler_test.exs

View workflow job for this annotation

GitHub Actions / main (1.15.6, 26.1, true)

doctest Scholar.Preprocessing.StandardScaler.fit/2 (1) (Scholar.Preprocessing.StandardScalerTest)

Check failure on line 5 in test/scholar/preprocessing/standard_scaler_test.exs

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3)

doctest Scholar.Preprocessing.StandardScaler.fit/2 (1) (Scholar.Preprocessing.StandardScalerTest)

Check failure on line 5 in test/scholar/preprocessing/standard_scaler_test.exs

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3)

doctest Scholar.Preprocessing.StandardScaler.transform/2 (2) (Scholar.Preprocessing.StandardScalerTest)

describe "fit_transform/2" do
test "applies standard scaling to data" do
data = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])

expected =
Nx.tensor([
[0.5212860703468323, -1.3553436994552612, 1.4596009254455566],
[1.4596009254455566, -0.4170288145542145, -0.4170288145542145],
[-0.4170288145542145, 0.5212860703468323, -1.3553436994552612]
])

assert_all_close(StandardScaler.fit_transform(data), expected)
end

test "leaves data as it is when variance is zero" do
data = 42.0
expected = Nx.tensor(data)
assert StandardScaler.fit_transform(data) == expected
end
end
end
Loading