Skip to content

Commit

Permalink
Revert "GH-15856: Grid pipeline support (#16040)"
Browse files Browse the repository at this point in the history
This reverts commit b7ac670.
  • Loading branch information
valenad1 committed Mar 8, 2024
1 parent c751663 commit 2f7bd43
Show file tree
Hide file tree
Showing 22 changed files with 272 additions and 482 deletions.
12 changes: 0 additions & 12 deletions h2o-algos/src/main/java/hex/glm/GLMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel.GLMParameters.Family;
import hex.glm.GLMModel.GLMParameters.Link;
import hex.grid.Grid;
import hex.util.EffectiveParametersUtils;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
Expand Down Expand Up @@ -1035,17 +1034,6 @@ public DistributionFamily getDistributionFamily() {
return familyToDistribution(_family);
}

@Override
public void addSearchWarnings(Grid.SearchFailure searchFailure, Grid grid) {
super.addSearchWarnings(searchFailure, grid);
if (ArrayUtils.contains(grid.getHyperNames(), "alpha")) {
// maybe we should find a way to raise this warning at the very beginning of grid search, similar to validation in ModelBuilder#init().
searchFailure.addWarning("Adding alpha array to hyperparameter runs slower with gridsearch. "+
"This is due to the fact that the algo has to run initialization for every alpha value. "+
"Setting the alpha array as a model parameter will skip the initialization and run faster overall.");
}
}

