Skip to content

Commit

Permalink
handle near constant column
Browse files Browse the repository at this point in the history
  • Loading branch information
jitingxu1 committed Sep 17, 2024
1 parent 72eb5b3 commit b49543b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 32 deletions.
29 changes: 13 additions & 16 deletions ibis_ml/steps/_standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from collections.abc import Iterable

_DOCS_PAGE_NAME = "standardization"
# a small epsilon value to handle near-constant columns during normalization
_APPROX_EPS = 10e-7


class ScaleMinMax(Step):
Expand Down Expand Up @@ -61,21 +63,18 @@ def fit_table(self, table: ir.Table, metadata: Metadata) -> None:
self._fit_expr = [expr]
results = expr.execute().to_dict("records")[0]
for name in columns:
col_max = results[f"{name}_max"]
col_min = results[f"{name}_min"]
if col_max == col_min:
raise ValueError(
f"Cannot standardize {name!r} - "
"the maximum and minimum values are equal"
)
stats[name] = (col_max, col_min)
stats[name] = (results[f"{name}_max"], results[f"{name}_min"])

self.stats_ = stats

def transform_table(self, table: ir.Table) -> ir.Table:
return table.mutate(
[
((table[c] - min) / (max - min)).name(c) # type: ignore
# for near-constant column, set the scale to 1.0
(
(table[c] - min)
/ (1.0 if abs(max - min) < _APPROX_EPS else max - min)
).name(c)
for c, (max, min) in self.stats_.items()
]
)
Expand Down Expand Up @@ -128,19 +127,17 @@ def fit_table(self, table: ir.Table, metadata: Metadata) -> None:
self._fit_expr = [table.aggregate(aggs)]
results = self._fit_expr[-1].execute().to_dict("records")[0]
for name in columns:
col_std = results[f"{name}_std"]
if col_std == 0:
raise ValueError(
f"Cannot standardize {name!r} - the standard deviation is zero"
)
stats[name] = (results[f"{name}_mean"], col_std)
stats[name] = (results[f"{name}_mean"], results[f"{name}_std"])

self.stats_ = stats

def transform_table(self, table: ir.Table) -> ir.Table:
return table.mutate(
[
((table[c] - center) / scale).name(c) # type: ignore
# for near-constant column, set the scale to 1.0
(
(table[c] - center) / (1.0 if abs(scale) < _APPROX_EPS else scale)
).name(c)
for c, (center, scale) in self.stats_.items()
]
)
25 changes: 9 additions & 16 deletions tests/test_standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,12 @@ def test_scaleminmax():
tm.assert_frame_equal(result.execute(), expected, check_exact=False)


@pytest.mark.parametrize(
("model", "msg"),
[
("ScaleStandard", "Cannot standardize 'col' - the standard deviation is zero"),
(
"ScaleMinMax",
"Cannot standardize 'col' - the maximum and minimum values are equal",
),
],
)
def test_scale_unique_col(model, msg):
table = ibis.memtable({"col": [1]})
scale_class = getattr(ml, model)
step = scale_class("col")
with pytest.raises(ValueError, match=msg):
step.fit_table(table, ml.core.Metadata())
@pytest.mark.parametrize("scaler", ["ScaleStandard", "ScaleMinMax"])
def test_constant_columns(scaler):
table = ibis.memtable({"int_col": [100], "float_col": [100.0]})
scaler_class = getattr(ml, scaler)
scale_step = scaler_class(ml.numeric())
scale_step.fit_table(table, ml.core.Metadata())
result = scale_step.transform_table(table)
expected = pd.DataFrame({"int_col": [0.0], "float_col": [0.0]})
tm.assert_frame_equal(result.execute(), expected)

0 comments on commit b49543b

Please sign in to comment.