Skip to content

Commit

Permalink
Revert "GH-15857: AutoML pipeline support (#16041)"
Browse files Browse the repository at this point in the history
This reverts commit 17fa9ee.
  • Loading branch information
mn-mikke committed Feb 27, 2024
1 parent 248cffb commit 46a3cd0
Show file tree
Hide file tree
Showing 17 changed files with 51 additions and 483 deletions.
61 changes: 6 additions & 55 deletions h2o-automl/src/main/java/ai/h2o/automl/AutoML.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@
import ai.h2o.automl.leaderboard.ModelProvider;
import ai.h2o.automl.leaderboard.ModelStep;
import ai.h2o.automl.preprocessing.PreprocessingStep;
import ai.h2o.automl.preprocessing.PreprocessingStepDefinition;
import hex.Model;
import hex.ScoreKeeper.StoppingMetric;
import hex.genmodel.utils.DistributionFamily;
import hex.leaderboard.*;
import hex.pipeline.DataTransformer;
import hex.pipeline.PipelineModel.PipelineParameters;
import hex.splitframe.ShuffleSplitFrame;
import org.apache.log4j.Logger;
import water.*;
import water.automl.api.schemas3.AutoMLV99;
import water.exceptions.H2OAutoMLException;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.logging.Logger;
import water.logging.LoggerFactory;
import water.nbhm.NonBlockingHashMap;
import water.util.*;

