Skip to content

Commit

Permalink
feat(core): implement a more robust set_params()
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman committed Sep 16, 2024
1 parent 1c8706e commit 7b00324
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 4 deletions.
106 changes: 102 additions & 4 deletions ibis_ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,44 @@ def get_params(self, deep=True) -> dict[str, Any]:
"""
return {key: getattr(self, key) for key in self._get_param_names()}

def set_params(self, **params):
"""Set the parameters of this estimator.
Parameters
----------
**params : dict
Step parameters.
Returns
-------
self : object
Step class instance.
Notes
-----
Derived from [1]_.
References
----------
.. [1] https://github.com/scikit-learn/scikit-learn/blob/74016ab/sklearn/base.py#L214-L256
"""
if not params:
# Simple optimization to gain speed (inspect is slow)
return self

valid_params = self._get_param_names()

for key, value in params.items():
if key not in valid_params:
raise ValueError(
f"Invalid parameter {key!r} for estimator {self}. "
f"Valid parameters are: {valid_params!r}."
)

setattr(self, key, value)

return self

def __repr__(self) -> str:
return pprint.pformat(self)

Expand Down Expand Up @@ -453,7 +491,7 @@ def _name_estimators(estimators):

class Recipe:
def __init__(self, *steps: Step):
self.steps = steps
self.steps = list(steps)
self._output_format = "default"

def __repr__(self):
Expand Down Expand Up @@ -502,9 +540,69 @@ def get_params(self, deep=True) -> dict[str, Any]:
out[f"{name}__{key}"] = value
return out

def set_params(self, **kwargs):
if "steps" in kwargs:
self.steps = kwargs.get("steps")
def set_params(self, **params):
"""Set the parameters of this estimator.
Valid parameter keys can be listed with ``get_params()``. Note that
you can directly set the parameters of the estimators contained in
`steps`.
Parameters
----------
**params : dict
Parameters of this estimator or parameters of estimators contained
in `steps`. Parameters of the steps may be set using its name and
the parameter name separated by a '__'.
Returns
-------
self : object
Recipe class instance.
Notes
-----
Derived from [1]_ and [2]_.
References
----------
.. [1] https://github.com/scikit-learn/scikit-learn/blob/ff1c6f3/sklearn/utils/metaestimators.py#L51-L70
.. [2] https://github.com/scikit-learn/scikit-learn/blob/74016ab/sklearn/base.py#L214-L256
"""
if not params:
# Simple optimization to gain speed (inspect is slow)
return self

# Ensure strict ordering of parameter setting:
# 1. All steps
if "steps" in params:
self.steps = params.pop("steps")

# 2. Replace items with estimators in params
estimator_name_indexes = {
x: i for i, x in enumerate(name for name, _ in _name_estimators(self.steps))
}
for name in list(params):
if "__" not in name and name in estimator_name_indexes:
self.steps[estimator_name_indexes[name]] = params.pop(name)

# 3. Step parameters and other initialisation arguments
valid_params = self.get_params(deep=True)

nested_params = defaultdict(dict) # grouped by prefix
for key, value in params.items():
key, sub_key = key.split("__", maxsplit=1)
if key not in valid_params:
raise ValueError(
f"Invalid parameter {key!r} for estimator {self}. "
f"Valid parameters are: ['steps']."
)

nested_params[key][sub_key] = value

for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)

return self

def set_output(
self,
Expand Down
51 changes: 51 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import ibis
import ibis.expr.types as ir
import numpy as np
Expand Down Expand Up @@ -372,6 +374,55 @@ def test_get_params():
assert "expanddatetime__components" not in rec.get_params(deep=False)


def test_set_params():
rec = ml.Recipe(ml.ExpandDateTime(ml.timestamp()))

# Nonexistent parameter in step
with pytest.raises(
ValueError,
match="Invalid parameter 'nonexistent_param' for estimator ExpandDateTime",
):
rec.set_params(expanddatetime__nonexistent_param=True)

# Nonexistent parameter of pipeline
with pytest.raises(
ValueError, match="Invalid parameter 'expandtimestamp' for estimator Recipe"
):
rec.set_params(expandtimestamp__nonexistent_param=True)


def test_set_params_passes_all_parameters():
# Make sure all parameters are passed together to set_params
# of nested estimator.
rec = ml.Recipe(ml.ExpandDateTime(ml.timestamp()))
with patch.object(ml.ExpandDateTime, "set_params") as mock_set_params:
rec.set_params(
expanddatetime__inputs=["x", "y"],
expanddatetime__components=["day", "year", "hour"],
)

mock_set_params.assert_called_once_with(
inputs=["x", "y"], components=["day", "year", "hour"]
)


def test_set_params_updates_valid_params():
# Check that set_params tries to set `replacement_mutateat.inputs`, not
# `original_mutateat.inputs`.
original_mutateat = ml.MutateAt("dep_time", ibis._.hour() * 60 + ibis._.minute()) # noqa: SLF001
rec = ml.Recipe(
original_mutateat,
ml.MutateAt(ml.timestamp(), ibis._.epoch_seconds()), # noqa: SLF001
)
replacement_mutateat = ml.MutateAt("arr_time", ibis._.hour() * 60 + ibis._.minute()) # noqa: SLF001
rec.set_params(
**{"mutateat-1": replacement_mutateat, "mutateat-1__inputs": ml.cols("arrival")}
)
assert original_mutateat.inputs == ml.cols("dep_time")
assert replacement_mutateat.inputs == ml.cols("arrival")
assert rec.steps[0] is replacement_mutateat


@pytest.mark.parametrize(
("step", "url"),
[
Expand Down

0 comments on commit 7b00324

Please sign in to comment.