Skip to content

Commit

Permalink
MAINT parameter validation for KbinsDiscretizer (scikit-learn#23804)
Browse files Browse the repository at this point in the history
Co-authored-by: Sangam Swadi K <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>
Co-authored-by: jeremiedbb <[email protected]>
  • Loading branch information
4 people authored Jul 6, 2022
1 parent a0623ce commit 652a627
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 106 deletions.
55 changes: 18 additions & 37 deletions sklearn/preprocessing/_discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
# License: BSD


import numbers
from numbers import Integral
import numpy as np
import warnings

from . import OneHotEncoder

from ..base import BaseEstimator, TransformerMixin
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils.validation import check_array
from ..utils.validation import check_is_fitted
from ..utils.validation import check_random_state
from ..utils.validation import _check_feature_names_in
from ..utils.validation import check_scalar
from ..utils import _safe_indexing


Expand Down Expand Up @@ -152,6 +152,19 @@ class KBinsDiscretizer(TransformerMixin, BaseEstimator):
[ 0.5, 3.5, -1.5, 1.5]])
"""

_parameter_constraints = {
"n_bins": [Interval(Integral, 2, None, closed="left"), "array-like"],
"encode": [StrOptions({"onehot", "onehot-dense", "ordinal"})],
"strategy": [StrOptions({"uniform", "quantile", "kmeans"})],
"dtype": [type, None], # TODO: TypeOptions constraint,
"subsample": [
Interval(Integral, 1, None, closed="left"),
None,
Hidden(StrOptions({"warn"})),
],
"random_state": ["random_state"],
}

def __init__(
self,
n_bins=5,
Expand Down Expand Up @@ -187,6 +200,7 @@ def fit(self, X, y=None):
self : object
Returns the instance itself.
"""
self._validate_params()
X = self._validate_data(X, dtype="numeric")

supported_dtype = (np.float64, np.float32)
Expand Down Expand Up @@ -214,37 +228,18 @@ def fit(self, X, y=None):
FutureWarning,
)
else:
self.subsample = check_scalar(
self.subsample, "subsample", numbers.Integral, min_val=1
)
rng = check_random_state(self.random_state)
if n_samples > self.subsample:
subsample_idx = rng.choice(
n_samples, size=self.subsample, replace=False
)
X = _safe_indexing(X, subsample_idx)
elif self.strategy != "quantile" and isinstance(
self.subsample, numbers.Integral
):
elif self.strategy != "quantile" and isinstance(self.subsample, Integral):
raise ValueError(
f"Invalid parameter for `strategy`: {self.strategy}. "
'`subsample` must be used with `strategy="quantile"`.'
)

valid_encode = ("onehot", "onehot-dense", "ordinal")
if self.encode not in valid_encode:
raise ValueError(
"Valid options for 'encode' are {}. Got encode={!r} instead.".format(
valid_encode, self.encode
)
)
valid_strategy = ("uniform", "quantile", "kmeans")
if self.strategy not in valid_strategy:
raise ValueError(
"Valid options for 'strategy' are {}. "
"Got strategy={!r} instead.".format(valid_strategy, self.strategy)
)

n_features = X.shape[1]
n_bins = self._validate_n_bins(n_features)

Expand Down Expand Up @@ -313,21 +308,7 @@ def fit(self, X, y=None):
def _validate_n_bins(self, n_features):
"""Returns n_bins_, the number of bins per feature."""
orig_bins = self.n_bins
if isinstance(orig_bins, numbers.Number):
if not isinstance(orig_bins, numbers.Integral):
raise ValueError(
"{} received an invalid n_bins type. "
"Received {}, expected int.".format(
KBinsDiscretizer.__name__, type(orig_bins).__name__
)
)
if orig_bins < 2:
raise ValueError(
"{} received an invalid number "
"of bins. Received {}, expected at least 2.".format(
KBinsDiscretizer.__name__, orig_bins
)
)
if isinstance(orig_bins, Integral):
return np.full(n_features, orig_bins, dtype=int)

n_bins = check_array(orig_bins, dtype=int, copy=True, ensure_2d=False)
Expand Down
78 changes: 10 additions & 68 deletions sklearn/preprocessing/tests/test_discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,6 @@ def test_valid_n_bins():
assert KBinsDiscretizer(n_bins=2).fit(X).n_bins_.dtype == np.dtype(int)