Expand Down Expand Up @@ -63,7 +61,7 @@ public enum Constraint {

private static final boolean verifyImmutability = true; // check that trainingFrame hasn't been messed with
private static final ThreadLocal<SimpleDateFormat> timestampFormatForKeys = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyyMMdd_HHmmss"));
private static final Logger log = Logger.getLogger(AutoML.class);
private static final Logger log = LoggerFactory.getLogger(AutoML.class);

private static LeaderboardExtensionsProvider createLeaderboardExtensionProvider(AutoML automl) {
final Key<AutoML> amlKey = automl._key;
Expand Down Expand Up @@ -168,11 +166,9 @@ public double[] getClassDistribution() {
private Vec[] _originalTrainingFrameVecs;
private String[] _originalTrainingFrameNames;
private long[] _originalTrainingFrameChecksums;
private transient Map<Key, String> _trackedKeys = new NonBlockingHashMap<>();
private transient NonBlockingHashMap<Key, String> _trackedKeys = new NonBlockingHashMap<>();
private transient ModelingStep[] _executionPlan;
private transient PreprocessingStep[] _preprocessing;
private transient PipelineParameters _pipelineParams;
private transient Map<String, Object[]> _pipelineHyperParams;
transient StepResultState[] _stepsResults;

private boolean _useAutoBlending;
Expand Down Expand Up @@ -222,7 +218,6 @@ public AutoML(Key<AutoML> key, Date startTime, AutoMLBuildSpec buildSpec) {
prepareData();
initLeaderboard();
initPreprocessing();
initPipeline();
_modelingStepsExecutor = new ModelingStepsExecutor(_leaderboard, _eventLog, _runCountdown);
} catch (Exception e) {
delete(); //cleanup potentially leaked keys
Expand Down Expand Up @@ -392,53 +387,11 @@ private void initLeaderboard() {
}
_leaderboard.setExtensionsProvider(createLeaderboardExtensionProvider(this));
}

private void initPipeline() {
final AutoMLBuildModels build = _buildSpec.build_models;
_pipelineParams = build.preprocessing == null || !build._pipelineEnabled ? null : new PipelineParameters();
if (_pipelineParams == null) return;
List<DataTransformer> transformers = new ArrayList<>();
Map<String, Object[]> hyperParams = new NonBlockingHashMap<>();
for (PreprocessingStepDefinition def : build.preprocessing) {
PreprocessingStep step = def.newPreprocessingStep(this);
transformers.addAll(Arrays.asList(step.pipelineTransformers()));
Map<String, Object[]> hp = step.pipelineTransformersHyperParams();
if (hp != null) hyperParams.putAll(hp);
}
if (transformers.isEmpty()) {
_pipelineParams = null;
_pipelineHyperParams = null;
} else {
_pipelineParams._transformers = transformers.toArray(new DataTransformer[0]);
_pipelineHyperParams = hyperParams;
}

//TODO: given that a transformer can reference a model (e.g. TE),
// and multiple transformers can refer
// to the same model,
// then we should be careful when deleting a transformer (resp. an entire pipeline)
// as we may delete sth that is still in use by another transformer (resp. pipeline).
// --> ref count?

//TODO: in AutoML, the same transformations are likely to occur on multiple (sometimes all) models,
// especially if the transformers parameters are not tuned.
// But it also depends if the transformers are context(CV)-sensitive (e.g. Target Encoding).
// See `CachingTransformer` for some thoughts about this.
}

PipelineParameters getPipelineParams() {
return _pipelineParams;
}

Map<String, Object[]> getPipelineHyperParams() {
return _pipelineHyperParams;
}

private void initPreprocessing() {
final AutoMLBuildModels build = _buildSpec.build_models;
_preprocessing = build.preprocessing == null || build._pipelineEnabled
_preprocessing = _buildSpec.build_models.preprocessing == null
? null
: Arrays.stream(build.preprocessing)
: Arrays.stream(_buildSpec.build_models.preprocessing)
.map(def -> def.newPreprocessingStep(this))
.toArray(PreprocessingStep[]::new);
}
Expand Down Expand Up @@ -538,11 +491,9 @@ public void run() {
eventLog().info(Stage.Workflow, "AutoML build started: " + EventLogEntry.dateTimeFormat.get().format(_runCountdown.start_time()))
.setNamedValue("start_epoch", _runCountdown.start_time(), EventLogEntry.epochFormat.get());
try {
Scope.enter();
learn();
} finally {
stop();
Scope.exit();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ public static final class AutoMLBuildModels extends Iced {
public double exploitation_ratio = -1;
public AutoMLCustomParameters algo_parameters = new AutoMLCustomParameters();
public PreprocessingStepDefinition[] preprocessing;
public boolean _pipelineEnabled = false; // currently used for testing until ready: to be removed
}

public static final class AutoMLCustomParameters extends Iced {
Expand Down
97 changes: 23 additions & 74 deletions h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,12 @@
import hex.ModelContainer;
import hex.ScoreKeeper.StoppingMetric;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.*;
import hex.grid.HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria;
import hex.grid.HyperSpaceWalker;
import hex.leaderboard.Leaderboard;
import hex.ModelParametersDelegateBuilderFactory;
import hex.pipeline.PipelineModel.PipelineParameters;
import jsr166y.CountedCompleter;
import org.apache.commons.lang.builder.ToStringBuilder;
import water.*;
import water.KeyGen.ConstantKeyGen;
import water.KeyGen.PatternKeyGen;
import water.exceptions.H2OIllegalArgumentException;
import water.util.ArrayUtils;
import water.util.Countdown;
Expand All @@ -44,8 +37,6 @@
* Parent class defining common properties and common logic for actual {@link AutoML} training steps.
*/
public abstract class ModelingStep<M extends Model> extends Iced<ModelingStep> {

protected static final String PIPELINE_KEY_PREFIX = "Pipeline_";

protected enum SeedPolicy {
/** No seed will be used (= random). */
Expand All @@ -70,11 +61,20 @@ protected <MP extends Model.Parameters> Job<Grid> startSearch(
assert baseParams != null;
assert hyperParams.size() > 0;
assert searchCriteria != null;
GridSearch.Builder builder = makeGridBuilder(resultKey, baseParams, hyperParams, searchCriteria);
aml().trackKeys(builder.dest());
aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+builder.dest()+" hyperparameter search")
applyPreprocessing(baseParams);
aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+resultKey+" hyperparameter search")
.setNamedValue("start_"+_provider+"_"+_id, new Date(), EventLogEntry.epochFormat.get());
return builder.start();
return GridSearch.create(
resultKey,
HyperSpaceWalker.BaseWalker.WalkerFactory.create(
baseParams,
hyperParams,
new SimpleParametersBuilderFactory<>(),
searchCriteria
))
.withParallelism(GridSearch.SEQUENTIAL_MODEL_BUILDING)
.withMaxConsecutiveFailures(aml()._maxConsecutiveModelFailures)
.start();
}

@SuppressWarnings("unchecked")
Expand All @@ -84,8 +84,11 @@ protected <MP extends Model.Parameters> Job<M> startModel(
) {
assert resultKey != null;
assert params != null;
ModelBuilder builder = makeBuilder(resultKey, params);
aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+builder.dest()+" model training")
Job<M> job = new Job<>(resultKey, ModelBuilder.javaName(_algo.urlName()), _description);
applyPreprocessing(params);
ModelBuilder builder = ModelBuilder.make(_algo.urlName(), job, (Key<Model>) resultKey);
builder._parms = params;
aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+resultKey+" model training")
.setNamedValue("start_"+_provider+"_"+_id, new Date(), EventLogEntry.epochFormat.get());
builder.init(false); // validate parameters
if (builder._messages.length > 0) {
Expand All @@ -99,38 +102,6 @@ protected <MP extends Model.Parameters> Job<M> startModel(
}
return builder.trainModelOnH2ONode();
}

protected <MP extends Model.Parameters> GridSearch.Builder makeGridBuilder(Key<Grid> resultKey,
MP baseParams,
Map<String, Object[]> hyperParams,
HyperSpaceSearchCriteria searchCriteria) {
applyPreprocessing(baseParams);
Model.Parameters finalParams = applyPipeline(resultKey, baseParams, hyperParams);
if (finalParams instanceof PipelineParameters) resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey);
return GridSearch.create(
resultKey,
HyperSpaceWalker.BaseWalker.WalkerFactory.create(
finalParams,
hyperParams,
new ModelParametersDelegateBuilderFactory<>(),
searchCriteria
))
.withParallelism(GridSearch.SEQUENTIAL_MODEL_BUILDING)
.withMaxConsecutiveFailures(aml()._maxConsecutiveModelFailures);
}


protected <MP extends Model.Parameters> ModelBuilder makeBuilder(Key<M> resultKey, MP params) {
applyPreprocessing(params);
Model.Parameters finalParams = applyPipeline(resultKey, params, null);
if (finalParams instanceof PipelineParameters) resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey);

Job<M> job = new Job<>(resultKey, ModelBuilder.javaName(_algo.urlName()), _description);
ModelBuilder builder = ModelBuilder.make(finalParams.algoName(), job, (Key<Model>) resultKey);
builder._parms = finalParams;
builder._input_parms = finalParams.clone();
return builder;
}

private boolean validParameters(Model.Parameters parms, String[] fields) {
try {
Expand Down Expand Up @@ -389,6 +360,8 @@ protected void setCommonModelBuilderParams(Model.Parameters params) {
setClassBalancingParams(params);
params._custom_metric_func = buildSpec.build_control.custom_metric_func;

params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment;
params._export_checkpoints_dir = buildSpec.build_control.export_checkpoints_dir;

/** Using _main_model_time_budget_factor to determine if and how we should restrict the time for the main model.
Expand All @@ -401,8 +374,6 @@ protected void setCommonModelBuilderParams(Model.Parameters params) {
protected void setCrossValidationParams(Model.Parameters params) {
AutoMLBuildSpec buildSpec = aml().getBuildSpec();
params._keep_cross_validation_predictions = aml().getBlendingFrame() == null || buildSpec.build_control.keep_cross_validation_predictions;
params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment;
params._fold_column = buildSpec.input_spec.fold_column;

if (buildSpec.input_spec.fold_column == null) {
Expand Down Expand Up @@ -442,29 +413,6 @@ protected void applyPreprocessing(Model.Parameters params) {
}
}

protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map<String, Object[]> hyperParams) {
if (aml().getPipelineParams() == null) return params;
PipelineParameters pparams = (PipelineParameters) aml().getPipelineParams().clone();
setCommonModelBuilderParams(pparams);
pparams._seed = params._seed;
pparams._max_runtime_secs = params._max_runtime_secs;
pparams._estimatorParams = params;
pparams._estimatorKeyGen = hyperParams == null
? new ConstantKeyGen(resultKey)
: new PatternKeyGen("{0}|s/"+PIPELINE_KEY_PREFIX+"//") // in case of grid, remove the Pipeline prefix to obtain the estimator key, this allows naming compatibility with the classic mode.
;
if (hyperParams != null) {
Map<String, Object[]> pipelineHyperParams = new HashMap<>();
for (Map.Entry<String, Object[]> e : hyperParams.entrySet()) {
pipelineHyperParams.put("estimator."+e.getKey(), e.getValue());
}
hyperParams.clear();
hyperParams.putAll(pipelineHyperParams);
hyperParams.putAll(aml().getPipelineHyperParams());
}
return pparams;
}

protected PreprocessingConfig getPreprocessingConfig() {
return new PreprocessingConfig();
}
Expand Down Expand Up @@ -690,6 +638,7 @@ protected Job<Grid> hyperparameterSearch(Key<Grid> key, Model.Parameters basePar
setSearchCriteria(searchCriteria, baseParms);

if (null == key) key = makeKey(_provider, true);
aml().trackKeys(key);

Log.debug("Hyperparameter search: " + _provider + ", time remaining (ms): " + aml().timeRemainingMs());
aml().eventLog().debug(Stage.ModelTraining, searchCriteria.max_runtime_secs() == 0
Expand Down Expand Up @@ -809,7 +758,7 @@ protected Job<Models> startJob() {
final Key<Models> selectionKey = Key.make(key+"_select");
final EventLog selectionEventLog = EventLog.getOrMake(selectionKey);
// EventLog selectionEventLog = aml().eventLog();
final LeaderboardHolder selectionLeaderboard = makeLeaderboard(selectionKey.toString(), selectionEventLog);
final LeaderboardHolder selectionLeaderboard = makeLeaderboard(selectionKey.toString(), selectionEventLog);

{
result.delete_and_lock(job);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ StepResultState submit(ModelingStep step, Job parentJob) {
}
} catch (Exception e) {
resultState.addState(new StepResultState(step.getGlobalId(), e));
Log.err(e);
} finally {
step.onDone(job);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ enum ResultStatus {
enum Resolution {
sameAsMain, // resolves to the same state as the main step (ignoring other sub-step states).
optimistic, // success if any success, otherwise cancelled if any cancelled, otherwise failed if any failure, otherwise skipped.
pessimistic, // failed if any failure, otherwise cancelled if any cancelled, otherwise success it any success, otherwise skipped.
pessimistic, // failures if any failure, otherwise cancelled if any cancelled, otherwise success it any success, otherwise skipped.
}

private final String _id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,7 @@ protected PreprocessingConfig getPreprocessingConfig() {
return config;
}

@Override
protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map<String, Object[]> hyperParams) {
return params; // no pipeline in SE, base models handle the transformations when making predictions.
}

@Override
@Override
@SuppressWarnings("unchecked")
public boolean canRun() {
Key<Model>[] keys = getBaseModels();
Expand Down Expand Up @@ -127,18 +122,13 @@ protected boolean hasDoppelganger(Key<Model>[] baseModelsKeys) {
protected abstract Key<Model>[] getBaseModels();

protected String getModelType(Key<Model> key) {
ModelingStep step = aml().session().getModelingStep(key);
// if (step != null) { // fixme: commenting out this for now, as it interprets XRT as a DRF (which it is) and breaks legacy tests. We might want to reconsider this distinction as XRT is often very similar to DRF and doesn't bring much diversity to SEs, and the best_of SEs currently almost always have these 2.
// return step.getAlgo().name();
// } else { // dirty case
String keyStr = key.toString();
int lookupStart = keyStr.startsWith(PIPELINE_KEY_PREFIX) ? PIPELINE_KEY_PREFIX.length() : 0;
return keyStr.substring(lookupStart, keyStr.indexOf('_', lookupStart));
// }
return keyStr.substring(0, keyStr.indexOf('_'));
}

protected boolean isStackedEnsemble(Key<Model> key) {
return Algo.StackedEnsemble.name().equals(getModelType(key));
ModelingStep step = aml().session().getModelingStep(key);
return step != null && step.getAlgo() == Algo.StackedEnsemble;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ public Map<String, Object[]> prepareSearchParameters() {
XGBoostParameters.Booster.gbtree,
XGBoostParameters.Booster.dart
});
// searchParams.put("_booster$weights", new Integer[] {2, 1});


searchParams.put("_reg_lambda", new Float[]{0.001f, 0.01f, 0.1f, 1f, 10f, 100f});
searchParams.put("_reg_alpha", new Float[]{0.001f, 0.01f, 0.1f, 0.5f, 1f});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package ai.h2o.automl.preprocessing;

import ai.h2o.automl.ModelingStep;
import hex.Model;
import hex.pipeline.DataTransformer;

import java.util.Map;

public interface PreprocessingStep<T> {

Expand Down Expand Up @@ -36,8 +34,4 @@ interface Completer extends Runnable {}
*/
void remove();

DataTransformer[] pipelineTransformers();

Map<String, Object[]> pipelineTransformersHyperParams();

}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ai.h2o.automl.preprocessing;

import ai.h2o.automl.AutoML;
import hex.pipeline.DataTransformer;
import water.Iced;

public class PreprocessingStepDefinition extends Iced<PreprocessingStepDefinition> {
Expand Down
Loading

0 comments on commit 46a3cd0

Please sign in to comment.