Skip to content

Commit

Permalink
fix(EstimatorReport): Make a deep copy of fitted estimator in constru…
Browse files Browse the repository at this point in the history
…ctor to avoid side-effect (#1085)

Co-authored-by: Guillaume Lemaitre <[email protected]>
  • Loading branch information
augustebaum and glemaitre authored Jan 10, 2025
1 parent 7e03c9c commit b177651
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 5 deletions.
24 changes: 21 additions & 3 deletions skore/src/skore/sklearn/_estimator/report.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import inspect
import time
import warnings
from itertools import product

import joblib
Expand Down Expand Up @@ -27,7 +29,8 @@ class EstimatorReport(_HelpMixin, DirNamesMixin):
Parameters
----------
estimator : estimator object
Estimator to make report from.
Estimator to make the report from. When the estimator is not fitted,
it is deep-copied to avoid side-effects. If it is fitted, it is cloned instead.
fit : {"auto", True, False}, default="auto"
Whether to fit the estimator on the training data. If "auto", the estimator
Expand Down Expand Up @@ -79,6 +82,21 @@ def _fit_estimator(estimator, X_train, y_train):
)
return clone(estimator).fit(X_train, y_train)

@classmethod
def _copy_estimator(cls, estimator):
try:
return copy.deepcopy(estimator)
except Exception as e:
warnings.warn(
"Deepcopy failed; using estimator as-is. "
"Be aware that modifying the estimator outside of "
f"{cls.__name__} will modify the internal estimator. "
"Consider using a FrozenEstimator from scikit-learn to prevent this. "
f"Original error: {e}",
stacklevel=1,
)
return estimator

def __init__(
self,
estimator,
Expand All @@ -92,13 +110,13 @@ def __init__(
if fit == "auto":
try:
check_is_fitted(estimator)
self._estimator = estimator
self._estimator = self._copy_estimator(estimator)
except NotFittedError:
self._estimator = self._fit_estimator(estimator, X_train, y_train)
elif fit is True:
self._estimator = self._fit_estimator(estimator, X_train, y_train)
else: # fit is False
self._estimator = estimator
self._estimator = self._copy_estimator(estimator)

# private storage to be able to invalidate the cache when the user alters
# those attributes
Expand Down
53 changes: 51 additions & 2 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def test_estimator_report_from_fitted_estimator(binary_classification_data, fit)
estimator, X, y = binary_classification_data
report = EstimatorReport(estimator, fit=fit, X_test=X, y_test=y)

assert report.estimator is estimator # we should not clone the estimator
check_is_fitted(report.estimator)
assert isinstance(report.estimator, RandomForestClassifier)
assert report.X_train is None
assert report.y_train is None
assert report.X_test is X
Expand All @@ -209,7 +210,8 @@ def test_estimator_report_from_fitted_pipeline(binary_classification_data_pipeli
estimator, X, y = binary_classification_data_pipeline
report = EstimatorReport(estimator, X_test=X, y_test=y)

assert report.estimator is estimator # we should not clone the estimator
check_is_fitted(report.estimator)
assert isinstance(report.estimator, Pipeline)
assert report.estimator_name == estimator[-1].__class__.__name__
assert report.X_train is None
assert report.y_train is None
Expand Down Expand Up @@ -925,3 +927,50 @@ def test_estimator_report_get_X_y_and_data_source_hash(data_source):
assert X is X_test
assert y is y_test
assert data_source_hash == joblib.hash((X_test, y_test))


@pytest.mark.parametrize("prefit_estimator", [True, False])
def test_estimator_has_side_effects(prefit_estimator):
"""Re-fitting the estimator outside the EstimatorReport
should not have an effect on the EstimatorReport's internal estimator."""
X, y = make_classification(n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

estimator = LogisticRegression()
if prefit_estimator:
estimator.fit(X_train, y_train)

report = EstimatorReport(
estimator,
X_train=X_train,
X_test=X_test,
y_train=y_train,
y_test=y_test,
)

predictions_before = report.estimator.predict_proba(X_test)
estimator.fit(X_test, y_test)
predictions_after = report.estimator.predict_proba(X_test)
np.testing.assert_array_equal(predictions_before, predictions_after)


def test_estimator_has_no_deep_copy():
"""Check that we raise a warning if the deep copy failed with a fitted
estimator."""
X, y = make_classification(n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

estimator = LogisticRegression()
# Make it so deepcopy does not work
estimator.__reduce_ex__ = None
estimator.__reduce__ = None

with pytest.warns(UserWarning, match="Deepcopy failed"):
EstimatorReport(
estimator,
fit=False,
X_train=X_train,
X_test=X_test,
y_train=y_train,
y_test=y_test,
)

0 comments on commit b177651

Please sign in to comment.