Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standard Scaler fit-transform interface #179

Merged
merged 9 commits into from
Dec 14, 2023
10 changes: 2 additions & 8 deletions lib/scholar/preprocessing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,8 @@ defmodule Scholar.Preprocessing do
>
"""
deftransform standard_scale(tensor, opts \\ []) do
standard_scale_n(tensor, NimbleOptions.validate!(opts, @general_schema))
end

defnp standard_scale_n(tensor, opts) do
std = Nx.standard_deviation(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.mean(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.select(std == 0, 0.0, mean_reduced)
(tensor - mean_reduced) / Nx.select(std == 0, 1.0, std)
opts = NimbleOptions.validate!(opts, @general_schema)
Scholar.Preprocessing.StandardScaler.fit_transform(tensor, opts)
end

@doc """
Expand Down
138 changes: 138 additions & 0 deletions lib/scholar/preprocessing/standard_scaler.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
defmodule Scholar.Preprocessing.StandardScaler do
@moduledoc ~S"""
Standardizes the tensor by removing the mean and scaling to unit variance.

Formula for input tensor $x$:

$$
z = \frac{x - \mu}{\sigma}
$$

Where $\mu$ is the mean of the samples, and $\sigma$ is the standard deviation.
Standardization can be helpful in cases where the data follows
a Gaussian distribution (or Normal distribution) without outliers.

Centering and scaling happen independently on each feature by computing the relevant
statistics on the samples in the training set. Mean and standard deviation are then
stored to be used on new samples.
"""

import Nx.Defn

@derive {Nx.Container, containers: [:standard_deviation, :mean]}
defstruct [:standard_deviation, :mean]

opts_schema = [
axes: [
type: {:custom, Scholar.Options, :axes, []},
doc: """
Axes to calculate the distance over. By default the distance
is calculated between the whole tensors.
"""
]
]

@opts_schema NimbleOptions.new!(opts_schema)

@doc """
Compute the standard deviation and mean of samples to be used for later scaling.

## Options

#{NimbleOptions.docs(@opts_schema)}

## Return values

Returns a struct with the following parameters:

* `standard_deviation`: the calculated standard deviation of samples.

* `mean`: the calculated mean of samples.

## Examples

iex> Scholar.Preprocessing.StandardScaler.fit(Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]))
%Scholar.Preprocessing.StandardScaler{
standard_deviation: Nx.tensor(
[
[1.0657403469085693]
]
),
mean: Nx.tensor(
[
[0.4444444477558136]
]
)
}
"""
deftransform fit(tensor, opts \\ []) do
NimbleOptions.validate!(opts, @opts_schema)
fit_n(tensor, opts)
end

defnp fit_n(tensor, opts) do
std = Nx.standard_deviation(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.mean(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.select(std == 0, 0.0, mean_reduced)
%__MODULE__{standard_deviation: std, mean: mean_reduced}
end

@doc """
Performs the standardization of the tensor using a fitted scaler.

## Examples

iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
iex> scaler = Scholar.Preprocessing.StandardScaler.fit(t)
%Scholar.Preprocessing.StandardScaler{
standard_deviation: Nx.tensor(
[
[1.0657403469085693]
]
),
mean: Nx.tensor(
[
[0.4444444477558136]
]
)
}
iex> Scholar.Preprocessing.StandardScaler.transform(scaler, t)
#Nx.Tensor<
f32[3][3]
[
[0.5212860703468323, -1.3553436994552612, 1.4596009254455566],
[1.4596009254455566, -0.4170288145542145, -0.4170288145542145],
[-0.4170288145542145, 0.5212860703468323, -1.3553436994552612]
]
>
"""
defn transform(%__MODULE__{standard_deviation: std, mean: mean}, tensor) do
scale(tensor, std, mean)
end

@doc """
Standardizes the tensor by removing the mean and scaling to unit variance.

## Examples

iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
iex> Scholar.Preprocessing.StandardScaler.fit_transform(t)
#Nx.Tensor<
f32[3][3]
[
[0.5212860703468323, -1.3553436994552612, 1.4596009254455566],
[1.4596009254455566, -0.4170288145542145, -0.4170288145542145],
[-0.4170288145542145, 0.5212860703468323, -1.3553436994552612]
]
>
"""
defn fit_transform(tensor, opts \\ []) do
tensor
|> fit(opts)
|> transform(tensor)
end

defnp scale(tensor, std, mean) do
(tensor - mean) / Nx.select(std == 0, 1.0, std)
end
end
27 changes: 27 additions & 0 deletions test/scholar/preprocessing/standard_scaler_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
defmodule Scholar.Preprocessing.StandardScalerTest do
use Scholar.Case, async: true
alias Scholar.Preprocessing.StandardScaler

doctest StandardScaler

describe "fit_transform/2" do
test "applies standard scaling to data" do
data = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])

expected =
Nx.tensor([
[0.5212860703468323, -1.3553436994552612, 1.4596009254455566],
[1.4596009254455566, -0.4170288145542145, -0.4170288145542145],
[-0.4170288145542145, 0.5212860703468323, -1.3553436994552612]
])

assert_all_close(StandardScaler.fit_transform(data), expected)
end

test "leaves data as it is when variance is zero" do
data = 42.0
expected = Nx.tensor(data)
assert StandardScaler.fit_transform(data) == expected
end
end
end
Loading