def test_invalid_n_bins():
est = KBinsDiscretizer(n_bins=1)
err_msg = (
"KBinsDiscretizer received an invalid number of bins. Received 1, expected at"
" least 2."
)
with pytest.raises(ValueError, match=err_msg):
est.fit_transform(X)

est = KBinsDiscretizer(n_bins=1.1)
err_msg = (
"KBinsDiscretizer received an invalid n_bins type. Received float, expected"
" int."
)
with pytest.raises(ValueError, match=err_msg):
est.fit_transform(X)


def test_invalid_n_bins_array():
# Bad shape
n_bins = np.full((2, 4), 2.0)
Expand Down Expand Up @@ -149,17 +131,6 @@ def test_numeric_stability(i):
assert_array_equal(Xt_expected, Xt)


def test_invalid_encode_option():
est = KBinsDiscretizer(n_bins=[2, 3, 3, 3], encode="invalid-encode")
err_msg = (
r"Valid options for 'encode' are "
r"\('onehot', 'onehot-dense', 'ordinal'\). "
r"Got encode='invalid-encode' instead."
)
with pytest.raises(ValueError, match=err_msg):
est.fit(X)


def test_encode_options():
est = KBinsDiscretizer(n_bins=[2, 3, 3, 3], encode="ordinal").fit(X)
Xt_1 = est.transform(X)
Expand All @@ -183,17 +154,6 @@ def test_encode_options():
)


def test_invalid_strategy_option():
est = KBinsDiscretizer(n_bins=[2, 3, 3, 3], strategy="invalid-strategy")
err_msg = (
r"Valid options for 'strategy' are "
r"\('uniform', 'quantile', 'kmeans'\). "
r"Got strategy='invalid-strategy' instead."
)
with pytest.raises(ValueError, match=err_msg):
est.fit(X)


@pytest.mark.parametrize(
"strategy, expected_2bins, expected_3bins, expected_5bins",
[
Expand Down Expand Up @@ -389,17 +349,6 @@ def test_kbinsdiscretizer_subsample_invalid_strategy():
kbd.fit(X)


def test_kbinsdiscretizer_subsample_invalid_type():
X = np.array([-2, 1.5, -4, -1]).reshape(-1, 1)
kbd = KBinsDiscretizer(
n_bins=10, encode="ordinal", strategy="quantile", subsample="full"
)

msg = "subsample must be an instance of int, not str."
with pytest.raises(TypeError, match=msg):
kbd.fit(X)


# TODO: Remove in 1.3
def test_kbinsdiscretizer_subsample_warn():
X = np.random.rand(200001, 1).reshape(-1, 1)
Expand All @@ -410,28 +359,21 @@ def test_kbinsdiscretizer_subsample_warn():
kbd.fit(X)


@pytest.mark.parametrize("subsample", [0, int(2e5)])
def test_kbinsdiscretizer_subsample_values(subsample):
# TODO(1.3) remove
def test_kbinsdiscretizer_subsample_values():
X = np.random.rand(220000, 1).reshape(-1, 1)
kbd_default = KBinsDiscretizer(n_bins=10, encode="ordinal", strategy="quantile")

kbd_with_subsampling = clone(kbd_default)
kbd_with_subsampling.set_params(subsample=subsample)
kbd_with_subsampling.set_params(subsample=int(2e5))

if subsample == 0:
with pytest.raises(ValueError, match="subsample == 0, must be >= 1."):
kbd_with_subsampling.fit(X)
else:
# TODO: Remove in 1.3
msg = "In version 1.3 onwards, subsample=2e5 will be used by default."
with pytest.warns(FutureWarning, match=msg):
kbd_default.fit(X)

kbd_with_subsampling.fit(X)
assert not np.all(
kbd_default.bin_edges_[0] == kbd_with_subsampling.bin_edges_[0]
)
assert kbd_default.bin_edges_.shape == kbd_with_subsampling.bin_edges_.shape
msg = "In version 1.3 onwards, subsample=2e5 will be used by default."
with pytest.warns(FutureWarning, match=msg):
kbd_default.fit(X)

kbd_with_subsampling.fit(X)
assert not np.all(kbd_default.bin_edges_[0] == kbd_with_subsampling.bin_edges_[0])
assert kbd_default.bin_edges_.shape == kbd_with_subsampling.bin_edges_.shape


@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"Isomap",
"IsotonicRegression",
"IterativeImputer",
"KBinsDiscretizer",
"KNNImputer",
"KNeighborsTransformer",
"KernelPCA",
Expand Down

0 comments on commit 652a627

Please sign in to comment.