public void updateTweedieParams(double tweedieVariancePower, double tweedieLinkPower, double dispersion){
_tweedie_variance_power = tweedieVariancePower;
_tweedie_link_power = tweedieLinkPower;
Expand Down
21 changes: 12 additions & 9 deletions h2o-algos/src/test/java/hex/grid/GridTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hex.faulttolerance.Recovery;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel;
import hex.grid.HyperSpaceWalker.BaseWalker.WalkerFactory;
import hex.tree.CompressedTree;
import hex.tree.gbm.GBMModel;
import hex.tree.uplift.UpliftDRFModel;
Expand All @@ -18,6 +19,8 @@
import water.exceptions.H2OGridException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.test.dummy.DummyAction;
import water.test.dummy.DummyModelParameters;
import water.test.dummy.MessageInstallAction;

Expand Down Expand Up @@ -69,7 +72,7 @@ public void testParallelModelTimeConstraint() {

Job<Grid> gridSearch = GridSearch.startGridSearch(
null, params, hyperParms,
new SimpleParametersBuilderFactory(),
new GridSearch.SimpleParametersBuilderFactory(),
searchCriteria, 2
);

Expand Down Expand Up @@ -107,7 +110,7 @@ public void testParallelUserStopRequest() {

Job<Grid> gridSearch = GridSearch.startGridSearch(
dest, params, hyperParms,
new SimpleParametersBuilderFactory(),
new GridSearch.SimpleParametersBuilderFactory(),
new HyperSpaceSearchCriteria.CartesianSearchCriteria(),
2
);
Expand Down Expand Up @@ -367,7 +370,7 @@ public void gridSearchRecoveryModels() throws IOException, InterruptedException
Scope.track(trainingFrame);
Job<Grid> gs = GridSearch.startGridSearch(
null, gridKey, params, hyperParms,
new SimpleParametersBuilderFactory(),
new GridSearch.SimpleParametersBuilderFactory(),
new HyperSpaceSearchCriteria.CartesianSearchCriteria(),
recovery1, 1
);
Expand Down Expand Up @@ -397,7 +400,7 @@ public void gridSearchRecoveryModels() throws IOException, InterruptedException
null, gridKey,
loadedGrid1.getParams(),
loadedGrid1.getHyperParams(),
new SimpleParametersBuilderFactory(),
new GridSearch.SimpleParametersBuilderFactory(),
loadedGrid1.getSearchCriteria(),
recovery2,
loadedGrid1.getParallelism()
Expand Down Expand Up @@ -454,7 +457,7 @@ public void gridSearchWithRecoverySuccess() throws IOException, InterruptedExcep
Key gridKey = Key.make("gridSearchWithRecovery_GRID");
Job<Grid> gs = GridSearch.startGridSearch(
null, gridKey, params, hyperParms,
new SimpleParametersBuilderFactory<>(),
new GridSearch.SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.CartesianSearchCriteria(),
recovery, GridSearch.SEQUENTIAL_MODEL_BUILDING
);
Expand Down Expand Up @@ -540,7 +543,7 @@ public void gridSearchWithRecoveryCancelGBM() throws IOException, InterruptedExc
Key gridKey = Key.make("gridSearchWithRecovery_GRID");
Job<Grid> gs = GridSearch.startGridSearch(
null, gridKey, params, hyperParms,
new SimpleParametersBuilderFactory<>(),
new GridSearch.SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.CartesianSearchCriteria(),
recovery, GridSearch.SEQUENTIAL_MODEL_BUILDING
);
Expand Down Expand Up @@ -583,7 +586,7 @@ public void gridSearchWithRecoveryCancelGLM() throws IOException, InterruptedExc
Key gridKey = Key.make("gridSearchWithRecoveryGlm");
Job<Grid> gs = GridSearch.startGridSearch(
null, gridKey, params, hyperParms,
new SimpleParametersBuilderFactory<>(),
new GridSearch.SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.CartesianSearchCriteria(),
recovery, GridSearch.SEQUENTIAL_MODEL_BUILDING
);
Expand Down Expand Up @@ -847,7 +850,7 @@ public void test_parallel_random_search_with_max_models_being_less_than_parallel
params._train = trainingFrame._key;
params._response_column = "species";

SimpleParametersBuilderFactory simpleParametersBuilderFactory = new SimpleParametersBuilderFactory();
GridSearch.SimpleParametersBuilderFactory simpleParametersBuilderFactory = new GridSearch.SimpleParametersBuilderFactory();
HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria hyperSpaceSearchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria();
int custom_max_model = 2;
hyperSpaceSearchCriteria.set_max_models(custom_max_model);
Expand Down Expand Up @@ -883,7 +886,7 @@ public void test_parallel_random_search_with_max_models_being_greater_than_paral
params._train = trainingFrame._key;
params._response_column = "species";

SimpleParametersBuilderFactory simpleParametersBuilderFactory = new SimpleParametersBuilderFactory();
GridSearch.SimpleParametersBuilderFactory simpleParametersBuilderFactory = new GridSearch.SimpleParametersBuilderFactory();
HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria hyperSpaceSearchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria();
int custom_max_model = 3;
hyperSpaceSearchCriteria.set_max_models(custom_max_model);
Expand Down
6 changes: 3 additions & 3 deletions h2o-algos/src/test/java/hex/grid/SequentialWalkerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void test_SequentialWalker() {
new SequentialWalker<>(
gbmParameters,
hyperParams,
new SimpleParametersBuilderFactory<>(),
new GridSearch.SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.SequentialSearchCriteria()
),
GridSearch.SEQUENTIAL_MODEL_BUILDING
Expand Down Expand Up @@ -85,7 +85,7 @@ public void test_SequentialWalker_getHyperParams() {
SequentialWalker walker = new SequentialWalker<>(
gbmParameters,
hyperParams,
new SimpleParametersBuilderFactory<>(),
new GridSearch.SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.SequentialSearchCriteria()
);
Map<String, Object[]> exp = new HashMap<>();
Expand Down Expand Up @@ -124,7 +124,7 @@ public void test_SequentialWalker_supports_early_stopping() {
new SequentialWalker<>(
gbmParameters,
hyperParams,
new SimpleParametersBuilderFactory<>(),
new GridSearch.SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.SequentialSearchCriteria(StoppingCriteria.create()
.stoppingRounds(1)
.stoppingMetric(StoppingMetric.AUC)
Expand Down
7 changes: 5 additions & 2 deletions h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
import hex.ModelContainer;
import hex.ScoreKeeper.StoppingMetric;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.*;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria;
import hex.grid.HyperSpaceWalker;
import hex.leaderboard.Leaderboard;
import jsr166y.CountedCompleter;
import org.apache.commons.lang.builder.ToStringBuilder;
Expand Down Expand Up @@ -69,7 +72,7 @@ protected <MP extends Model.Parameters> Job<Grid> startSearch(
HyperSpaceWalker.BaseWalker.WalkerFactory.create(
baseParams,
hyperParams,
new SimpleParametersBuilderFactory<>(),
new GridSearch.SimpleParametersBuilderFactory<>(),
searchCriteria
))
.withParallelism(GridSearch.SEQUENTIAL_MODEL_BUILDING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import hex.grid.HyperSpaceSearchCriteria.SequentialSearchCriteria;
import hex.grid.HyperSpaceSearchCriteria.StoppingCriteria;
import hex.grid.SequentialWalker;
import hex.grid.SimpleParametersBuilderFactory;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostModel.XGBoostParameters;
import water.Job;
Expand Down Expand Up @@ -369,7 +368,7 @@ protected Job<Models> startTraining(Key result, double maxRuntimeSecs) {
new SequentialWalker<>(
params,
hyperParams,
new SimpleParametersBuilderFactory<>(),
new GridSearch.SimpleParametersBuilderFactory<>(),
new SequentialSearchCriteria(StoppingCriteria.create()
.maxRuntimeSecs((int)maxRuntimeSecs)
.stoppingMetric(params._stopping_metric)
Expand Down
9 changes: 0 additions & 9 deletions h2o-core/src/main/java/hex/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.*;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.quantile.QuantileModel;
import org.joda.time.DateTime;
import water.*;
Expand Down Expand Up @@ -803,14 +802,6 @@ private Parameters getDefaults() {
}
return _defaults;
}

/**
* callback called during grid search if it failed building a model with current parameters.
* When this is called, the failure instance is already extended with the last failure details/params.
* @param searchFailure
* @param grid
*/
public void addSearchWarnings(Grid.SearchFailure searchFailure, Grid grid) {}
}

public ModelMetrics addModelMetrics(final ModelMetrics mm) {
Expand Down
8 changes: 2 additions & 6 deletions h2o-core/src/main/java/hex/ModelParametersBuilderFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ public interface ModelParametersBuilderFactory<MP extends Model.Parameters> {
* @return this parameters builder
*/
ModelParametersBuilder<MP> get(MP initialParams);



/**
* Returns mapping from input parameter specification to
Expand All @@ -40,10 +38,8 @@ public interface ModelParametersBuilderFactory<MP extends Model.Parameters> {
*
* @param <MP> type of produced model parameters object
*/
interface ModelParametersBuilder<MP extends Model.Parameters> {

boolean isAssignable(String name);

interface ModelParametersBuilder<MP extends Model.Parameters> {

ModelParametersBuilder<MP> set(String name, Object value);

MP build();
Expand Down

This file was deleted.

103 changes: 0 additions & 103 deletions h2o-core/src/main/java/hex/ModelParametersGenericBuilderFactory.java

This file was deleted.

Loading

0 comments on commit 2f7bd43

Please sign in to comment.