diff --git a/h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java b/h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java index bf17ce7a5bc7..ccd0faa36e1d 100644 --- a/h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java +++ b/h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java @@ -1,9 +1,6 @@ package hex.Infogram; import hex.*; -import hex.Infogram.InfogramModel.InfogramModelOutput; -import hex.Infogram.InfogramModel.InfogramParameters; -import hex.ModelMetrics.MetricBuilder; import water.*; import water.exceptions.H2OModelBuilderIllegalArgumentException; import water.fvec.Frame; @@ -21,8 +18,8 @@ import static water.util.ArrayUtils.sort; import static water.util.ArrayUtils.sum; -public class Infogram extends ModelBuilder { +public class Infogram extends ModelBuilder { static final double NORMALIZE_ADMISSIBLE_INDEX = 1.0/Math.sqrt(2.0); boolean _buildCore; // true to find core predictors, false to find admissible predictors String[] _topKPredictors; // contain the names of top predictors to consider for infogram @@ -48,14 +45,14 @@ public class Infogram extends ModelBuilder key) { + public Infogram(hex.Infogram.InfogramModel.InfogramParameters parms, Key key) { super(parms, key); init(false); } @@ -74,23 +71,18 @@ protected int nModelsInParallel(int folds) { * This is called before cross-validation is carried out */ @Override - protected void cv_init() { - super.cv_init(); + public void computeCrossValidation() { info("cross-validation", "cross-validation infogram information is stored in frame with key" + " labeled as admissible_score_key_cv and the admissible features in admissible_features_cv."); if (error_count() > 0) { throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(Infogram.this); } - } - - @Override - protected MetricBuilder makeCVMetricBuilder(ModelBuilder cvModelBuilder, Futures fs) { - return null; //infogram does not support scoring + super.computeCrossValidation(); } // find the best alpha/lambda values used to build the main model moving forward by looking at the devianceValid @Override - protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { + public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { int nBuilders = cvModelBuilders.length; double[][] cmiRaw = new double[nBuilders][]; List> columns = new ArrayList<>(); @@ -111,12 +103,7 @@ protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) } _cvDone = true; // cv is done and we are going to build main model next } - - @Override - protected void cv_mainModelScores(int N, MetricBuilder[] mbs, ModelBuilder[] cvModelBuilders) { - //infogram does not support scoring - } - + public void calculateMeanInfogramInfo(double[][] cmiRaw, List> columns, long[] nObs) { int nFolds = cmiRaw.length; @@ -317,7 +304,7 @@ public final void buildModel() { try { boolean validPresent = _parms.valid() != null; prepareModelTrainingFrame(); // generate training frame with predictors and sensitive features (if specified) - InfogramModel model = new hex.Infogram.InfogramModel(dest(), _parms, new InfogramModelOutput(Infogram.this)); + InfogramModel model = new hex.Infogram.InfogramModel(dest(), _parms, new hex.Infogram.InfogramModel.InfogramModelOutput(Infogram.this)); _model = model.delete_and_lock(_job); _model._output._start_time = System.currentTimeMillis(); _cmiRaw = new double[_numModels]; @@ -372,7 +359,7 @@ public final void buildModel() { * relevance >= relevance_threshold. Derive _admissible_index as distance from point with cmi = 1 and * relevance = 1. In addition, all arrays are sorted on _admissible_index. */ - private void copyCMIRelevance(InfogramModelOutput modelOutput) { + private void copyCMIRelevance(InfogramModel.InfogramModelOutput modelOutput) { modelOutput._cmi_raw = new double[_cmi.length]; System.arraycopy(_cmiRaw, 0, modelOutput._cmi_raw, 0, modelOutput._cmi_raw.length); modelOutput._admissible_index = new double[_cmi.length]; @@ -388,7 +375,7 @@ private void copyCMIRelevance(InfogramModelOutput modelOutput) { modelOutput._admissible_index, modelOutput._admissible, modelOutput._all_predictor_names); } - public void copyCMIRelevanceValid(InfogramModelOutput modelOutput) { + public void copyCMIRelevanceValid(InfogramModel.InfogramModelOutput modelOutput) { modelOutput._cmi_raw_valid = new double[_cmiValid.length]; System.arraycopy(_cmiRawValid, 0, modelOutput._cmi_raw_valid, 0, modelOutput._cmi_raw_valid.length); modelOutput._admissible_index_valid = new double[_cmiValid.length]; diff --git a/h2o-algos/src/main/java/hex/deeplearning/DeepLearning.java b/h2o-algos/src/main/java/hex/deeplearning/DeepLearning.java index 6f77c5f5b879..9093bac63225 100755 --- a/h2o-algos/src/main/java/hex/deeplearning/DeepLearning.java +++ b/h2o-algos/src/main/java/hex/deeplearning/DeepLearning.java @@ -165,7 +165,7 @@ static DataInfo makeDataInfo(Frame train, Frame valid, DeepLearningParameters pa } } - @Override protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { + @Override public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { _parms._overwrite_with_best_model = false; if( _parms._stopping_rounds == 0 && _parms._max_runtime_secs == 0) return; // No exciting changes to stopping conditions diff --git a/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java b/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java index d09f8712919c..a49202eda204 100644 --- a/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java +++ b/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java @@ -580,7 +580,7 @@ public ModelMetrics makeModelMetrics(Frame fr, Frame adaptFrm) { @Override public ModelMetrics.MetricBuilder getMetricBuilder() { throw new UnsupportedOperationException("Stacked Ensemble model doesn't implement MetricBuilder infrastructure code, " + - "retrieve your metrics by calling makeModelMetrics method."); + "retrieve your metrics by calling getOrMakeMetrics method."); } } diff --git a/h2o-algos/src/main/java/hex/glm/GLM.java b/h2o-algos/src/main/java/hex/glm/GLM.java index 7c7a9d835a64..433f8c058dc5 100644 --- a/h2o-algos/src/main/java/hex/glm/GLM.java +++ b/h2o-algos/src/main/java/hex/glm/GLM.java @@ -26,6 +26,7 @@ import hex.util.LinearAlgebraUtils; import hex.util.LinearAlgebraUtils.BMulTask; import hex.util.LinearAlgebraUtils.FindMaxIndex; +import jsr166y.CountedCompleter; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; import water.*; @@ -118,8 +119,7 @@ public boolean isSupervised() { public ModelCategory[] can_build() { return new ModelCategory[]{ ModelCategory.Regression, - ModelCategory.Binomial, - ModelCategory.Multinomial + ModelCategory.Binomial, }; } @@ -148,12 +148,13 @@ public ModelCategory[] can_build() { * (builds N+1 models, all have train+validation metrics, the main model has N-fold cross-validated validation metrics) */ @Override - protected void cv_init() { + public void computeCrossValidation() { // init computes global list of lambdas init(true); _cvRuns = true; if (error_count() > 0) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(GLM.this); + super.computeCrossValidation(); } @@ -292,7 +293,7 @@ private double[] alignSubModelsAcrossCVModels(ModelBuilder[] cvModelBuilders) { * 4. unlock the n-folds models (they are changed here, so the unlocking happens here) */ @Override - protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { + public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { setMaxRuntimeSecsForMainModel(); double bestTestDev = Double.POSITIVE_INFINITY; double[] alphasAndLambdas = alignSubModelsAcrossCVModels(cvModelBuilders); @@ -371,6 +372,12 @@ protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) break; } } + for (int i = 0; i < cvModelBuilders.length; ++i) { + GLM g = (GLM) cvModelBuilders[i]; + if (g._toRemove != null) + for (Key k : g._toRemove) + Keyed.remove(k); + } for (int i = 0; i < cvModelBuilders.length; ++i) { GLM g = (GLM) cvModelBuilders[i]; @@ -1536,11 +1543,11 @@ private void buildModel() { protected static final long WORK_TOTAL = 1000000; - @Override - protected void cleanUp() { - if (_parms._lambda_search && _parms._is_cv_model) - keepUntilCompletion(_dinfo.getWeightsVec()._key); - super.cleanUp(); + transient Key [] _toRemove; + + private Key[] removeLater(Key ...k){ + _toRemove = _toRemove == null?k:ArrayUtils.append(_toRemove,k); + return k; } @Override protected GLMDriver trainModelImpl() { return _driver = new GLMDriver(); } @@ -1569,6 +1576,23 @@ public final class GLMDriver extends Driver implements ProgressMonitor { private transient GLMTask.GLMIterationTask _gramInfluence; private transient double[][] _cholInvInfluence; + private void doCleanup() { + try { + if (_parms._lambda_search && _parms._is_cv_model) + Scope.untrack(removeLater(_dinfo.getWeightsVec()._key)); + if (_parms._HGLM) { + Key[] vecKeys = _toRemove; + for (int index = 0; index < vecKeys.length; index++) { + Vec tempVec = DKV.getGet(vecKeys[index]); + tempVec.remove(); + } + } + } catch (Exception e) { + Log.err("Error while cleaning up GLM " + _result); + Log.err(e); + } + } + private transient Cholesky _chol; private transient L1Solver _lslvr; @@ -3540,8 +3564,9 @@ private Vec[] genGLMVectors(DataInfo dinfo, double[] nb) { sumExp += Math.exp(nb[i * N + P] - maxRow); } Vec[] vecs = dinfo._adaptedFrame.anyVec().makeDoubles(2, new double[]{sumExp, maxRow}); - if (_parms._lambda_search) { - track(vecs[0]); track(vecs[1]); + if (_parms._lambda_search && _parms._is_cv_model) { + Scope.untrack(vecs[0]._key, vecs[1]._key); + removeLater(vecs[0]._key, vecs[1]._key); } return vecs; } @@ -3823,7 +3848,7 @@ private void checkCoeffsBounds() { * - column 2: zi, intermediate values * - column 3: eta = X*beta, intermediate values */ - private void addWdataZiEtaOld2Response() { // attach wdata, zi, eta to response for HGLM + public void addWdataZiEtaOld2Response() { // attach wdata, zi, eta to response for HGLM int moreColnum = 3 + _parms._random_columns.length; Vec[] vecs = _dinfo._adaptedFrame.anyVec().makeZeros(moreColnum); String[] colNames = new String[moreColnum]; @@ -3836,11 +3861,25 @@ private void addWdataZiEtaOld2Response() { // attach wdata, zi, eta to response vecs[index] = _parms.train().vec(randColIndices[index - 3]).makeCopy(); } _dinfo.addResponse(colNames, vecs); - Frame wdataZiEta = new Frame(Key.make("wdataZiEta"+Key.rand()), colNames, vecs); - DKV.put(wdataZiEta); - track(wdataZiEta); + for (int index = 0; index < moreColnum; index++) { + Scope.untrack(vecs[index]._key); + removeLater(vecs[index]._key); + } + } + + @Override + public void onCompletion(CountedCompleter caller) { + doCleanup(); + super.onCompletion(caller); } + @Override + public boolean onExceptionalCompletion(Throwable t, CountedCompleter caller) { + doCleanup(); + return super.onExceptionalCompletion(t, caller); + } + + @Override public boolean progress(double[] beta, GradientInfo ginfo) { _state._iter++; diff --git a/h2o-algos/src/main/java/hex/glm/GLMModel.java b/h2o-algos/src/main/java/hex/glm/GLMModel.java index 613ca20acbeb..f4b71c1a1809 100755 --- a/h2o-algos/src/main/java/hex/glm/GLMModel.java +++ b/h2o-algos/src/main/java/hex/glm/GLMModel.java @@ -7,7 +7,6 @@ import hex.genmodel.utils.DistributionFamily; import hex.glm.GLMModel.GLMParameters.Family; import hex.glm.GLMModel.GLMParameters.Link; -import hex.grid.Grid; import hex.util.EffectiveParametersUtils; import org.apache.commons.math3.distribution.NormalDistribution; import org.apache.commons.math3.distribution.RealDistribution; @@ -1035,17 +1034,6 @@ public DistributionFamily getDistributionFamily() { return familyToDistribution(_family); } - @Override - public void addSearchWarnings(Grid.SearchFailure searchFailure, Grid grid) { - super.addSearchWarnings(searchFailure, grid); - if (ArrayUtils.contains(grid.getHyperNames(), "alpha")) { - // maybe we should find a way to raise this warning at the very beginning of grid search, similar to validation in ModelBuilder#init(). - searchFailure.addWarning("Adding alpha array to hyperparameter runs slower with gridsearch. "+ - "This is due to the fact that the algo has to run initialization for every alpha value. "+ - "Setting the alpha array as a model parameter will skip the initialization and run faster overall."); - } - } - public void updateTweedieParams(double tweedieVariancePower, double tweedieLinkPower, double dispersion){ _tweedie_variance_power = tweedieVariancePower; _tweedie_link_power = tweedieLinkPower; diff --git a/h2o-algos/src/main/java/hex/kmeans/KMeans.java b/h2o-algos/src/main/java/hex/kmeans/KMeans.java index 55a5b0db3c3c..4c69246bb41b 100755 --- a/h2o-algos/src/main/java/hex/kmeans/KMeans.java +++ b/h2o-algos/src/main/java/hex/kmeans/KMeans.java @@ -109,7 +109,7 @@ else if( user_points.numRows() != _parms._k) if (expensive && error_count() == 0) checkMemoryFootPrint(); } - protected void cv_makeAggregateModelMetrics(ModelMetrics.MetricBuilder[] mbs){ + public void cv_makeAggregateModelMetrics(ModelMetrics.MetricBuilder[] mbs){ super.cv_makeAggregateModelMetrics(mbs); ((ModelMetricsClustering.MetricBuilderClustering) mbs[0])._within_sumsqe = null; ((ModelMetricsClustering.MetricBuilderClustering) mbs[0])._size = null; diff --git a/h2o-algos/src/main/java/hex/tree/SharedTree.java b/h2o-algos/src/main/java/hex/tree/SharedTree.java index 68e716654533..43ea630ae210 100755 --- a/h2o-algos/src/main/java/hex/tree/SharedTree.java +++ b/h2o-algos/src/main/java/hex/tree/SharedTree.java @@ -1203,7 +1203,7 @@ public double initialValue() { return _parms._parallel_main_model_building; } - @Override protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { + @Override public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { // Extract stopping conditions from each CV model, and compute the best stopping answer if (!cv_initStoppingParameters()) return; // No exciting changes to stopping conditions diff --git a/h2o-algos/src/test/java/hex/grid/GridTest.java b/h2o-algos/src/test/java/hex/grid/GridTest.java index a5e392be1584..21c58a7f07bd 100644 --- a/h2o-algos/src/test/java/hex/grid/GridTest.java +++ b/h2o-algos/src/test/java/hex/grid/GridTest.java @@ -6,6 +6,7 @@ import hex.faulttolerance.Recovery; import hex.genmodel.utils.DistributionFamily; import hex.glm.GLMModel; +import hex.grid.HyperSpaceWalker.BaseWalker.WalkerFactory; import hex.tree.CompressedTree; import hex.tree.gbm.GBMModel; import hex.tree.uplift.UpliftDRFModel; @@ -18,6 +19,8 @@ import water.exceptions.H2OGridException; import water.fvec.Frame; import water.fvec.Vec; +import water.parser.BufferedString; +import water.test.dummy.DummyAction; import water.test.dummy.DummyModelParameters; import water.test.dummy.MessageInstallAction; @@ -69,7 +72,7 @@ public void testParallelModelTimeConstraint() { Job gridSearch = GridSearch.startGridSearch( null, params, hyperParms, - new SimpleParametersBuilderFactory(), + new GridSearch.SimpleParametersBuilderFactory(), searchCriteria, 2 ); @@ -107,7 +110,7 @@ public void testParallelUserStopRequest() { Job gridSearch = GridSearch.startGridSearch( dest, params, hyperParms, - new SimpleParametersBuilderFactory(), + new GridSearch.SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.CartesianSearchCriteria(), 2 ); @@ -367,7 +370,7 @@ public void gridSearchRecoveryModels() throws IOException, InterruptedException Scope.track(trainingFrame); Job gs = GridSearch.startGridSearch( null, gridKey, params, hyperParms, - new SimpleParametersBuilderFactory(), + new GridSearch.SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.CartesianSearchCriteria(), recovery1, 1 ); @@ -397,7 +400,7 @@ public void gridSearchRecoveryModels() throws IOException, InterruptedException null, gridKey, loadedGrid1.getParams(), loadedGrid1.getHyperParams(), - new SimpleParametersBuilderFactory(), + new GridSearch.SimpleParametersBuilderFactory(), loadedGrid1.getSearchCriteria(), recovery2, loadedGrid1.getParallelism() @@ -454,7 +457,7 @@ public void gridSearchWithRecoverySuccess() throws IOException, InterruptedExcep Key gridKey = Key.make("gridSearchWithRecovery_GRID"); Job gs = GridSearch.startGridSearch( null, gridKey, params, hyperParms, - new SimpleParametersBuilderFactory<>(), + new GridSearch.SimpleParametersBuilderFactory<>(), new HyperSpaceSearchCriteria.CartesianSearchCriteria(), recovery, GridSearch.SEQUENTIAL_MODEL_BUILDING ); @@ -540,7 +543,7 @@ public void gridSearchWithRecoveryCancelGBM() throws IOException, InterruptedExc Key gridKey = Key.make("gridSearchWithRecovery_GRID"); Job gs = GridSearch.startGridSearch( null, gridKey, params, hyperParms, - new SimpleParametersBuilderFactory<>(), + new GridSearch.SimpleParametersBuilderFactory<>(), new HyperSpaceSearchCriteria.CartesianSearchCriteria(), recovery, GridSearch.SEQUENTIAL_MODEL_BUILDING ); @@ -583,7 +586,7 @@ public void gridSearchWithRecoveryCancelGLM() throws IOException, InterruptedExc Key gridKey = Key.make("gridSearchWithRecoveryGlm"); Job gs = GridSearch.startGridSearch( null, gridKey, params, hyperParms, - new SimpleParametersBuilderFactory<>(), + new GridSearch.SimpleParametersBuilderFactory<>(), new HyperSpaceSearchCriteria.CartesianSearchCriteria(), recovery, GridSearch.SEQUENTIAL_MODEL_BUILDING ); @@ -847,7 +850,7 @@ public void test_parallel_random_search_with_max_models_being_less_than_parallel params._train = trainingFrame._key; params._response_column = "species"; - SimpleParametersBuilderFactory simpleParametersBuilderFactory = new SimpleParametersBuilderFactory(); + GridSearch.SimpleParametersBuilderFactory simpleParametersBuilderFactory = new GridSearch.SimpleParametersBuilderFactory(); HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria hyperSpaceSearchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria(); int custom_max_model = 2; hyperSpaceSearchCriteria.set_max_models(custom_max_model); @@ -883,7 +886,7 @@ public void test_parallel_random_search_with_max_models_being_greater_than_paral params._train = trainingFrame._key; params._response_column = "species"; - SimpleParametersBuilderFactory simpleParametersBuilderFactory = new SimpleParametersBuilderFactory(); + GridSearch.SimpleParametersBuilderFactory simpleParametersBuilderFactory = new GridSearch.SimpleParametersBuilderFactory(); HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria hyperSpaceSearchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria(); int custom_max_model = 3; hyperSpaceSearchCriteria.set_max_models(custom_max_model); diff --git a/h2o-algos/src/test/java/hex/grid/SequentialWalkerTest.java b/h2o-algos/src/test/java/hex/grid/SequentialWalkerTest.java index 0f25712489f2..ad47775d35ec 100644 --- a/h2o-algos/src/test/java/hex/grid/SequentialWalkerTest.java +++ b/h2o-algos/src/test/java/hex/grid/SequentialWalkerTest.java @@ -47,7 +47,7 @@ public void test_SequentialWalker() { new SequentialWalker<>( gbmParameters, hyperParams, - new SimpleParametersBuilderFactory<>(), + new GridSearch.SimpleParametersBuilderFactory<>(), new HyperSpaceSearchCriteria.SequentialSearchCriteria() ), GridSearch.SEQUENTIAL_MODEL_BUILDING @@ -85,7 +85,7 @@ public void test_SequentialWalker_getHyperParams() { SequentialWalker walker = new SequentialWalker<>( gbmParameters, hyperParams, - new SimpleParametersBuilderFactory<>(), + new GridSearch.SimpleParametersBuilderFactory<>(), new HyperSpaceSearchCriteria.SequentialSearchCriteria() ); Map exp = new HashMap<>(); @@ -124,7 +124,7 @@ public void test_SequentialWalker_supports_early_stopping() { new SequentialWalker<>( gbmParameters, hyperParams, - new SimpleParametersBuilderFactory<>(), + new GridSearch.SimpleParametersBuilderFactory<>(), new HyperSpaceSearchCriteria.SequentialSearchCriteria(StoppingCriteria.create() .stoppingRounds(1) .stoppingMetric(StoppingMetric.AUC) diff --git a/h2o-automl/src/main/java/ai/h2o/automl/AutoML.java b/h2o-automl/src/main/java/ai/h2o/automl/AutoML.java index a7b7e9c38f9b..2d0332d60b09 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/AutoML.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/AutoML.java @@ -11,22 +11,20 @@ import ai.h2o.automl.leaderboard.ModelGroup; import ai.h2o.automl.leaderboard.ModelProvider; import ai.h2o.automl.leaderboard.ModelStep; -import ai.h2o.automl.preprocessing.PipelineStep; -import ai.h2o.automl.preprocessing.PipelineStepDefinition; +import ai.h2o.automl.preprocessing.PreprocessingStep; import hex.Model; import hex.ScoreKeeper.StoppingMetric; import hex.genmodel.utils.DistributionFamily; import hex.leaderboard.*; -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineModel.PipelineParameters; import hex.splitframe.ShuffleSplitFrame; -import org.apache.log4j.Logger; import water.*; import water.automl.api.schemas3.AutoMLV99; import water.exceptions.H2OAutoMLException; import water.exceptions.H2OIllegalArgumentException; import water.fvec.Frame; import water.fvec.Vec; +import water.logging.Logger; +import water.logging.LoggerFactory; import water.nbhm.NonBlockingHashMap; import water.util.*; @@ -63,7 +61,7 @@ public enum Constraint { private static final boolean verifyImmutability = true; // check that trainingFrame hasn't been messed with private static final ThreadLocal timestampFormatForKeys = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyyMMdd_HHmmss")); - private static final Logger log = Logger.getLogger(AutoML.class); + private static final Logger log = LoggerFactory.getLogger(AutoML.class); private static LeaderboardExtensionsProvider createLeaderboardExtensionProvider(AutoML automl) { final Key amlKey = automl._key; @@ -168,10 +166,9 @@ public double[] getClassDistribution() { private Vec[] _originalTrainingFrameVecs; private String[] _originalTrainingFrameNames; private long[] _originalTrainingFrameChecksums; - private transient Map _trackedKeys = new NonBlockingHashMap<>(); + private transient NonBlockingHashMap _trackedKeys = new NonBlockingHashMap<>(); private transient ModelingStep[] _executionPlan; - private transient PipelineParameters _pipelineParams; - private transient Map _pipelineHyperParams; + private transient PreprocessingStep[] _preprocessing; transient StepResultState[] _stepsResults; private boolean _useAutoBlending; @@ -220,7 +217,7 @@ public AutoML(Key key, Date startTime, AutoMLBuildSpec buildSpec) { prepareData(); initLeaderboard(); - initPipeline(); + initPreprocessing(); _modelingStepsExecutor = new ModelingStepsExecutor(_leaderboard, _eventLog, _runCountdown); } catch (Exception e) { delete(); //cleanup potentially leaked keys @@ -390,46 +387,17 @@ private void initLeaderboard() { } _leaderboard.setExtensionsProvider(createLeaderboardExtensionProvider(this)); } - - private void initPipeline() { - final AutoMLBuildModels build = _buildSpec.build_models; - _pipelineParams = build.preprocessing == null ? null : new PipelineParameters(); - if (_pipelineParams == null) return; - List transformers = new ArrayList<>(); - Map hyperParams = new NonBlockingHashMap<>(); - for (PipelineStepDefinition def : build.preprocessing) { - PipelineStep step = def.newPipelineStep(this); - transformers.addAll(Arrays.asList(step.pipelineTransformers())); - Map hp = step.pipelineTransformersHyperParams(); - if (hp != null) hyperParams.putAll(hp); - } - if (transformers.isEmpty()) { - _pipelineParams = null; - _pipelineHyperParams = null; - } else { - _pipelineParams.setTransformers(transformers.toArray(new DataTransformer[0])); - _pipelineHyperParams = hyperParams; - trackKeys(transformers.stream().map(DataTransformer::getKey).toArray(Key[]::new)); - } - - //TODO: given that a transformer can reference a model (e.g. TE), - // and multiple transformers can refer the same model, - // then we should be careful when deleting a transformer (resp. an entire pipeline) - // as we may delete sth that is still in use by another transformer (resp. pipeline). - // --> ref count? - - //TODO: in AutoML, the same transformations are likely to occur on multiple (sometimes all) models, - // especially if the transformers parameters are not tuned. - // But it also depends if the transformers are context(CV)-sensitive (e.g. Target Encoding). - // See `CachingTransformer` for some thoughts about this. - } - - PipelineParameters getPipelineParams() { - return _pipelineParams; + + private void initPreprocessing() { + _preprocessing = _buildSpec.build_models.preprocessing == null + ? null + : Arrays.stream(_buildSpec.build_models.preprocessing) + .map(def -> def.newPreprocessingStep(this)) + .toArray(PreprocessingStep[]::new); } - Map getPipelineHyperParams() { - return _pipelineHyperParams; + PreprocessingStep[] getPreprocessing() { + return _preprocessing; } ModelingStep[] getExecutionPlan() { @@ -523,11 +491,9 @@ public void run() { eventLog().info(Stage.Workflow, "AutoML build started: " + EventLogEntry.dateTimeFormat.get().format(_runCountdown.start_time())) .setNamedValue("start_epoch", _runCountdown.start_time(), EventLogEntry.epochFormat.get()); try { - Scope.enter(); learn(); } finally { stop(); - Scope.exit(); } } @@ -793,6 +759,9 @@ private void prepareData() { private void learn() { List completed = new ArrayList<>(); + if (_preprocessing != null) { + for (PreprocessingStep preprocessingStep : _preprocessing) preprocessingStep.prepare(); + } for (ModelingStep step : getExecutionPlan()) { if (!exceededSearchLimits(step)) { StepResultState state = _modelingStepsExecutor.submit(step, job()); @@ -812,6 +781,9 @@ private void learn() { } } } + if (_preprocessing != null) { + for (PreprocessingStep preprocessingStep : _preprocessing) preprocessingStep.dispose(); + } _actualModelingSteps = session().getModelingStepsRegistry().createDefinitionPlanFromSteps(completed.toArray(new ModelingStep[0])); eventLog().info(Stage.Workflow, "Actual modeling steps: "+Arrays.toString(_actualModelingSteps)); } @@ -871,6 +843,11 @@ protected Futures remove_impl(Futures fs, boolean cascade) { if (leaderboard() != null) leaderboard().remove(fs, cascade); if (eventLog() != null) eventLog().remove(fs, cascade); if (session() != null) session().remove(fs, cascade); + if (cascade && _preprocessing != null) { + for (PreprocessingStep preprocessingStep : _preprocessing) { + preprocessingStep.remove(); + } + } for (Key key : _trackedKeys.keySet()) Keyed.remove(key, fs, true); return super.remove_impl(fs, cascade); diff --git a/h2o-automl/src/main/java/ai/h2o/automl/AutoMLBuildSpec.java b/h2o-automl/src/main/java/ai/h2o/automl/AutoMLBuildSpec.java index ec2f2ed43c24..6c742337daf3 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/AutoMLBuildSpec.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/AutoMLBuildSpec.java @@ -1,6 +1,6 @@ package ai.h2o.automl; -import ai.h2o.automl.preprocessing.PipelineStepDefinition; +import ai.h2o.automl.preprocessing.PreprocessingStepDefinition; import hex.Model; import hex.ScoreKeeper.StoppingMetric; import hex.genmodel.utils.DistributionFamily; @@ -180,7 +180,7 @@ public static final class AutoMLBuildModels extends Iced { public StepDefinition[] modeling_plan; public double exploitation_ratio = -1; public AutoMLCustomParameters algo_parameters = new AutoMLCustomParameters(); - public PipelineStepDefinition[] preprocessing; + public PreprocessingStepDefinition[] preprocessing; } public static final class AutoMLCustomParameters extends Iced { diff --git a/h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java b/h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java index a42cc5ac26a8..7208152d6645 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java @@ -9,6 +9,8 @@ import ai.h2o.automl.events.EventLog; import ai.h2o.automl.events.EventLogEntry; import ai.h2o.automl.events.EventLogEntry.Stage; +import ai.h2o.automl.preprocessing.PreprocessingConfig; +import ai.h2o.automl.preprocessing.PreprocessingStep; import hex.Model; import hex.Model.Parameters.FoldAssignmentScheme; import hex.ModelBuilder; @@ -21,30 +23,23 @@ import hex.grid.HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria; import hex.grid.HyperSpaceWalker; import hex.leaderboard.Leaderboard; -import hex.ModelParametersDelegateBuilderFactory; -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineModel.PipelineParameters; import jsr166y.CountedCompleter; import org.apache.commons.lang.builder.ToStringBuilder; import water.*; -import water.KeyGen.ConstantKeyGen; -import water.KeyGen.PatternKeyGen; import water.exceptions.H2OIllegalArgumentException; -import water.util.*; +import water.util.ArrayUtils; +import water.util.Countdown; +import water.util.EnumUtils; +import water.util.Log; import java.util.*; import java.util.function.Consumer; import java.util.function.Predicate; -import java.util.stream.Collectors; - -import static hex.pipeline.PipelineModel.ESTIMATOR_PARAM; /** * Parent class defining common properties and common logic for actual {@link AutoML} training steps. */ public abstract class ModelingStep extends Iced { - - protected static final String PIPELINE_KEY_PREFIX = "Pipeline_"; protected enum SeedPolicy { /** No seed will be used (= random). */ @@ -69,10 +64,20 @@ protected Job startSearch( assert baseParams != null; assert hyperParams.size() > 0; assert searchCriteria != null; - GridSearch.Builder builder = makeGridBuilder(resultKey, baseParams, hyperParams, searchCriteria); - aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+builder.dest()+" hyperparameter search") + applyPreprocessing(baseParams); + aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+resultKey+" hyperparameter search") .setNamedValue("start_"+_provider+"_"+_id, new Date(), EventLogEntry.epochFormat.get()); - return builder.start(); + return GridSearch.create( + resultKey, + HyperSpaceWalker.BaseWalker.WalkerFactory.create( + baseParams, + hyperParams, + new GridSearch.SimpleParametersBuilderFactory<>(), + searchCriteria + )) + .withParallelism(GridSearch.SEQUENTIAL_MODEL_BUILDING) + .withMaxConsecutiveFailures(aml()._maxConsecutiveModelFailures) + .start(); } @SuppressWarnings("unchecked") @@ -82,8 +87,11 @@ protected Job startModel( ) { assert resultKey != null; assert params != null; - ModelBuilder builder = makeBuilder(resultKey, params); - aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+builder.dest()+" model training") + Job job = new Job<>(resultKey, ModelBuilder.javaName(_algo.urlName()), _description); + applyPreprocessing(params); + ModelBuilder builder = ModelBuilder.make(_algo.urlName(), job, (Key) resultKey); + builder._parms = params; + aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+resultKey+" model training") .setNamedValue("start_"+_provider+"_"+_id, new Date(), EventLogEntry.epochFormat.get()); builder.init(false); // validate parameters if (builder._messages.length > 0) { @@ -97,40 +105,6 @@ protected Job startModel( } return builder.trainModelOnH2ONode(); } - - protected GridSearch.Builder makeGridBuilder(Key resultKey, - MP baseParams, - Map hyperParams, - HyperSpaceSearchCriteria searchCriteria) { - Model.Parameters finalParams = applyPipeline(resultKey, baseParams, hyperParams); - if (finalParams instanceof PipelineParameters) { - resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey); - aml().trackKeys(((PipelineParameters)finalParams)._transformers); - } - aml().trackKeys(resultKey); - return GridSearch.create( - resultKey, - HyperSpaceWalker.BaseWalker.WalkerFactory.create( - finalParams, - hyperParams, - new ModelParametersDelegateBuilderFactory<>(), - searchCriteria - )) - .withParallelism(GridSearch.SEQUENTIAL_MODEL_BUILDING) - .withMaxConsecutiveFailures(aml()._maxConsecutiveModelFailures); - } - - - protected ModelBuilder makeBuilder(Key resultKey, MP params) { - Model.Parameters finalParams = applyPipeline(resultKey, params, null); - if (finalParams instanceof PipelineParameters) resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey); - - Job job = new Job<>(resultKey, ModelBuilder.javaName(_algo.urlName()), _description); - ModelBuilder builder = ModelBuilder.make(finalParams.algoName(), job, (Key) resultKey); - builder._parms = finalParams; - builder._input_parms = finalParams.clone(); - return builder; - } private boolean validParameters(Model.Parameters parms, String[] fields) { try { @@ -389,6 +363,8 @@ protected void setCommonModelBuilderParams(Model.Parameters params) { setClassBalancingParams(params); params._custom_metric_func = buildSpec.build_control.custom_metric_func; + params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models; + params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment; params._export_checkpoints_dir = buildSpec.build_control.export_checkpoints_dir; /** Using _main_model_time_budget_factor to determine if and how we should restrict the time for the main model. @@ -401,8 +377,6 @@ protected void setCommonModelBuilderParams(Model.Parameters params) { protected void setCrossValidationParams(Model.Parameters params) { AutoMLBuildSpec buildSpec = aml().getBuildSpec(); params._keep_cross_validation_predictions = aml().getBlendingFrame() == null || buildSpec.build_control.keep_cross_validation_predictions; - params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models; - params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment; params._fold_column = buildSpec.input_spec.fold_column; if (buildSpec.input_spec.fold_column == null) { @@ -433,64 +407,19 @@ protected void setCustomParams(Model.Parameters params) { if (customParams == null) return; customParams.applyCustomParameters(_algo, params); } - - - /** - * If some algo/provider needs to modify the pipeline dynamically, it's recommended to override this. - */ - protected void filterPipelineTransformers(List transformers, Map transformerHyperParams) {} - - protected final void removeTransformersType(Class toRemove, List transformers, Map transformerHyperParams) { - List teIds = transformers.stream() - .filter(toRemove::isInstance) - .map(DataTransformer::name) - .collect(Collectors.toList()); - transformers.removeIf(dt -> teIds.contains(dt.name())); - transformerHyperParams.keySet().removeIf(k -> teIds.contains(k.split("\\.", 2)[0])); - } - - /** - * Transforms the simple model parameters and potential hyper-parameters into pipeline parameters. - * - * @param resultKey: the key of the final pipe - * @param params: parameters for the model being built in this step. - * @param hyperParams: hyper-parameters for the grid being built in this step (can be null if simple model). - * @return the final pipeline parameters that will be used to build the models in this step. - */ - protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map hyperParams) { - if (aml().getPipelineParams() == null) return params; - PipelineParameters pparams = aml().getPipelineParams().freshCopy(); - List transformers = new ArrayList<>(Arrays.asList(pparams.getTransformers())); // need to convert to ArrayList as `filterPipelineTransformers` may remove items below - Map transformersHyperParams = new HashMap<>(aml().getPipelineHyperParams()); - filterPipelineTransformers(transformers, transformersHyperParams); - Key[] defaultTransformersKeys = pparams._transformers; - pparams.setTransformers(transformers.toArray(new DataTransformer[0])); - if (defaultTransformersKeys.length != pparams._transformers.length) { - for (Key k : defaultTransformersKeys) { - if (!ArrayUtils.contains(pparams._transformers, k)) ((DataTransformer)k.get()).cleanup(); - } - } - if (pparams._transformers.length == 0) return params; - setCommonModelBuilderParams(pparams); - pparams._seed = params._seed; - pparams._max_runtime_secs = params._max_runtime_secs; - pparams._estimatorParams = params; - pparams._estimatorKeyGen = hyperParams == null - ? new ConstantKeyGen(resultKey) - : new PatternKeyGen("{0}|s/"+PIPELINE_KEY_PREFIX+"//") // in case of grid, remove the Pipeline prefix to obtain the estimator key, this allows naming compatibility with the classic mode. - ; - if (hyperParams != null) { - Map pipelineHyperParams = new HashMap<>(); - for (Map.Entry e : hyperParams.entrySet()) { - pipelineHyperParams.put(ESTIMATOR_PARAM+"."+e.getKey(), e.getValue()); - } - hyperParams.clear(); - hyperParams.putAll(pipelineHyperParams); - hyperParams.putAll(transformersHyperParams); - } - return pparams; + + protected void applyPreprocessing(Model.Parameters params) { + if (aml().getPreprocessing() == null) return; + for (PreprocessingStep preprocessingStep : aml().getPreprocessing()) { + PreprocessingStep.Completer complete = preprocessingStep.apply(params, getPreprocessingConfig()); + _onDone.add(j -> complete.run()); + } } + protected PreprocessingConfig getPreprocessingConfig() { + return new PreprocessingConfig(); + } + /** * Configures early-stopping for the model or set of models to be built. * @@ -712,6 +641,7 @@ protected Job hyperparameterSearch(Key key, Model.Parameters basePar setSearchCriteria(searchCriteria, baseParms); if (null == key) key = makeKey(_provider, true); + aml().trackKeys(key); Log.debug("Hyperparameter search: " + _provider + ", time remaining (ms): " + aml().timeRemainingMs()); aml().eventLog().debug(Stage.ModelTraining, searchCriteria.max_runtime_secs() == 0 @@ -831,7 +761,7 @@ protected Job startJob() { final Key selectionKey = Key.make(key+"_select"); final EventLog selectionEventLog = EventLog.getOrMake(selectionKey); // EventLog selectionEventLog = aml().eventLog(); - final LeaderboardHolder selectionLeaderboard = makeLeaderboard(selectionKey.toString(), selectionEventLog); +final LeaderboardHolder selectionLeaderboard = makeLeaderboard(selectionKey.toString(), selectionEventLog); { result.delete_and_lock(job); diff --git a/h2o-automl/src/main/java/ai/h2o/automl/ModelingStepsExecutor.java b/h2o-automl/src/main/java/ai/h2o/automl/ModelingStepsExecutor.java index dcbbc4ea4c02..727d2637cf8a 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/ModelingStepsExecutor.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/ModelingStepsExecutor.java @@ -106,7 +106,6 @@ StepResultState submit(ModelingStep step, Job parentJob) { } } catch (Exception e) { resultState.addState(new StepResultState(step.getGlobalId(), e)); - Log.err(e); } finally { step.onDone(job); } diff --git a/h2o-automl/src/main/java/ai/h2o/automl/StepResultState.java b/h2o-automl/src/main/java/ai/h2o/automl/StepResultState.java index 65f5fbaf2c05..5e408bb746e0 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/StepResultState.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/StepResultState.java @@ -19,7 +19,7 @@ enum ResultStatus { enum Resolution { sameAsMain, // resolves to the same state as the main step (ignoring other sub-step states). optimistic, // success if any success, otherwise cancelled if any cancelled, otherwise failed if any failure, otherwise skipped. - pessimistic, // failed if any failure, otherwise cancelled if any cancelled, otherwise success it any success, otherwise skipped. + pessimistic, // failures if any failure, otherwise cancelled if any cancelled, otherwise success it any success, otherwise skipped. } private final String _id; diff --git a/h2o-automl/src/main/java/ai/h2o/automl/modeling/DeepLearningStepsProvider.java b/h2o-automl/src/main/java/ai/h2o/automl/modeling/DeepLearningStepsProvider.java index 2be139a237a5..1f6f0de76f8b 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/modeling/DeepLearningStepsProvider.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/modeling/DeepLearningStepsProvider.java @@ -1,17 +1,13 @@ package ai.h2o.automl.modeling; import ai.h2o.automl.*; -import ai.h2o.targetencoding.pipeline.transformers.TargetEncoderFeatureTransformer; -import hex.Model; +import ai.h2o.automl.preprocessing.PreprocessingConfig; +import ai.h2o.automl.preprocessing.TargetEncoding; import hex.deeplearning.DeepLearningModel; import hex.deeplearning.DeepLearningModel.DeepLearningParameters; -import hex.pipeline.DataTransformer; -import water.Key; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; public class DeepLearningStepsProvider @@ -26,17 +22,14 @@ static abstract class DeepLearningModelStep extends ModelingStep.ModelStep hyperParams) { - return super.applyPipeline(resultKey, params, hyperParams); - } - - @Override - protected void filterPipelineTransformers(List transformers, Map transformerHyperParams) { - // legacy behavior: TE was not applied for deep learning as it is not useful for this algo. - removeTransformersType(TargetEncoderFeatureTransformer.class, transformers, transformerHyperParams); - } + + @Override + protected PreprocessingConfig getPreprocessingConfig() { + //TE useless for DNN + PreprocessingConfig config = super.getPreprocessingConfig(); + config.put(TargetEncoding.CONFIG_PREPARE_CV_ONLY, aml().isCVEnabled()); + return config; + } } static abstract class DeepLearningGridStep extends ModelingStep.GridStep { @@ -53,6 +46,14 @@ public DeepLearningParameters prepareModelParameters() { return params; } + @Override + protected PreprocessingConfig getPreprocessingConfig() { + //TE useless for DNN + PreprocessingConfig config = super.getPreprocessingConfig(); + config.put(TargetEncoding.CONFIG_PREPARE_CV_ONLY, aml().isCVEnabled()); + return config; + } + public Map prepareSearchParameters() { Map searchParams = new HashMap<>(); searchParams.put("_rho", new Double[] { 0.9, 0.95, 0.99 }); diff --git a/h2o-automl/src/main/java/ai/h2o/automl/modeling/GLMStepsProvider.java b/h2o-automl/src/main/java/ai/h2o/automl/modeling/GLMStepsProvider.java index 325d4984e918..d410e2a6076a 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/modeling/GLMStepsProvider.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/modeling/GLMStepsProvider.java @@ -1,7 +1,10 @@ package ai.h2o.automl.modeling; import ai.h2o.automl.*; +import ai.h2o.automl.preprocessing.PreprocessingConfig; +import ai.h2o.automl.preprocessing.TargetEncoding; import hex.Model; +import hex.genmodel.utils.DistributionFamily; import hex.glm.GLMModel; import hex.glm.GLMModel.GLMParameters; @@ -30,6 +33,15 @@ public GLMParameters prepareModelParameters() { params._lambda_search = true; return params; } + + @Override + protected PreprocessingConfig getPreprocessingConfig() { + //GLM (the exception as usual) doesn't support targetencoding if CV is enabled + // because it is initializing its lambdas + other params before CV (preventing changes in train frame during CV). + PreprocessingConfig config = super.getPreprocessingConfig(); + config.put(TargetEncoding.CONFIG_PREPARE_CV_ONLY, aml().isCVEnabled()); + return config; + } } diff --git a/h2o-automl/src/main/java/ai/h2o/automl/modeling/StackedEnsembleStepsProvider.java b/h2o-automl/src/main/java/ai/h2o/automl/modeling/StackedEnsembleStepsProvider.java index 263c832d3fae..a16adf2b62e9 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/modeling/StackedEnsembleStepsProvider.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/modeling/StackedEnsembleStepsProvider.java @@ -3,6 +3,8 @@ import ai.h2o.automl.*; import ai.h2o.automl.WorkAllocations.Work; import ai.h2o.automl.events.EventLogEntry; +import ai.h2o.automl.preprocessing.PreprocessingConfig; +import ai.h2o.automl.preprocessing.TargetEncoding; import hex.KeyValue; import hex.Model; import hex.ensemble.Metalearner; @@ -63,8 +65,11 @@ protected void setClassBalancingParams(Model.Parameters params) { } @Override - protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map hyperParams) { - return params; // no pipeline in SE, base models handle the transformations when making predictions. + protected PreprocessingConfig getPreprocessingConfig() { + //SE should not have TE applied, the base models already do it. + PreprocessingConfig config = super.getPreprocessingConfig(); + config.put(TargetEncoding.CONFIG_ENABLED, false); + return config; } @Override @@ -117,18 +122,13 @@ protected boolean hasDoppelganger(Key[] baseModelsKeys) { protected abstract Key[] getBaseModels(); protected String getModelType(Key key) { - ModelingStep step = aml().session().getModelingStep(key); -// if (step != null) { // fixme: commenting out this for now, as it interprets XRT as a DRF (which it is) and breaks legacy tests. We might want to reconsider this distinction as XRT is often very similar to DRF and doesn't bring much diversity to SEs, and the best_of SEs currently almost always have these 2. -// return step.getAlgo().name(); -// } else { // dirty case String keyStr = key.toString(); - int lookupStart = keyStr.startsWith(PIPELINE_KEY_PREFIX) ? PIPELINE_KEY_PREFIX.length() : 0; - return keyStr.substring(lookupStart, keyStr.indexOf('_', lookupStart)); -// } + return keyStr.substring(0, keyStr.indexOf('_')); } protected boolean isStackedEnsemble(Key key) { - return Algo.StackedEnsemble.name().equals(getModelType(key)); + ModelingStep step = aml().session().getModelingStep(key); + return step != null && step.getAlgo() == Algo.StackedEnsemble; } @Override diff --git a/h2o-automl/src/main/java/ai/h2o/automl/modeling/XGBoostSteps.java b/h2o-automl/src/main/java/ai/h2o/automl/modeling/XGBoostSteps.java index 01033767ceb3..603aa906577b 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/modeling/XGBoostSteps.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/modeling/XGBoostSteps.java @@ -9,7 +9,6 @@ import hex.grid.HyperSpaceSearchCriteria.SequentialSearchCriteria; import hex.grid.HyperSpaceSearchCriteria.StoppingCriteria; import hex.grid.SequentialWalker; -import hex.grid.SimpleParametersBuilderFactory; import hex.tree.xgboost.XGBoostModel; import hex.tree.xgboost.XGBoostModel.XGBoostParameters; import water.Job; @@ -201,8 +200,7 @@ public Map prepareSearchParameters() { XGBoostParameters.Booster.gbtree, XGBoostParameters.Booster.dart }); -// searchParams.put("_booster$weights", new Integer[] {2, 1}); - + searchParams.put("_reg_lambda", new Float[]{0.001f, 0.01f, 0.1f, 1f, 10f, 100f}); searchParams.put("_reg_alpha", new Float[]{0.001f, 0.01f, 0.1f, 0.5f, 1f}); @@ -370,7 +368,7 @@ protected Job startTraining(Key result, double maxRuntimeSecs) { new SequentialWalker<>( params, hyperParams, - new SimpleParametersBuilderFactory<>(), + new GridSearch.SimpleParametersBuilderFactory<>(), new SequentialSearchCriteria(StoppingCriteria.create() .maxRuntimeSecs((int)maxRuntimeSecs) .stoppingMetric(params._stopping_metric) diff --git a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PipelineStep.java b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PipelineStep.java deleted file mode 100644 index ed5fd7ebfb47..000000000000 --- a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PipelineStep.java +++ /dev/null @@ -1,23 +0,0 @@ -package ai.h2o.automl.preprocessing; - -import hex.pipeline.DataTransformer; - -import java.util.Map; - -public interface PipelineStep { - - String getType(); - - /** - * - * @return an array of pipeline {@link hex.pipeline.DataTransformer}s needed for this pipeline step. - */ - DataTransformer[] pipelineTransformers(); - - /** - * - * @return a map of hyper-parameters for the {@link hex.pipeline.DataTransformer}s of this pipeline step - * that can be used in AutoML grids. - */ - Map pipelineTransformersHyperParams(); -} diff --git a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingConfig.java b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingConfig.java new file mode 100644 index 000000000000..b571e6755bed --- /dev/null +++ b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingConfig.java @@ -0,0 +1,10 @@ +package ai.h2o.automl.preprocessing; + +import java.util.HashMap; + +public class PreprocessingConfig extends HashMap { + + boolean get(String key, boolean defaultValue) { + return (boolean) getOrDefault(key, defaultValue); + } +} diff --git a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStep.java b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStep.java new file mode 100644 index 000000000000..e3a32a361c71 --- /dev/null +++ b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStep.java @@ -0,0 +1,37 @@ +package ai.h2o.automl.preprocessing; + +import ai.h2o.automl.ModelingStep; +import hex.Model; + +public interface PreprocessingStep { + + interface Completer extends Runnable {} + + String getType(); + + /** + * preprocessing steps are prepared by default before the AutoML session starts training the first model. + */ + void prepare(); + + /** + * applies this preprocessing step to the model parameters right before the model training starts. + * @param params + * @return a function used to "complete" the preprocessing step: it is called by default at the end of the job creating model(s) from the given parms. + * This can mean for example cleaning the temporary artifacts that may have been created to apply the preprocessing step. + */ + Completer apply(Model.Parameters params, PreprocessingConfig config); + + /** + * preprocessing steps are disposed by default at the end of the AutoML training session. + * Note that disposing here doesn't mean being removed from the system, + * the goal is mainly to clean resources that are not needed anymore for the current AutoML run. + */ + void dispose(); + + /** + * Completely remove from the system + */ + void remove(); + +} diff --git a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PipelineStepDefinition.java b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStepDefinition.java similarity index 59% rename from h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PipelineStepDefinition.java rename to h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStepDefinition.java index 05f93e95eaa0..568599a16633 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PipelineStepDefinition.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStepDefinition.java @@ -3,7 +3,7 @@ import ai.h2o.automl.AutoML; import water.Iced; -public class PipelineStepDefinition extends Iced { +public class PreprocessingStepDefinition extends Iced { public enum Type { TargetEncoding @@ -11,13 +11,13 @@ public enum Type { Type _type; - public PipelineStepDefinition() { /* for reflection */ } + public PreprocessingStepDefinition() { /* for reflection */ } - public PipelineStepDefinition(Type type) { + public PreprocessingStepDefinition(Type type) { _type = type; } - public PipelineStep newPipelineStep(AutoML aml) { + public PreprocessingStep newPreprocessingStep(AutoML aml) { switch (_type) { case TargetEncoding: return new TargetEncoding(aml); diff --git a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/TargetEncoding.java b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/TargetEncoding.java index 5f1b933be65a..b242a97f0665 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/TargetEncoding.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/TargetEncoding.java @@ -1,13 +1,17 @@ package ai.h2o.automl.preprocessing; import ai.h2o.automl.AutoML; +import ai.h2o.automl.AutoMLBuildSpec.AutoMLBuildControl; import ai.h2o.automl.AutoMLBuildSpec.AutoMLInput; +import ai.h2o.automl.events.EventLogEntry.Stage; +import ai.h2o.targetencoding.TargetEncoder; +import ai.h2o.targetencoding.TargetEncoderModel; import ai.h2o.targetencoding.TargetEncoderModel.DataLeakageHandlingStrategy; import ai.h2o.targetencoding.TargetEncoderModel.TargetEncoderParameters; -import ai.h2o.targetencoding.pipeline.transformers.TargetEncoderFeatureTransformer; +import ai.h2o.targetencoding.TargetEncoderPreprocessor; +import hex.Model; import hex.Model.Parameters.FoldAssignmentScheme; -import hex.pipeline.DataTransformer; -import hex.pipeline.transformers.KFoldColumnGenerator; +import hex.ModelPreprocessor; import water.DKV; import water.Key; import water.fvec.Frame; @@ -18,9 +22,19 @@ import java.util.*; import java.util.function.Predicate; -public class TargetEncoding implements PipelineStep { +public class TargetEncoding implements PreprocessingStep { + + public static String CONFIG_ENABLED = "target_encoding_enabled"; + public static String CONFIG_PREPARE_CV_ONLY = "target_encoding_prepare_cv_only"; + + static String TE_FOLD_COLUMN_SUFFIX = "_te_fold"; + private static final Completer NOOP = () -> {}; private AutoML _aml; + private TargetEncoderPreprocessor _tePreprocessor; + private TargetEncoderModel _teModel; + private final List _disposables = new ArrayList<>(); + private TargetEncoderParameters _defaultParams; private boolean _encodeAllColumns = false; // if true, bypass all restrictions in columns selection. private int _columnCardinalityThreshold = 25; // the minimal cardinality for a column to be TE encoded. @@ -31,7 +45,100 @@ public TargetEncoding(AutoML aml) { @Override public String getType() { - return PipelineStepDefinition.Type.TargetEncoding.name(); + return PreprocessingStepDefinition.Type.TargetEncoding.name(); + } + + @Override + public void prepare() { + AutoMLInput amlInput = _aml.getBuildSpec().input_spec; + AutoMLBuildControl amlBuild = _aml.getBuildSpec().build_control; + Frame amlTrain = _aml.getTrainingFrame(); + + TargetEncoderParameters params = (TargetEncoderParameters) getDefaultParams().clone(); + params._train = amlTrain._key; + params._response_column = amlInput.response_column; + params._seed = amlBuild.stopping_criteria.seed(); + + Set teColumns = selectColumnsToEncode(amlTrain, params); + if (teColumns.isEmpty()) return; + + _aml.eventLog().warn(Stage.FeatureCreation, + "Target Encoding integration in AutoML is in an experimental stage, the models obtained with this feature can not yet be downloaded as MOJO for production."); + + + if (_aml.isCVEnabled()) { + params._data_leakage_handling = DataLeakageHandlingStrategy.KFold; + params._fold_column = amlInput.fold_column; + if (params._fold_column == null) { + //generate fold column + Frame train = new Frame(params.train()); + Vec foldColumn = createFoldColumn( + params.train(), + FoldAssignmentScheme.Modulo, + amlBuild.nfolds, + params._response_column, + params._seed + ); + DKV.put(foldColumn); + params._fold_column = params._response_column+TE_FOLD_COLUMN_SUFFIX; + train.add(params._fold_column, foldColumn); + register(train, params._train.toString(), true); + params._train = train._key; + _disposables.add(() -> { + foldColumn.remove(); + DKV.remove(train._key); + }); + } + } + String[] keep = params.getNonPredictors(); + params._ignored_columns = Arrays.stream(amlTrain.names()) + .filter(col -> !teColumns.contains(col) && !ArrayUtils.contains(keep, col)) + .toArray(String[]::new); + + TargetEncoder te = new TargetEncoder(params, _aml.makeKey(getType(), null, false)); + _teModel = te.trainModel().get(); + _tePreprocessor = new TargetEncoderPreprocessor(_teModel); + } + + @Override + public Completer apply(Model.Parameters params, PreprocessingConfig config) { + if (_tePreprocessor == null || !config.get(CONFIG_ENABLED, true)) return NOOP; + + if (!config.get(CONFIG_PREPARE_CV_ONLY, false)) + params._preprocessors = (Key[])ArrayUtils.append(params._preprocessors, _tePreprocessor._key); + + Frame train = new Frame(params.train()); + String foldColumn = _teModel._parms._fold_column; + boolean addFoldColumn = foldColumn != null && train.find(foldColumn) < 0; + if (addFoldColumn) { + train.add(foldColumn, _teModel._parms._train.get().vec(foldColumn)); + register(train, params._train.toString(), true); + params._train = train._key; + params._fold_column = foldColumn; + params._nfolds = 0; // to avoid confusion or errors + params._fold_assignment = FoldAssignmentScheme.AUTO; // to avoid confusion or errors + } + + return () -> { + //revert train changes + if (addFoldColumn) { + DKV.remove(train._key); + } + }; + } + + @Override + public void dispose() { + for (Completer disposable : _disposables) disposable.run(); + } + + @Override + public void remove() { + if (_tePreprocessor != null) { + _tePreprocessor.remove(true); + _tePreprocessor = null; + _teModel = null; + } } public void setDefaultParams(TargetEncoderParameters defaultParams) { @@ -86,42 +193,12 @@ private Set selectColumnsToEncode(Frame fr, TargetEncoderParameters para return encode; } - @Override - public DataTransformer[] pipelineTransformers() { - List dts = new ArrayList<>(); - TargetEncoderParameters teParams = (TargetEncoderParameters) getDefaultParams().clone(); - Frame train = _aml.getTrainingFrame(); - Set teColumns = selectColumnsToEncode(train, teParams); - if (teColumns.isEmpty()) return new DataTransformer[0]; - - String[] keep = teParams.getNonPredictors(); - teParams._ignored_columns = Arrays.stream(train.names()) - .filter(col -> !teColumns.contains(col) && !ArrayUtils.contains(keep, col)) - .toArray(String[]::new); - if (_aml.isCVEnabled()) { - dts.add(new KFoldColumnGenerator() - .name("add_fold_column") - .description("If cross-validation is enabled, generates (if needed) a fold column used by Target Encoder and for the final estimator") -// .init() - ); - teParams._data_leakage_handling = DataLeakageHandlingStrategy.KFold; - } - dts.add(new TargetEncoderFeatureTransformer(teParams) - .name("default_TE") - .description("Applies Target Encoding to selected categorical features") - .enableCache() -// .init() - ); - return dts.toArray(new DataTransformer[0]); + TargetEncoderPreprocessor getTEPreprocessor() { + return _tePreprocessor; } - @Override - public Map pipelineTransformersHyperParams() { - Map hp = new HashMap<>(); - hp.put("default_TE._enabled", new Boolean[] {Boolean.TRUE, Boolean.FALSE}); - hp.put("default_TE._keep_original_categorical_columns", new Boolean[] {Boolean.TRUE, Boolean.FALSE}); - hp.put("default_TE._blending", new Boolean[] {Boolean.TRUE, Boolean.FALSE}); - return hp; + TargetEncoderModel getTEModel() { + return _teModel; } private static void register(Frame fr, String keyPrefix, boolean force) { @@ -133,10 +210,10 @@ private static void register(Frame fr, String keyPrefix, boolean force) { } public static Vec createFoldColumn(Frame fr, - FoldAssignmentScheme fold_assignment, - int nfolds, - String responseColumn, - long seed) { + FoldAssignmentScheme fold_assignment, + int nfolds, + String responseColumn, + long seed) { Vec foldColumn; switch (fold_assignment) { default: diff --git a/h2o-automl/src/main/java/water/automl/api/schemas3/AutoMLBuildSpecV99.java b/h2o-automl/src/main/java/water/automl/api/schemas3/AutoMLBuildSpecV99.java index 458e28ddebbd..bfc74f362d99 100644 --- a/h2o-automl/src/main/java/water/automl/api/schemas3/AutoMLBuildSpecV99.java +++ b/h2o-automl/src/main/java/water/automl/api/schemas3/AutoMLBuildSpecV99.java @@ -284,7 +284,7 @@ public static final class AutoMLBuildModelsV99 extends SchemaV3 { +public final class PreprocessingStepDefinitionV99 extends Schema { public static final class TypeProvider extends EnumValuesProvider { public TypeProvider() { diff --git a/h2o-automl/src/main/resources/META-INF/services/water.api.Schema b/h2o-automl/src/main/resources/META-INF/services/water.api.Schema index ca831a3af5fd..b19787770d47 100644 --- a/h2o-automl/src/main/resources/META-INF/services/water.api.Schema +++ b/h2o-automl/src/main/resources/META-INF/services/water.api.Schema @@ -12,5 +12,5 @@ water.automl.api.schemas3.EventLogEntryV99 water.automl.api.schemas3.EventLogV99 water.automl.api.schemas3.StepDefinitionV99 water.automl.api.schemas3.StepDefinitionV99$StepV99 -water.automl.api.schemas3.PipelineStepDefinitionV99 +water.automl.api.schemas3.PreprocessingStepDefinitionV99 water.automl.api.schemas3.SchemaExtensions$ModelsKeyV3 diff --git a/h2o-automl/src/test/java/ai/h2o/automl/AutoMLTest.java b/h2o-automl/src/test/java/ai/h2o/automl/AutoMLTest.java index 72e6323985ba..b4395da915c4 100644 --- a/h2o-automl/src/test/java/ai/h2o/automl/AutoMLTest.java +++ b/h2o-automl/src/test/java/ai/h2o/automl/AutoMLTest.java @@ -233,15 +233,13 @@ public class AutoMLTest { AutoML aml = scope.track(AutoML.startAutoML(autoMLBuildSpec, true)); aml.get(); - - System.out.println(aml.leaderboard().toTwoDimTable(ModelProvider.COLUMN.getName(), ModelStep.COLUMN.getName()).toString()); + //as max_models is provided, no time budget is assigned to the models by default, // even when user also provides max_runtime_secs: in this case, the latter only acts as a global limit // and cancels the last training step. StepResultState[] steps = aml._stepsResults; assertTrue("shouldn't have managed to train all max_models", steps.length < autoMLBuildSpec.build_control.stopping_criteria.max_models()); StepResultState lastStep = steps[steps.length - 1]; - Log.info("last step = "+lastStep); assertTrue("last model training should have been cancelled", lastStep.is(StepResultState.ResultStatus.cancelled) // if timeout during model training || lastStep.is(StepResultState.ResultStatus.success)); // if timeout between models diff --git a/h2o-automl/src/test/java/ai/h2o/automl/ModelingStepTest.java b/h2o-automl/src/test/java/ai/h2o/automl/ModelingStepTest.java index 77b152f1b86d..8220ffa2f715 100644 --- a/h2o-automl/src/test/java/ai/h2o/automl/ModelingStepTest.java +++ b/h2o-automl/src/test/java/ai/h2o/automl/ModelingStepTest.java @@ -3,7 +3,6 @@ import ai.h2o.automl.dummy.DummyBuilder; import ai.h2o.automl.dummy.DummyModel; import ai.h2o.automl.dummy.DummyStepsProvider; -import ai.h2o.automl.dummy.DummyStepsProvider.DummyGridStep; import ai.h2o.automl.dummy.DummyStepsProvider.DummyModelStep; import hex.Model; import hex.ScoreKeeper; @@ -213,7 +212,24 @@ public Model.Parameters prepareModelParameters() { } } + private static class DummyGridStep extends ModelingStep.GridStep { + public DummyGridStep(IAlgo algo, String id, AutoML autoML) { + super(TestingModelSteps.NAME, algo, id, autoML); + } + + @Override + public Model.Parameters prepareModelParameters() { + return new DummyModel.DummyModelParameters(); + } + + @Override + public Map prepareSearchParameters() { + Map searchParams = new HashMap<>(); + searchParams.put("_tag", new String[] {"one", "two", "three"}); + return searchParams; + } + } private static class DummySelectionStep extends ModelingStep.SelectionStep { boolean _useSearch; diff --git a/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyBuilder.java b/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyBuilder.java index 65fde16ca2b5..eb4f162e8a17 100644 --- a/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyBuilder.java +++ b/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyBuilder.java @@ -1,7 +1,8 @@ package ai.h2o.automl.dummy; import ai.h2o.automl.IAlgo; -import hex.*; +import hex.ModelBuilder; +import hex.ModelCategory; import org.junit.Ignore; import water.exceptions.H2OIllegalArgumentException; @@ -57,7 +58,6 @@ class DummyDriver extends Driver { public void computeImpl() { init(true); DummyModel model = new DummyModel(_result, _parms, new DummyModel.DummyModelOutput(DummyBuilder.this)); - model._output._training_metrics = new ModelMetricsBinomial(model, model._parms.train(), 1, 1, null, 1, AUC2.emptyAUC(), 1, null, null); model.delete_and_lock(_job); model.update(_job); model.unlock(_job); diff --git a/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyModel.java b/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyModel.java index d9563cf55122..700693b31509 100644 --- a/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyModel.java +++ b/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyModel.java @@ -1,6 +1,9 @@ package ai.h2o.automl.dummy; -import hex.*; +import hex.Model; +import hex.ModelBuilder; +import hex.ModelMetrics; +import hex.ModelMetricsBinomial; import org.junit.Ignore; import water.Key; import water.util.IcedHashMap; diff --git a/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyStepsProvider.java b/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyStepsProvider.java index 0b972968c6ec..5daee4411918 100644 --- a/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyStepsProvider.java +++ b/h2o-automl/src/test/java/ai/h2o/automl/dummy/DummyStepsProvider.java @@ -4,8 +4,6 @@ import hex.Model; import org.junit.Ignore; -import java.util.HashMap; -import java.util.Map; import java.util.function.Function; @Ignore("utility class") @@ -79,23 +77,4 @@ public Model.Parameters prepareModelParameters() { return new DummyModel.DummyModelParameters(); } } - - public static class DummyGridStep extends ModelingStep.GridStep { - - public DummyGridStep(IAlgo algo, String id, AutoML autoML) { - super(DummyModelSteps.NAME, algo, id, autoML); - } - - @Override - public Model.Parameters prepareModelParameters() { - return new DummyModel.DummyModelParameters(); - } - - @Override - public Map prepareSearchParameters() { - Map searchParams = new HashMap<>(); - searchParams.put("_tag", new String[] {"one", "two", "three"}); - return searchParams; - } - } - } +} diff --git a/h2o-automl/src/test/java/ai/h2o/automl/preprocessing/PipelineIntegrationTest.java b/h2o-automl/src/test/java/ai/h2o/automl/preprocessing/PipelineIntegrationTest.java deleted file mode 100644 index ff64ddf04170..000000000000 --- a/h2o-automl/src/test/java/ai/h2o/automl/preprocessing/PipelineIntegrationTest.java +++ /dev/null @@ -1,155 +0,0 @@ -package ai.h2o.automl.preprocessing; - -import ai.h2o.automl.Algo; -import ai.h2o.automl.AutoML; -import ai.h2o.automl.AutoMLBuildSpec; -import ai.h2o.automl.StepDefinition; -import hex.Model; -import hex.SplitFrame; -import hex.ensemble.StackedEnsembleModel; -import hex.pipeline.PipelineModel; -import org.junit.Test; -import org.junit.runner.RunWith; -import water.Key; -import water.Scope; -import water.TestUtil; -import water.fvec.Frame; -import water.runner.CloudSize; -import water.runner.H2ORunner; - -import static org.junit.Assert.*; -import static water.TestUtil.parseTestFile; - -@CloudSize(1) -@RunWith(H2ORunner.class) -public class PipelineIntegrationTest { - - - @Test - public void test_automl_run_with_cv_enabling_pipelines() { - try { - Scope.enter(); - AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec(); - Frame fr = Scope.track(parseTestFile("./smalldata/titanic/titanic_expanded.csv")); - SplitFrame sf = new SplitFrame(fr, new double[] { 0.7, 0.3 }, new Key[]{Key.make("titanic_train"), Key.make("titanic_test")}); - sf.exec().get(); - Frame train = Scope.track(sf._destination_frames[0].get()); - Frame test = Scope.track(sf._destination_frames[1].get()); - TestUtil.printOutFrameAsTable(test); - - autoMLBuildSpec.input_spec.training_frame = train._key; - autoMLBuildSpec.input_spec.response_column = "survived"; - autoMLBuildSpec.build_control.stopping_criteria.set_max_models(20); // sth big enough to test all algos+grids with TE - autoMLBuildSpec.build_control.stopping_criteria.set_seed(42); - autoMLBuildSpec.build_control.nfolds = 3; - autoMLBuildSpec.build_models.preprocessing = new PipelineStepDefinition[] { - new PipelineStepDefinition(PipelineStepDefinition.Type.TargetEncoding) - }; - - AutoML aml = AutoML.startAutoML(autoMLBuildSpec); Scope.track_generic(aml); - aml.get(); - System.out.println(aml.leaderboard().toTwoDimTable()); - for (Model m : aml.leaderboard().getModels()) { - if (m instanceof StackedEnsembleModel) { - assertFalse(m.haveMojo()); // all SEs should have at least one Pipeline model as a base model which doesn't support MOJO - assertFalse(m.havePojo()); - } else { - assertTrue(m instanceof PipelineModel); - } - } - } finally { - Scope.exit(); - } - } - - @Test - public void test_automl_run_with_cv_enabling_pipelines_scored_by_leaderboard_frame() { - try { - Scope.enter(); - AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec(); - Frame fr = Scope.track(parseTestFile("./smalldata/titanic/titanic_expanded.csv")); - SplitFrame sf = new SplitFrame(fr, new double[] { 0.7, 0.3 }, new Key[]{Key.make("titanic_train"), Key.make("titanic_test")}); - sf.exec().get(); - Frame train = Scope.track(sf._destination_frames[0].get()); - Frame test = Scope.track(sf._destination_frames[1].get()); - TestUtil.printOutFrameAsTable(test); - - autoMLBuildSpec.input_spec.training_frame = train._key; -// autoMLBuildSpec.input_spec.validation_frame = test._key; - autoMLBuildSpec.input_spec.leaderboard_frame = test._key; - autoMLBuildSpec.input_spec.response_column = "survived"; - autoMLBuildSpec.build_control.stopping_criteria.set_max_models(12); // sth big enough to have some grid - autoMLBuildSpec.build_control.stopping_criteria.set_seed(42); - autoMLBuildSpec.build_control.nfolds = 3; - autoMLBuildSpec.build_models.preprocessing = new PipelineStepDefinition[] { - new PipelineStepDefinition(PipelineStepDefinition.Type.TargetEncoding) - }; - autoMLBuildSpec.build_models.modeling_plan = new StepDefinition[] { - new StepDefinition(Algo.GLM.name()), - new StepDefinition(Algo.XGBoost.name(), StepDefinition.Alias.grids), - new StepDefinition(Algo.GBM.name(), StepDefinition.Alias.grids), - new StepDefinition(Algo.StackedEnsemble.name(), StepDefinition.Alias.defaults), - }; - - AutoML aml = AutoML.startAutoML(autoMLBuildSpec); Scope.track_generic(aml); - aml.get(); - System.out.println(aml.leaderboard().toTwoDimTable()); - for (Model m : aml.leaderboard().getModels()) { - if (m instanceof StackedEnsembleModel) { - assertFalse(m.haveMojo()); // all SEs should have at least one Pipeline model as a base model which doesn't support MOJO - assertFalse(m.havePojo()); - } else { - assertTrue(m instanceof PipelineModel); - } - } - } finally { - Scope.exit(); - } - } - - @Test - public void test_automl_run_without_cv_enabling_pipelines_scored_by_leaderboard_frame() { - try { - Scope.enter(); - AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec(); - Frame fr = Scope.track(parseTestFile("./smalldata/titanic/titanic_expanded.csv")); - SplitFrame sf = new SplitFrame(fr, new double[] { 0.7, 0.3 }, new Key[]{Key.make("titanic_train"), Key.make("titanic_test")}); - sf.exec().get(); - Frame train = Scope.track(sf._destination_frames[0].get()); - Frame test = Scope.track(sf._destination_frames[1].get()); - TestUtil.printOutFrameAsTable(test); - - autoMLBuildSpec.input_spec.training_frame = train._key; - autoMLBuildSpec.input_spec.validation_frame = test._key; - autoMLBuildSpec.input_spec.response_column = "survived"; - autoMLBuildSpec.build_control.stopping_criteria.set_max_models(12); // sth big enough to have some grid - autoMLBuildSpec.build_control.stopping_criteria.set_seed(42); - autoMLBuildSpec.build_control.nfolds = 0; - autoMLBuildSpec.build_models.preprocessing = new PipelineStepDefinition[] { - new PipelineStepDefinition(PipelineStepDefinition.Type.TargetEncoding) - }; - autoMLBuildSpec.build_models.modeling_plan = new StepDefinition[] { - new StepDefinition(Algo.GLM.name()), - new StepDefinition(Algo.XGBoost.name(), StepDefinition.Alias.grids), - new StepDefinition(Algo.GBM.name(), StepDefinition.Alias.grids), - new StepDefinition(Algo.StackedEnsemble.name(), StepDefinition.Alias.defaults), - }; - - AutoML aml = AutoML.startAutoML(autoMLBuildSpec); Scope.track_generic(aml); - aml.get(); - System.out.println(aml.leaderboard().toTwoDimTable()); - for (Model m : aml.leaderboard().getModels()) { - if (m instanceof StackedEnsembleModel) { - assertFalse(m.haveMojo()); // all SEs should have at least one Pipeline model as a base model which doesn't support MOJO - assertFalse(m.havePojo()); - } else { - assertTrue(m instanceof PipelineModel); - } - } - } finally { - Scope.exit(); - } - } - - -} diff --git a/h2o-automl/src/test/java/ai/h2o/automl/preprocessing/TargetEncodingTest.java b/h2o-automl/src/test/java/ai/h2o/automl/preprocessing/TargetEncodingTest.java index fbf7633ea495..cdc611692a38 100644 --- a/h2o-automl/src/test/java/ai/h2o/automl/preprocessing/TargetEncodingTest.java +++ b/h2o-automl/src/test/java/ai/h2o/automl/preprocessing/TargetEncodingTest.java @@ -1,30 +1,35 @@ package ai.h2o.automl.preprocessing; import ai.h2o.automl.*; -import ai.h2o.automl.dummy.DummyBuilder; -import ai.h2o.automl.dummy.DummyStepsProvider; -import ai.h2o.automl.dummy.DummyStepsProvider.DummyGridStep; -import ai.h2o.automl.dummy.DummyStepsProvider.DummyModelStep; +import ai.h2o.automl.dummy.DummyModel; +import ai.h2o.automl.preprocessing.PreprocessingStepDefinition.Type; import ai.h2o.targetencoding.TargetEncoderModel.DataLeakageHandlingStrategy; import ai.h2o.targetencoding.TargetEncoderModel.TargetEncoderParameters; -import ai.h2o.targetencoding.pipeline.transformers.TargetEncoderFeatureTransformer; +import ai.h2o.targetencoding.TargetEncoderPreprocessor; import hex.Model; import hex.SplitFrame; import hex.deeplearning.DeepLearningModel; import hex.ensemble.StackedEnsembleModel; -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineModel; -import org.junit.*; +import hex.glm.GLMModel; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; import org.junit.runner.RunWith; -import water.*; +import water.DKV; +import water.Key; +import water.Keyed; +import water.Scope; import water.fvec.Frame; import water.fvec.TestFrameBuilder; import water.fvec.Vec; -import water.junit.rules.ScopeTracker; import water.runner.CloudSize; import water.runner.H2ORunner; import water.util.ArrayUtils; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; + import static org.junit.Assert.*; import static water.TestUtil.*; import static water.TestUtil.ar; @@ -32,24 +37,14 @@ @CloudSize(1) @RunWith(H2ORunner.class) public class TargetEncodingTest { - - @FunctionalInterface - interface Callback { - void call(T t); - } - - @Rule - public ScopeTracker scope = new ScopeTracker(); - @BeforeClass - public static void setupDummySteps() { - DummyStepsProvider provider = new DummyStepsProvider(); - provider.modelStepsFactory = DummySteps::new; - ModelingStepsRegistry.registerProvider(provider); - } - - private AutoML runDummyAutoML(Callback configureBuildSpec) { - Frame fr = new TestFrameBuilder() + private List toDelete = new ArrayList<>(); + private AutoML aml; + private Frame fr; + + @Before + public void setup() { + fr = new TestFrameBuilder() .withName("dummy_fr") .withColNames("cat1", "numerical", "cat2", "target", "foldc") .withVecTypes(Vec.T_CAT, Vec.T_NUM, Vec.T_CAT, Vec.T_CAT, Vec.T_NUM) @@ -59,194 +54,196 @@ private AutoML runDummyAutoML(Callback configureBuildSpec) { .withDataForCol(3, ar("yes", "no", "no", "yes", "yes", "no")) .withDataForCol(4, ar(1, 1, 1, 2, 2, 2)) .build(); - DKV.put(fr); + DKV.put(fr); toDelete.add(fr); AutoMLBuildSpec buildSpec = new AutoMLBuildSpec(); buildSpec.input_spec.training_frame = fr._key; buildSpec.input_spec.response_column = "target"; - buildSpec.build_models.preprocessing = new PipelineStepDefinition[] { - new TEStepDefinition() - }; - buildSpec.build_models.modeling_plan = new StepDefinition[] { - new StepDefinition("dummy", StepDefinition.Alias.defaults) - }; - configureBuildSpec.call(buildSpec); - AutoML aml = Scope.track_generic(AutoML.startAutoML(buildSpec)); - aml.get(); - return aml; + aml = new AutoML(buildSpec); + DKV.put(aml); toDelete.add(aml); + } + + @After + public void cleanup() { + toDelete.forEach(Keyed::remove); } @Test public void test_default_params() { - AutoML aml = runDummyAutoML(spec -> { - spec.build_control.nfolds = 0; //disabling CV on AutoML - }); - - Model m = aml.leaderboard().getLeader(); - DataTransformer[] transformers = ((PipelineModel) m)._output.getTransformers(); - assertNotNull(transformers); - assertEquals(1, transformers.length); - TargetEncoderFeatureTransformer teTrans = (TargetEncoderFeatureTransformer)transformers[0]; - TargetEncoderParameters teParams = teTrans.getModel()._parms; - assertNull(teParams._fold_column); - assertEquals(DataLeakageHandlingStrategy.None, teParams._data_leakage_handling); - assertFalse(teParams._keep_original_categorical_columns); - assertTrue(teParams._blending); - assertEquals(0, teParams._noise, 0); + aml.getBuildSpec().build_control.nfolds = 0; //disabling CV on AutoML + TargetEncoding te = new TargetEncoding(aml); + te.setEncodeAllColumns(true); + try { + Scope.enter(); + te.prepare(); + assertNotNull(te.getTEModel()); + assertNotNull(te.getTEPreprocessor()); + Scope.track_generic(te.getTEModel()); + Scope.track_generic(te.getTEPreprocessor()); + + TargetEncoderParameters teParams = te.getTEModel()._parms; + assertNull(teParams._fold_column); + assertEquals(DataLeakageHandlingStrategy.None, teParams._data_leakage_handling); + assertFalse(teParams._keep_original_categorical_columns); + assertTrue(teParams._blending); + assertEquals(0, teParams._noise, 0); + } finally { + te.dispose(); + Scope.exit(); + } } @Test - public void test_te_pipeline_lifecycle_automl_no_cv() { - AutoML aml = runDummyAutoML(spec -> { - spec.build_control.nfolds = 0; //disabling CV on AutoML - }); - PipelineModel m = (PipelineModel) aml.leaderboard().getLeader(); - DataTransformer[] transformers = m._output.getTransformers(); - assertNotNull(transformers); - assertEquals(1, transformers.length); - TargetEncoderFeatureTransformer teTrans = (TargetEncoderFeatureTransformer)transformers[0]; - TargetEncoderParameters teParams = teTrans.getModel()._parms; - assertNull(teParams._fold_column); - assertEquals(DataLeakageHandlingStrategy.None, teParams._data_leakage_handling); - - Model.Parameters eParams = m._output.getEstimatorModel()._parms; - assertEquals(0, eParams._nfolds); - assertNull(eParams._fold_column); + public void test_te_preprocessing_lifecycle_automl_no_cv() { + aml.getBuildSpec().build_control.nfolds = 0; //disabling CV on AutoML + TargetEncoding te = new TargetEncoding(aml); + te.setEncodeAllColumns(true); + assertNull(te.getTEModel()); + assertNull(te.getTEPreprocessor()); + try { + Scope.enter(); + te.prepare(); + assertNotNull(te.getTEModel()); + assertNotNull(te.getTEPreprocessor()); + Scope.track_generic(te.getTEModel()); + Scope.track_generic(te.getTEPreprocessor()); + assertNull(te.getTEModel()._parms._fold_column); + assertEquals(DataLeakageHandlingStrategy.None, te.getTEModel()._parms._data_leakage_handling); + + Model.Parameters params = new DummyModel.DummyModelParameters(); + params._train = fr._key; + params._nfolds = 0; + params._fold_column = null; + + PreprocessingStep.Completer complete = te.apply(params, new PreprocessingConfig()); + assertEquals(0, params._nfolds); + assertNull(params._fold_column); + complete.run(); + } finally { + te.dispose(); + Scope.exit(); + } } @Test - public void test_te_pipeline_lifecycle_with_automl_cv_nfolds() { + public void test_te_preprocessing_lifecycle_with_automl_cv_nfolds() { int nfolds = 3; - AutoML aml = runDummyAutoML(spec -> { - spec.build_control.nfolds = nfolds; - spec.build_control.keep_cross_validation_models = true; - }); - PipelineModel m = (PipelineModel) aml.leaderboard().getLeader(); - DataTransformer[] transformers = m._output.getTransformers(); - assertNotNull(transformers); - assertEquals(2, transformers.length); //with CV enabled and no fold column, an additional transformer is added to generate the latter, - TargetEncoderFeatureTransformer teTrans = (TargetEncoderFeatureTransformer)transformers[1]; - TargetEncoderParameters teParams = teTrans.getModel()._parms; - assertNotNull(teParams._fold_column); - assertEquals("__fold__target", teParams._fold_column); - assertTrue(teParams._fold_column.endsWith("target")); - assertEquals(DataLeakageHandlingStrategy.KFold, teParams._data_leakage_handling); - - Model eModel = m._output.getEstimatorModel(); - assertEquals(0, eModel._parms._nfolds); - assertNotNull(eModel._parms._fold_column); - - assertEquals(teParams._fold_column, eModel._parms._fold_column); - assertNotEquals(aml.getBuildSpec().input_spec.training_frame, eModel._parms._train); - Frame amlTrain = aml.getTrainingFrame(); - assertTrue(ArrayUtils.contains(eModel._output._names, eModel._parms._fold_column)); - assertFalse(ArrayUtils.contains(amlTrain.names(), eModel._parms._fold_column)); - assertEquals(nfolds, m._output._cross_validation_models.length); - assertArrayEquals(m._output._cross_validation_models, m._output.getEstimatorModel()._output._cross_validation_models); //temporary until estimator CV models can be translated into pipeline models. + aml.getBuildSpec().build_control.nfolds = nfolds; + TargetEncoding te = new TargetEncoding(aml); + te.setEncodeAllColumns(true); + try { + Scope.enter(); + te.prepare(); + assertNotNull(te.getTEModel()); + assertNotNull(te.getTEPreprocessor()); + Scope.track_generic(te.getTEModel()); + Scope.track_generic(te.getTEPreprocessor()); + assertNotNull(te.getTEModel()._parms._fold_column); + assertTrue(te.getTEModel()._parms._fold_column.endsWith(TargetEncoding.TE_FOLD_COLUMN_SUFFIX)); + assertEquals(DataLeakageHandlingStrategy.KFold, te.getTEModel()._parms._data_leakage_handling); + + Model.Parameters params = new DummyModel.DummyModelParameters(); + params._train = fr._key; + params._nfolds = nfolds; + params._fold_column = null; + + PreprocessingStep.Completer complete = te.apply(params, new PreprocessingConfig()); + assertEquals(0, params._nfolds); + assertNotNull(params._fold_column); + assertEquals(te.getTEModel()._parms._fold_column, params._fold_column); + assertNotEquals(fr._key, params._train); + Frame newTrain = params._train.get(); + assertTrue(ArrayUtils.contains(newTrain.names(), params._fold_column)); + assertFalse(ArrayUtils.contains(fr.names(), params._fold_column)); + assertEquals(nfolds, newTrain.vec(params._fold_column).toCategoricalVec().cardinality()); + complete.run(); + } finally { + te.dispose(); + Scope.exit(); + } } - + @Test - public void test_te_pipeline_lifecycle_with_automl_cv_foldcolumn() { - String foldc = "foldc"; - AutoML aml = runDummyAutoML(spec -> { - spec.input_spec.fold_column = foldc; - spec.build_control.keep_cross_validation_models = true; - }); - PipelineModel m = (PipelineModel) aml.leaderboard().getLeader(); - DataTransformer[] transformers = m._output.getTransformers(); - assertNotNull(transformers); - assertEquals(2, transformers.length); //with CV enabled and no fold column, an additional transformer is added to generate the latter, - TargetEncoderFeatureTransformer teTrans = (TargetEncoderFeatureTransformer)transformers[1]; - TargetEncoderParameters teParams = teTrans.getModel()._parms; - assertNotNull(teParams._fold_column); - assertEquals(foldc, teParams._fold_column); - assertEquals(DataLeakageHandlingStrategy.KFold, teParams._data_leakage_handling); - - Model eModel = m._output.getEstimatorModel(); - assertEquals(0, eModel._parms._nfolds); - assertNotNull(eModel._parms._fold_column); - - assertEquals(foldc, eModel._parms._fold_column); - assertNotEquals(aml.getBuildSpec().input_spec.training_frame, eModel._parms._train); - assertEquals(2, m._output._cross_validation_models.length); // foldc has 2 distinct values - assertArrayEquals(m._output._cross_validation_models, m._output.getEstimatorModel()._output._cross_validation_models); //temporary until estimator CV models can be translated into pipeline models. + public void test_te_preprocessing_lifecycle_with_automl_cv_foldcolumn() { + aml.getBuildSpec().input_spec.fold_column = "foldc"; + TargetEncoding te = new TargetEncoding(aml); + te.setEncodeAllColumns(true); + try { + Scope.enter(); + te.prepare(); + assertNotNull(te.getTEModel()); + assertNotNull(te.getTEPreprocessor()); + Scope.track_generic(te.getTEModel()); + Scope.track_generic(te.getTEPreprocessor()); + assertNotNull(te.getTEModel()._parms._fold_column); + assertEquals("foldc", te.getTEModel()._parms._fold_column); + assertEquals(DataLeakageHandlingStrategy.KFold, te.getTEModel()._parms._data_leakage_handling); + + Model.Parameters params = new DummyModel.DummyModelParameters(); + params._train = fr._key; + params._nfolds = 0; + params._fold_column = "foldc"; + + PreprocessingStep.Completer complete = te.apply(params, new PreprocessingConfig()); + assertEquals(0, params._nfolds); + assertNotNull(params._fold_column); + assertEquals("foldc", params._fold_column); + assertEquals(te.getTEModel()._parms._fold_column, params._fold_column); + assertEquals(fr._key, params._train); + complete.run(); + } finally { + te.dispose(); + Scope.exit(); + } } - - + + @Test public void test_automl_run_with_target_encoding_enabled() { - AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec(); - Frame fr = parseTestFile("./smalldata/titanic/titanic_expanded.csv"); Scope.track(fr); - SplitFrame sf = new SplitFrame(fr, new double[] { 0.7, 0.3 }, new Key[]{Key.make("titanic_train"), Key.make("titanic_test")}); - sf.exec().get(); - Frame train = sf._destination_frames[0].get(); Scope.track(train); - Frame test = sf._destination_frames[1].get(); Scope.track(test); - - autoMLBuildSpec.input_spec.training_frame = train._key; - autoMLBuildSpec.input_spec.validation_frame = test._key; - autoMLBuildSpec.input_spec.leaderboard_frame = test._key; - autoMLBuildSpec.input_spec.response_column = "survived"; - autoMLBuildSpec.build_control.stopping_criteria.set_max_models(15); // sth big enough to test all algos+grids with TE - autoMLBuildSpec.build_control.stopping_criteria.set_seed(42); - autoMLBuildSpec.build_control.nfolds = 3; - autoMLBuildSpec.build_models.preprocessing = new PipelineStepDefinition[] { - new PipelineStepDefinition(PipelineStepDefinition.Type.TargetEncoding) - }; -// autoMLBuildSpec.build_models.exclude_algos = aro(Algo.DeepLearning); - - AutoML aml = AutoML.startAutoML(autoMLBuildSpec); Scope.track_generic(aml); - aml.get(); - System.out.println(aml.leaderboard().toTwoDimTable()); - for (Model m : aml.leaderboard().getModels()) { - if (m instanceof StackedEnsembleModel) { - assertFalse(m.haveMojo()); // all SEs should not support MOJO as their base models don't - assertFalse(m.havePojo()); - } else { - assertTrue(m instanceof PipelineModel); - PipelineModel p = (PipelineModel)m; - if (p._output.getEstimatorModel() instanceof DeepLearningModel) { - assertEquals(1, p._output.getTransformers().length); // TE disabled for DL, but keeping the fold column generator for CV consistency with other models when building SE. + try { + Scope.enter(); + AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec(); + Frame fr = parseTestFile("./smalldata/titanic/titanic_expanded.csv"); Scope.track(fr); + SplitFrame sf = new SplitFrame(fr, new double[] { 0.7, 0.3 }, new Key[]{Key.make("titanic_train"), Key.make("titanic_test")}); + sf.exec().get(); + Frame train = sf._destination_frames[0].get(); Scope.track(train); + Frame test = sf._destination_frames[1].get(); Scope.track(test); + + autoMLBuildSpec.input_spec.training_frame = train._key; +// autoMLBuildSpec.input_spec.validation_frame = test._key; + autoMLBuildSpec.input_spec.leaderboard_frame = test._key; + autoMLBuildSpec.input_spec.response_column = "survived"; + autoMLBuildSpec.build_control.stopping_criteria.set_max_models(15); // sth big enough to test all algos+grids with TE + autoMLBuildSpec.build_control.stopping_criteria.set_seed(42); + autoMLBuildSpec.build_control.nfolds = 3; + autoMLBuildSpec.build_models.preprocessing = new PreprocessingStepDefinition[] { + new PreprocessingStepDefinition(Type.TargetEncoding) + }; + + aml = AutoML.startAutoML(autoMLBuildSpec); Scope.track_generic(aml); + aml.get(); + System.out.println(aml.leaderboard().toTwoDimTable()); + for (Model m : aml.leaderboard().getModels()) { + if (m instanceof StackedEnsembleModel) { + assertNull(m._parms._preprocessors); + assertFalse(m.haveMojo()); // all SEs should have at least one XGB which doesn't support MOJO + assertFalse(m.havePojo()); + } else if (m instanceof GLMModel + || m instanceof DeepLearningModel + ) { // disabled for GLM with CV, because GLM refuses to follow the same CV flow as other algos. + assertNull(m._parms._preprocessors); + assertTrue(m.haveMojo()); + assertTrue(m.havePojo()); } else { - assertEquals(2, p._input_parms._transformers.length); - if (p._input_parms._transformers[1].get() != null) { - assertEquals(2, p._output.getTransformers().length); - assertTrue(p._output.getTransformers()[1].enabled()); - } else { - assertEquals(1, p._output.getTransformers().length); - assertTrue(p._key.toString().contains("_grid_")); // TE can be disabled during grid search as an hyperparam. - } + assertNotNull(m._parms._preprocessors); + assertEquals(1, m._parms._preprocessors.length); + assertTrue(m._parms._preprocessors[0].get() instanceof TargetEncoderPreprocessor); + assertFalse(m.haveMojo()); + assertFalse(m.havePojo()); } - assertFalse(m.haveMojo()); - assertFalse(m.havePojo()); } + } finally { + Scope.exit(); } } - - private static class DummySteps extends DummyStepsProvider.DummyModelSteps { - - public DummySteps(AutoML autoML) { - super(autoML); - defaultModels = new ModelingStep[] { - new DummyModelStep(DummyBuilder.algo, "dummy_model", aml()), - }; - - grids = new ModelingStep[] { - new DummyGridStep(DummyBuilder.algo, "dummy_grid", aml()) - }; - } - } - private static class TEStepDefinition extends PipelineStepDefinition { - - public TEStepDefinition() { - super(Type.TargetEncoding); - } - - @Override - public PipelineStep newPipelineStep(AutoML aml) { - TargetEncoding teStep = (TargetEncoding) super.newPipelineStep(aml); - teStep.setEncodeAllColumns(true); //enforce as we use small data in those tests - return teStep; - } - } - } diff --git a/h2o-bindings/bin/custom.py b/h2o-bindings/bin/custom.py index 496fad3946c3..ec1c81fa29e0 100644 --- a/h2o-bindings/bin/custom.py +++ b/h2o-bindings/bin/custom.py @@ -25,9 +25,8 @@ def get_customizations_for(language, algo, property=None, default=None): tokens = property.split('.') value = customizations for token in tokens: - if token in value: - value = value.get(token) - else: + value = value.get(token) + if value is None: return default return value else: diff --git a/h2o-bindings/bin/custom/R/gen_pipeline.py b/h2o-bindings/bin/custom/R/gen_pipeline.py deleted file mode 100644 index 09b5b702677b..000000000000 --- a/h2o-bindings/bin/custom/R/gen_pipeline.py +++ /dev/null @@ -1,33 +0,0 @@ - -extensions = dict( - required_params=[], - frame_params=[], - validate_required_params="", - set_required_params="", - module=""" -.h2o.fill_pipeline <- function(model, parameters, allparams) { - if (!is.null(model$estimator)) { - model$estimator_model <- h2o.getModel(model$estimator$name) - } else { - model$estimator_model <- NULL - } - model$transformers <- unlist(lapply(model$transformers, function(k) .h2o.fetch_datatransformer(k$name))) - # class(model) <- "H2OPipeline" - return(model) -} -.h2o.fetch_datatransformer <- function(id) { - resp <- .h2o.__remoteSend(method="GET", h2oRestApiVersion=3, page=paste0("Pipeline/DataTransformer/", id)) - tr <- new("H2ODataTransformer", id=resp$key$name, name=resp$name, description=resp$description) - return (tr) -} -""" -) - -doc = dict( - preamble=""" -Build a pipeline model given a list of transformers and a final model. - -Currently R model pipelines, as produced by AutoML for example, -are only available as read-only models that can not be constructed and trained directly by the end-user. -""", -) diff --git a/h2o-bindings/bin/custom/python/gen_pipeline.py b/h2o-bindings/bin/custom/python/gen_pipeline.py deleted file mode 100644 index c69ddfbe42bf..000000000000 --- a/h2o-bindings/bin/custom/python/gen_pipeline.py +++ /dev/null @@ -1,69 +0,0 @@ -supervised_learning = None # actually depends on the estimator model in the pipeline, leave it to None for now as it is needed only for training and we don't support pipeline as input yet - - -# in future update, we'll want to expose parameters applied to each transformer -def module_extensions(): - class H2ODataTransformer(Keyed, H2ODisplay): - @classmethod - def make(cls, kvs): - dt = H2ODataTransformer(**{k: v for k, v in kvs if k not in H2OSchema._ignored_schema_keys_}) - dt._json = kvs - return dt - - def __init__(self, key=None, name=None, description=None): - self._json = None - self._id = key['name'] - self._name = name - self._description = description - - @property - def key(self): - return self._id - - def _repr_(self): - return repr(self._json) - - def _str_(self, verbosity=None): - return repr_def(self) - - - # self-register transformer class: done as soon as `h2o.estimators` is loaded, which means as soon as h2o.h2o is... - register_schema_handler("DataTransformerV3", H2ODataTransformer) - - -def class_extensions(): - - @staticmethod - def get_transformer(id): - assert_is_type(id, str) - return h2o.api("GET /3/Pipeline/DataTransformer/%s" % id) - - @property - def transformers(self): - trs_json = self._model_json['output']['transformers'] - return None if (trs_json is None) else [H2OPipeline.get_transformer(k['name']) for k in trs_json] - - @property - def estimator_model(self): - m_json = self._model_json['output']['estimator'] - return None if (m_json is None or m_json['name'] is None) else h2o.get_model(m_json['name']) - - def transform(self, fr): - """ - Applies all the pipeline transformers to the given input frame. - :return: the transformed frame, as it would be passed to `estimator_model`, if calling `predict` instead. - """ - return H2OFrame._expr(expr=ExprNode("transform", ASTId(self.key), ASTId(fr.key)))._frame(fill_cache=True) - - -extensions = dict( - __imports__=""" -import h2o -from h2o.base import Keyed -from h2o.display import H2ODisplay, repr_def -from h2o.expr import ASTId, ExprNode -from h2o.schemas import H2OSchema, register_schema_handler -""", - __class__=class_extensions, - __module__=module_extensions, -) diff --git a/h2o-bindings/bin/gen_R.py b/h2o-bindings/bin/gen_R.py index 7143f6051377..fdb0150d4d28 100644 --- a/h2o-bindings/bin/gen_R.py +++ b/h2o-bindings/bin/gen_R.py @@ -145,8 +145,8 @@ def get_schema_params(pname): "verbose", "destination_key"] # destination_key is only for SVD bulk_params = list(zip(*filter(lambda t: not t[0] in bulk_pnames_skip, zip(sig_pnames, sig_params)))) - bulk_pnames = list(bulk_params[0]) if bulk_params else [] - sig_bulk_params = list(bulk_params[1]) if bulk_params else [] + bulk_pnames = list(bulk_params[0]) + sig_bulk_params = list(bulk_params[1]) sig_bulk_params.append("segment_columns = NULL") sig_bulk_params.append("segment_models_id = NULL") sig_bulk_params.append("parallelism = 1") diff --git a/h2o-bindings/bin/gen_python.py b/h2o-bindings/bin/gen_python.py index 9ebdaf9a8215..6fa044f77d0e 100755 --- a/h2o-bindings/bin/gen_python.py +++ b/h2o-bindings/bin/gen_python.py @@ -314,19 +314,16 @@ def extend_schema_params(param): yield " self._parms[\"%s\"] = %s" % (sname, pname) yield "" - if deprecated_params: - for old, new in deprecated_params.items(): - new_name = new[0] if isinstance(new, tuple) else new - yield " %s = deprecated_property('%s', %s)" % (old, old, new) - yield "" - + for old, new in deprecated_params.items(): + new_name = new[0] if isinstance(new, tuple) else new + yield " %s = deprecated_property('%s', %s)" % (old, old, new) + + yield "" if class_extras: yield reformat_block(code_as_str(class_extras), 4) - yield "" if module_extras: yield "" yield reformat_block(code_as_str(module_extras)) - yield "" def algo_to_classname(algo): @@ -355,7 +352,6 @@ def algo_to_classname(algo): if algo == "modelselection": return "H2OModelSelectionEstimator" if algo == "isotonicregression": return "H2OIsotonicRegressionEstimator" if algo == "adaboost": return "H2OAdaBoostEstimator" - if algo == "pipeline": return "H2OPipeline" return "H2O" + algo.capitalize() + "Estimator" @@ -460,10 +456,8 @@ def main(): modelselection="model_selection" ) algo_to_category = dict( - generic="Miscellaneous", - pipeline=None, svd="Miscellaneous", - word2vec="Miscellaneous", + word2vec="Miscellaneous" ) for name, mb in builders: module = name @@ -471,9 +465,9 @@ def main(): module = algo_to_module[name] bi.vprint("Generating model: " + name) bi.write_to_file("%s.py" % module, gen_module(mb, name)) - category = (algo_to_category[name] if name in algo_to_category - else "Supervised" if mb["supervised"] - else "Unsupervised") + category = algo_to_category[name] if name in algo_to_category \ + else "Supervised" if mb["supervised"] \ + else "Unsupervised" full_module = '.'.join(["h2o.estimators", module]) modules.append((full_module, module, algo_to_classname(name), category)) diff --git a/h2o-core/src/main/java/hex/CVModelBuilder.java b/h2o-core/src/main/java/hex/CVModelBuilder.java index 550332604e67..3287b62636df 100644 --- a/h2o-core/src/main/java/hex/CVModelBuilder.java +++ b/h2o-core/src/main/java/hex/CVModelBuilder.java @@ -43,49 +43,46 @@ public void bulkBuildModels() { stopAll(submodel_tasks); throw new Job.JobCancelledException(job); } - LOG.info("Building "+modelBuilders[i]._desc+"."); + LOG.info("Building cross-validation model " + (i + 1) + " / " + N + "."); prepare(modelBuilders[i]); modelBuilders[i].startClock(); submodel_tasks[i] = modelBuilders[i].submitTrainModelTask(); if (++nRunning == parallelization) { //piece-wise advance in training the models while (nRunning > 0) { final int waitForTaskIndex = i + 1 - nRunning; - final String modelDesc = modelBuilders[waitForTaskIndex]._desc; try { submodel_tasks[waitForTaskIndex].join(); finished(modelBuilders[waitForTaskIndex]); } catch (RuntimeException t) { if (rt == null) { - LOG.info("Exception from "+ modelDesc + " will be reported as main exception."); + LOG.info("Exception from CV model #" + waitForTaskIndex + " will be reported as main exception."); rt = t; } else { - LOG.warn(modelDesc + " failed, the exception will not be reported", t); + LOG.warn("CV model #" + waitForTaskIndex + " failed, the exception will not be reported", t); } } finally { - LOG.info("Completed "+modelDesc+"."); + LOG.info("Completed cross-validation model " + waitForTaskIndex + " / " + N + "."); nRunning--; // need to decrement regardless even if there is an exception, otherwise looping... } } if (rt != null) throw rt; } } - for (int i = 0; i < N; ++i) { //all sub-models must be completed before the main model can be built - final String modelDesc = modelBuilders[i]._desc; + for (int i = 0; i < N; ++i) //all sub-models must be completed before the main model can be built try { final TrainModelTaskController task = submodel_tasks[i]; assert task != null; task.join(); } catch (RuntimeException t) { if (rt == null) { - LOG.info("Exception from "+ modelDesc + " will be reported as main exception."); + LOG.info("Exception from CV model #" + i + " will be reported as main exception."); rt = t; } else { - LOG.warn(modelDesc + " failed, the exception will not be reported", t); + LOG.warn("CV model #" + i + " failed, the exception will not be reported", t); } } finally { - LOG.info("Completed "+modelDesc+"."); + LOG.info("Completed cross-validation model " + i + " / " + N + "."); } - } if (rt != null) throw rt; } diff --git a/h2o-core/src/main/java/hex/Model.java b/h2o-core/src/main/java/hex/Model.java index c73c0c91ad79..c246e5675f05 100755 --- a/h2o-core/src/main/java/hex/Model.java +++ b/h2o-core/src/main/java/hex/Model.java @@ -13,7 +13,6 @@ import hex.genmodel.easy.exception.PredictException; import hex.genmodel.easy.prediction.*; import hex.genmodel.utils.DistributionFamily; -import hex.grid.Grid; import hex.quantile.QuantileModel; import org.joda.time.DateTime; import water.*; @@ -284,6 +283,7 @@ public boolean isGeneric() { } public boolean havePojo() { + if (_parms._preprocessors != null) return false; // TE processor not included to current POJO (see PUBDEV-8508 for potential fix) final String algoName = _parms.algoName(); return ModelBuilder.getRegisteredBuilder(algoName) .map(ModelBuilder::havePojo) @@ -295,6 +295,7 @@ public boolean havePojo() { } public boolean haveMojo() { + if (_parms._preprocessors != null) return false; // until PUBDEV-7799, disable model MOJO if it was trained with embedded TE. final String algoName = _parms.algoName(); return ModelBuilder.getRegisteredBuilder(algoName) .map(ModelBuilder::haveMojo) @@ -348,16 +349,10 @@ public static class GridSortBy { // intentionally not an enum to allow 3rd party * WARNING: Model Parameters is not immutable object and ModelBuilder can modify * them! */ - public abstract static class Parameters extends Iced implements AdaptFrameParameters, Parameterizable, Checksumable { + public abstract static class Parameters extends Iced implements AdaptFrameParameters { /** Maximal number of supported levels in response. */ public static final int MAX_SUPPORTED_LEVELS = 1<<20; - static final Set IGNORED_FIELDS_PARAM_HASH = new HashSet<>(Arrays.asList( - "_export_checkpoints_dir", - "_max_runtime_secs" // It is often modified during training on purpose (e.g. grid search) - )); - - /** The short name, used in making Keys. e.g. "GBM" */ abstract public String algoName(); @@ -426,6 +421,8 @@ public static CategoricalEncodingScheme fromGenModel(CategoricalEncoding encodin } } + public Key[] _preprocessors; + public long _seed = -1; public long getOrMakeRealSeed(){ while (_seed==-1) { @@ -623,9 +620,8 @@ public void read_unlock_frames(Job job) { public boolean hasCustomMetricFunc() { return _custom_metric_func != null; } - @Override public long checksum() { - return checksum(IGNORED_FIELDS_PARAM_HASH); + return checksum(null); } /** * Compute a checksum based on all non-transient non-static ice-able assignable fields (incl. inherited ones) which have @API annotations. @@ -638,7 +634,66 @@ public long checksum() { * @return checksum A 64-bit long representing the checksum of the {@link Parameters} object */ public long checksum(final Set ignoredFields) { - long xs = Checksum.checksum(this, ignoredFields); + long xs = 0x600DL; + int count = 0; + Field[] fields = Weaver.getWovenFields(this.getClass()); + Arrays.sort(fields, Comparator.comparing(Field::getName)); + for (Field f : fields) { + if (ignoredFields != null && ignoredFields.contains(f.getName())) { + // Do not include ignored fields in the final hash + continue; + } + final long P = MathUtils.PRIMES[count % MathUtils.PRIMES.length]; + Class c = f.getType(); + if (c.isArray()) { + try { + f.setAccessible(true); + if (f.get(this) != null) { + if (c.getComponentType() == Integer.TYPE){ + int[] arr = (int[]) f.get(this); + xs = xs * P + (long) Arrays.hashCode(arr); + } else if (c.getComponentType() == Float.TYPE) { + float[] arr = (float[]) f.get(this); + xs = xs * P + (long) Arrays.hashCode(arr); + } else if (c.getComponentType() == Double.TYPE) { + double[] arr = (double[]) f.get(this); + xs = xs * P + (long) Arrays.hashCode(arr); + } else if (c.getComponentType() == Long.TYPE){ + long[] arr = (long[]) f.get(this); + xs = xs * P + (long) Arrays.hashCode(arr); + } else if (c.getComponentType() == Boolean.TYPE){ + boolean[] arr = (boolean[]) f.get(this); + xs = xs * P + (long) Arrays.hashCode(arr); + } else { + Object[] arr = (Object[]) f.get(this); + xs = xs * P + (long) Arrays.deepHashCode(arr); + } //else lead to ClassCastException + } else { + xs = xs * P; + } + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } catch (ClassCastException t) { + throw H2O.fail("Failed to calculate checksum for the parameter object", t); //no support yet for int[][] etc. + } + } else { + try { + f.setAccessible(true); + Object value = f.get(this); + if (value instanceof Enum) { + // use string hashcode for enums, otherwise the checksum would be different each run + xs = xs * P + (long)(value.toString().hashCode()); + } else if (value != null) { + xs = xs * P + (long)(value.hashCode()); + } else { + xs = xs * P + P; + } + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + count++; + } xs ^= (train() == null ? 43 : train().checksum()) * (valid() == null ? 17 : valid().checksum()); return xs; } @@ -757,67 +812,6 @@ public void setDistributionFamily(DistributionFamily distributionFamily){ public DistributionFamily getDistributionFamily() { return _distribution; } - - @Override - public boolean hasParameter(String name) { - try { - getParameter(name); - return true; - } catch (Exception e) { - return false; - } - } - - @Override - public Object getParameter(String name) { - return PojoUtils.getFieldValue(this, name); - } - - @Override - public void setParameter(String name, Object value) { - PojoUtils.setField(this, name, value); - } - - @Override - public boolean isParameterSetToDefault(String name) { - Object val = getParameter(name); - Object defaultVal = getParameterDefaultValue(name); - return Objects.deepEquals(val, defaultVal); - } - - @Override - public Object getParameterDefaultValue(String name) { - return getDefaults().getParameter(name); - } - - @Override - public boolean isParameterAssignable(String name) { - return "_seed".equals(name) || isParameterSetToDefault(name); - } - - @Override - public Parameters freshCopy() { - return clone(); - } - - /** private use only to avoid this getting mutated. */ - private transient Parameters _defaults; - - /** private use only to avoid this getting mutated. */ - private Parameters getDefaults() { - if (_defaults == null) { - _defaults = ModelBuilder.makeParameters(algoName()); - } - return _defaults; - } - - /** - * callback called during grid search if it failed building a model with current parameters. - * When this is called, the failure instance is already extended with the last failure details/params. - * @param searchFailure - * @param grid - */ - public void addSearchWarnings(Grid.SearchFailure searchFailure, Grid grid) {} } public ModelMetrics addModelMetrics(final ModelMetrics mm) { @@ -1698,7 +1692,7 @@ public static String[] adaptTestForTrain(final Frame test, final String[] origNa String[] names, String[][] domains, final AdaptFrameParameters parms, final boolean expensive, final boolean computeMetrics, final InteractionBuilder interactionBldr, final ToEigenVec tev, - final Map toDelete, final boolean catEncoded) + final IcedHashMap toDelete, final boolean catEncoded) throws IllegalArgumentException { String[] msg = new String[0]; if (test == null) return msg; @@ -1736,7 +1730,7 @@ public static String[] adaptTestForTrain(final Frame test, final String[] origNa // As soon as the test frame contains at least one original pre-encoding predictor, // then we consider the frame as valid for predictions, and we'll later fill missing columns with NA Set required = new HashSet<>(Arrays.asList(origNames)); - required.removeAll(Arrays.asList(parms.getNonPredictors())); + required.removeAll(Arrays.asList(response, weights, fold, treatment)); for (String name : test.names()) { if (required.contains(name)) { match = true; @@ -1921,7 +1915,7 @@ public Frame result() { } public Frame transform(Frame fr) { - throw new UnsupportedOperationException("this model doesn't support frame transformation"); + throw new UnsupportedOperationException("this model doesn't support constant frame results"); } /** Bulk score the frame {@code fr}, producing a Frame result; the 1st @@ -1983,6 +1977,7 @@ public Frame score(Frame fr, String destination_key, Job j, boolean computeMetri protected Frame adaptFrameForScore(Frame fr, boolean computeMetrics) { Frame adaptFr = new Frame(fr); + applyPreprocessors(adaptFr); String[] msg = adaptTestForTrain(adaptFr,true, computeMetrics); // Adapt if (msg.length > 0) { for (String s : msg) { @@ -1999,7 +1994,8 @@ public Frame score(Frame fr, String destination_key, Job j, boolean computeMetri try (Scope.Safe s = Scope.safe(fr)) { // Adapt frame, clean up the previous score warning messages _warningsP = new String[0]; - computeMetrics = computeMetrics && canComputeMetricsForFrame(fr); + computeMetrics = computeMetrics && + (!_output.hasResponse() || (fr.vec(_output.responseName()) != null && !fr.vec(_output.responseName()).isBad())); Frame adaptFr = adaptFrameForScore(fr, computeMetrics); // Predict & Score @@ -2023,6 +2019,21 @@ public Frame score(Frame fr, String destination_key, Job j, boolean computeMetri } } + private void applyPreprocessors(Frame fr) { + if (_parms._preprocessors == null) return; + + for (Key key : _parms._preprocessors) { + DKV.prefetch(key); + } + Frame result = fr; + for (Key key : _parms._preprocessors) { + ModelPreprocessor preprocessor = key.get(); + result = preprocessor.processScoring(result, this); + Scope.track(result); + } + fr.restructure(result.names(), result.vecs()); //inplace + } + /** * Compute the deviances for each observation * @param valid Validation Frame (must contain the response) @@ -2073,34 +2084,14 @@ public void map(Chunk[] cs, NewChunk[] nc) { }.doAll(Vec.T_NUM, predictions).outputFrame(Key.make(outputName), new String[]{"deviance"}, null); } - protected boolean canComputeMetricsForFrame(Frame fr) { - return !_output.hasResponse() || (fr.vec(_output.responseName()) != null && !fr.vec(_output.responseName()).isBad()); - } - protected String[] makeScoringNames(){ return makeScoringNames(_output); } - - /** - * ???: something fishy here! - * I suspect a bug not discovered yet, as there is a surprising similarity with the implementation taking a `names` parameter, - * but usages are only very slightly different and only one version is overridden by algos, which means that one may be incorrect in some cases. - */ - protected String[][] makeScoringDomains(Frame adaptFrm, boolean computeMetrics) { - String[][] domains = new String[1][]; - Vec response = adaptFrm.lastVec(); - domains[0] = _output.nclasses() == 1 ? null : !computeMetrics ? _output._domains[_output._domains.length-1] : response.domain(); - if (_parms._distribution == DistributionFamily.quasibinomial) { - domains[0] = new VecUtils.CollectDoubleDomain(null,2).doAll(response).stringDomain(response.isInt()); - } - return domains; - } - protected String[][] makeScoringDomains(Frame adaptFrm, boolean computeMetrics, String[] names) { String[][] domains = new String[names.length][]; Vec response = adaptFrm.lastVec(); - domains[0] = names.length == 1 || _output.hasTreatment() ? null : !computeMetrics ? _output._domains[_output._domains.length - 1] : response.domain(); + domains[0] = names.length == 1 || _output.hasTreatment() ? null : ! computeMetrics ? _output._domains[_output._domains.length - 1] : response.domain(); if (_parms._distribution == DistributionFamily.quasibinomial) { domains[0] = new VecUtils.CollectDoubleDomain(null,2).doAll(response).stringDomain(response.isInt()); } @@ -2223,10 +2214,16 @@ protected Frame postProcessPredictions(Frame adaptFrm, Frame predictFr, Job j) { * @return MetricBuilder */ protected ModelMetrics.MetricBuilder scoreMetrics(Frame adaptFrm) { - final boolean computeMetrics = canComputeMetricsForFrame(adaptFrm); + final boolean computeMetrics = (!isSupervised() || (adaptFrm.vec(_output.responseName()) != null && !adaptFrm.vec(_output.responseName()).isBad())); // Build up the names & domains. -// String[] names = makeScoringNames(); - String[][] domains = makeScoringDomains(adaptFrm, computeMetrics); + //String[] names = makeScoringNames(); + String[][] domains = new String[1][]; + Vec response = adaptFrm.lastVec(); + domains[0] = _output.nclasses() == 1 ? null : !computeMetrics ? _output._domains[_output._domains.length-1] : response.domain(); + if (_parms._distribution == DistributionFamily.quasibinomial) { + domains[0] = new VecUtils.CollectDoubleDomain(null,2).doAll(response).stringDomain(response.isInt()); + } + // Score the dataset, building the class distribution & predictions BigScore bs = makeBigScoreTask(domains, null, adaptFrm, computeMetrics, false, null, CFuncRef.from(_parms._custom_metric_func)).doAll(adaptFrm); return bs._mb; @@ -2250,11 +2247,7 @@ public BigScore(String[] domain, int ncols, double[] mean, boolean testHasWeight boolean computeMetrics, boolean makePreds, Job j, CFuncRef customMetricFunc) { super(customMetricFunc); _j = j; - _domain = domain; - _npredcols = ncols; - _mean = mean; - _computeMetrics = computeMetrics; - _makePreds = makePreds; + _domain = domain; _npredcols = ncols; _mean = mean; _computeMetrics = computeMetrics; _makePreds = makePreds; if(_output._hasWeights && _computeMetrics && !testHasWeights) throw new IllegalArgumentException("Missing weights when computing validation metrics."); _hasWeights = testHasWeights; @@ -2435,7 +2428,7 @@ public double score(double[] data){ deleteCrossValidationPreds(); deleteCrossValidationModels(); } - cleanUp(_toDelete == null ? null : _toDelete.keySet()); + cleanUp(_toDelete); return super.remove_impl(fs, cascade); } diff --git a/h2o-core/src/main/java/hex/ModelBuilder.java b/h2o-core/src/main/java/hex/ModelBuilder.java index 7f05a1c1fba7..cbf1e301542d 100644 --- a/h2o-core/src/main/java/hex/ModelBuilder.java +++ b/h2o-core/src/main/java/hex/ModelBuilder.java @@ -23,14 +23,6 @@ * Model builder parent class. Contains the common interfaces and fields across all model builders. */ abstract public class ModelBuilder, P extends Model.Parameters, O extends Model.Output> extends Iced { - - public static final String CV_WEIGHTS_COLUMN = "__internal_cv_weights__"; - - private ModelBuilderCallbacks _callbacks; - - public void setCallbacks(ModelBuilderCallbacks callbacks) { - this._callbacks = callbacks; - } public ToEigenVec getToEigenVec() { return null; } public boolean shouldReorder(Vec v) { return _parms._categorical_encoding.needsResponse() && isSupervised(); } @@ -157,7 +149,7 @@ public static

P makeParameters(String algo) { /** Factory method to create a ModelBuilder instance for given the algo name. * Shallow clone of both the default ModelBuilder instance and a Parameter. */ - public static > B make(String algo, Job job, K result) { + public static B make(String algo, Job job, Key result) { return getRegisteredBuilder(algo) .map(prototype -> { @SuppressWarnings("unchecked") @@ -185,8 +177,8 @@ public static B make(MP pa return make(parms, mKey); } - public static > B make(MP parms, K mKey) { - Job mJob = new Job<>(mKey, parms.javaName(), parms.algoName()); + public static B make(MP parms, Key mKey) { + Job mJob = new Job<>(mKey, parms.javaName(), parms.algoName()); B newMB = ModelBuilder.make(parms.algoName(), mJob, mKey); newMB._parms = parms.clone(); newMB._input_parms = parms.clone(); @@ -244,17 +236,17 @@ abstract protected class Driver extends H2O.H2OCountedCompleter { protected Driver(){ super(); } - @Override - public void compute2() { - if (_callbacks != null) _callbacks.wrapCompute(ModelBuilder.this, this::compute3); - else this.compute3(); + private ModelBuilderListener _callback; + + public void setCallback(ModelBuilderListener callback) { + this._callback = callback; } - + // Pull the boilerplate out of the computeImpl(), so the algo writer doesn't need to worry about the following: // 1) Scope (unless they want to keep data, then they must call Scope.untrack(Key[])) // 2) Train/Valid frame locking and unlocking // 3) calling tryComplete() - public void compute3() { + public void compute2() { try { Scope.enter(); _parms.read_lock_frames(_job); // Fetch & read-lock input frames @@ -266,10 +258,7 @@ public void compute3() { _parms.read_unlock_frames(_job); if (_parms._is_cv_model) { // CV models get completely cleaned up when the main model is fully trained. - Key[] keep = new Key[0]; - try { - keep = _workspace.getToDelete(true).keySet().toArray(new Key[0]); - } catch (Exception ignored) {} + Key[] keep = _workspace == null ? new Key[0] : _workspace.getToDelete(true).keySet().toArray(new Key[0]); Scope.exit(keep); } else { cleanUp(); @@ -282,16 +271,16 @@ public void compute3() { @Override public void onCompletion(CountedCompleter caller) { setFinalState(); - if (_callbacks != null) { - _callbacks.onModelSuccess(_result); + if (_callback != null) { + _callback.onModelSuccess(_result.get()); } } @Override public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) { setFinalState(); - if (_callbacks != null) { - _callbacks.onModelFailure(_result, ex, _parms); + if (_callback != null) { + _callback.onModelFailure(ex, _parms); } return true; } @@ -390,11 +379,16 @@ public void run() { /** Method to launch training of a Model, based on its parameters. */ final public Job trainModel() { + return trainModel(null); + } + + final public Job trainModel(final ModelBuilderListener callback) { if (error_count() > 0) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this); startClock(); if (!nFoldCV()) { Driver driver = trainModelImpl(); + driver.setCallback(callback); return _job.start(driver, _parms.progressUnits(), _parms._max_runtime_secs); } else { // cross-validation needs to be forked off to allow continuous (non-blocking) progress bar @@ -405,10 +399,15 @@ public void compute2() { tryComplete(); } + @Override + public void onCompletion(CountedCompleter caller) { + if (callback != null) callback.onModelSuccess(_result.get()); + } + @Override public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) { Log.warn("Model training job " + _job._description + " completed with exception: " + ex); - if (_callbacks != null) _callbacks.onModelFailure(_result, ex, _parms); + if (callback != null) callback.onModelFailure(ex, _parms); try { Keyed.remove(_job._result); // ensure there's no incomplete model left for manipulation after crash or cancellation } catch (Exception logged) { @@ -607,8 +606,7 @@ public void updateParameters() { cv_updateOptimalParameters(_cvModelBuilders); } } - - + /** * Default naive (serial) implementation of N-fold cross-validation * (builds N+1 models, all have train+validation metrics, the main model has N-fold cross-validated validation metrics) @@ -620,8 +618,7 @@ public void computeCrossValidation() { ModelBuilder[] cvModelBuilders = null; try { Scope.enter(); - // Step 0: custom preparation for CV - cv_init(); // ensures that this initialization is done in the current Scope to avoid key leakage. + init(false); // Step 1: Assign each row to a fold final FoldAssignment foldAssignment = cv_AssignFold(N); @@ -664,7 +661,8 @@ public void computeCrossValidation() { // Step 7: Combine cross-validation scores; compute main model x-val // scores; compute gains/lifts - cv_mainModelScores(N, mbs, cvModelBuilders); + if (!cvModelBuilders[0].getName().equals("infogram")) // infogram does not support scoring + cv_mainModelScores(N, mbs, cvModelBuilders); _job.setReadyForView(true); DKV.put(_job); @@ -693,14 +691,9 @@ public void computeCrossValidation() { Scope.exit(); } } - - // Step 0: Algos can override this if additional preparation is required before starting CV. - protected void cv_init() { - init(false); - } // Step 1: Assign each row to a fold - protected FoldAssignment cv_AssignFold(int N) { + FoldAssignment cv_AssignFold(int N) { assert(N>=2); Vec fold = train().vec(_parms._fold_column); if (fold != null) { @@ -727,7 +720,7 @@ protected FoldAssignment cv_AssignFold(int N) { } // Step 2: Make 2*N binary weight vectors - protected Vec[] cv_makeWeights(final int N, FoldAssignment foldAssignment) { + Vec[] cv_makeWeights(final int N, FoldAssignment foldAssignment) { String origWeightsName = _parms._weights_column; Vec origWeight = origWeightsName != null ? train().vec(origWeightsName) : train().anyVec().makeCon(1.0); Frame folds_and_weights = new Frame(foldAssignment.getAdaptedFold(), origWeight); @@ -759,11 +752,12 @@ protected Vec[] cv_makeWeights(final int N, FoldAssignment foldAssignment) { } // Step 3: Build N train & validation frames; build N ModelBuilders; error check them all - protected ModelBuilder[] cv_makeFramesAndBuilders( int N, Vec[] weights) { + private ModelBuilder[] cv_makeFramesAndBuilders( int N, Vec[] weights ) { final long old_cs = _parms.checksum(); final String origDest = _result.toString(); - if (train().find(CV_WEIGHTS_COLUMN) != -1) throw new H2OIllegalArgumentException("Frame cannot contain a Vec called '"+CV_WEIGHTS_COLUMN+"'."); + final String weightName = "__internal_cv_weights__"; + if (train().find(weightName) != -1) throw new H2OIllegalArgumentException("Frame cannot contain a Vec called '" + weightName + "'."); Frame cv_fr = new Frame(train().names(),train().vecs()); if( _parms._weights_column!=null ) cv_fr.remove( _parms._weights_column ); // The CV frames will have their own private weight column @@ -776,32 +770,32 @@ protected ModelBuilder[] cv_makeFramesAndBuilders( int N, Vec[] weights // Training/Validation share the same data, but will have exclusive weights Frame cvTrain = new Frame(Key.make(identifier + "_train"), cv_fr.names(), cv_fr.vecs()); cvTrain.write_lock(_job); - cvTrain.add(CV_WEIGHTS_COLUMN, weights[2*i]); + cvTrain.add(weightName, weights[2*i]); cvTrain.update(_job); Frame cvValid = new Frame(Key.make(identifier + "_valid"), cv_fr.names(), cv_fr.vecs()); cvValid.write_lock(_job); - cvValid.add(CV_WEIGHTS_COLUMN, weights[2*i+1]); + cvValid.add(weightName, weights[2*i+1]); cvValid.update(_job); // Shallow clone - not everything is a private copy!!! ModelBuilder cv_mb = (ModelBuilder)this.clone(); - cv_mb._desc = "Cross-Validation model " + (i + 1) + " / " + N; + cv_mb.setTrain(cvTrain); cv_mb._result = Key.make(identifier); // Each submodel gets its own key cv_mb._parms = (P) _parms.clone(); // Fix up some parameters of the clone cv_mb._parms._is_cv_model = true; cv_mb._parms._cv_fold = i; - cv_mb._parms._weights_column = CV_WEIGHTS_COLUMN;// All submodels have a weight column, which the main model does not + cv_mb._parms._weights_column = weightName;// All submodels have a weight column, which the main model does not cv_mb._parms.setTrain(cvTrain._key); // All submodels have a weight column, which the main model does not cv_mb._parms._valid = cvValid._key; cv_mb._parms._fold_assignment = Model.Parameters.FoldAssignmentScheme.AUTO; cv_mb._parms._nfolds = 0; // Each submodel is not itself folded cv_mb._parms._max_runtime_secs = cv_max_runtime_secs; - + cv_mb.clearValidationErrors(); // each submodel gets its own validation messages and error_count() cv_mb._input_parms = (P) _parms.clone(); + cv_mb._desc = "Cross-Validation model " + (i + 1) + " / " + N; + // Error-check all the cross-validation Builders before launching any - cv_mb.clearValidationErrors(); // each submodel gets its own validation messages and error_count() - cv_mb.setTrain(null); cv_mb.setValid(null); // reset before validating cv_mb.init(false); if( cv_mb.error_count() > 0 ) { // Gather all submodel error messages Log.info("Marking frame for failed cv model for removal: " + cvTrain._key); @@ -818,7 +812,7 @@ protected ModelBuilder[] cv_makeFramesAndBuilders( int N, Vec[] weights if( error_count() > 0 ) { // Found an error in one or more submodels Futures fs = new Futures(); for (Frame cvf : cvFramesForFailedModels) { - cvf.vec(CV_WEIGHTS_COLUMN).remove(fs); // delete the Vec's chunks + cvf.vec(weightName).remove(fs); // delete the Vec's chunks DKV.remove(cvf._key, fs); // delete the Frame from the DKV, leaving its vecs Log.info("Removing frame for failed cv model: " + cvf._key); } @@ -831,7 +825,7 @@ protected ModelBuilder[] cv_makeFramesAndBuilders( int N, Vec[] weights } // Step 4: Run all the CV models and launch the main model - protected void cv_buildModels(int N, ModelBuilder[] cvModelBuilders ) { + public void cv_buildModels(int N, ModelBuilder[] cvModelBuilders ) { makeCVModelBuilder(cvModelBuilders, nModelsInParallel(N)).bulkBuildModels(); cv_computeAndSetOptimalParameters(cvModelBuilders); } @@ -840,35 +834,9 @@ protected CVModelBuilder makeCVModelBuilder(ModelBuilder[] modelBuilder return new CVModelBuilder(_job, modelBuilders, parallelization); } - protected ModelMetrics.MetricBuilder makeCVMetricBuilder(ModelBuilder cvModelBuilder, Futures fs) { - Frame cvValid = cvModelBuilder.valid(); - Frame preds = null; - try (Scope.Safe s = Scope.safe(cvValid)) { - ModelMetrics.MetricBuilder mb; - Frame adaptFr = new Frame(cvValid); - M cvModel = cvModelBuilder.dest().get(); - cvModel.adaptTestForTrain(adaptFr, true, !isSupervised()); - if (nclasses() == 2 /* need holdout predictions for gains/lift table */ - || _parms._keep_cross_validation_predictions - || (cvModel.isDistributionHuber() /*need to compute quantiles on abs error of holdout predictions*/)) { - String predName = cvModelBuilder.getPredictionKey(); - Model.PredictScoreResult result = cvModel.predictScoreImpl(cvValid, adaptFr, predName, _job, true, CFuncRef.NOP); - preds = result.getPredictions(); - Scope.untrack(preds); - result.makeModelMetrics(cvValid, adaptFr); - mb = result.getMetricBuilder(); - DKV.put(cvModel); - } else { - mb = cvModel.scoreMetrics(adaptFr); - } - return mb; - } finally { - Scope.track(preds); - } - } // Step 5: Score the CV models - protected ModelMetrics.MetricBuilder[] cv_scoreCVModels(int N, Vec[] weights, ModelBuilder[] cvModelBuilders) { + public ModelMetrics.MetricBuilder[] cv_scoreCVModels(int N, Vec[] weights, ModelBuilder[] cvModelBuilders) { if (_job.stop_requested()) { Log.info("Skipping scoring of CV models"); throw new Job.JobCancelledException(_job); @@ -884,17 +852,44 @@ protected ModelMetrics.MetricBuilder[] cv_scoreCVModels(int N, Vec[] weights, Mo Log.info("Skipping scoring for last "+(N-i)+" out of "+N+" CV models"); throw new Job.JobCancelledException(_job); } - mbs[i] = makeCVMetricBuilder(cvModelBuilders[i], fs); - + Frame cvValid = cvModelBuilders[i].valid(); + Frame preds = null; + try (Scope.Safe s = Scope.safe(cvValid)) { + Frame adaptFr = new Frame(cvValid); + if (makeCVMetrics(cvModelBuilders[i])) { + M cvModel = cvModelBuilders[i].dest().get(); + cvModel.adaptTestForTrain(adaptFr, true, !isSupervised()); + if (nclasses() == 2 /* need holdout predictions for gains/lift table */ + || _parms._keep_cross_validation_predictions + || (cvModel.isDistributionHuber() /*need to compute quantiles on abs error of holdout predictions*/)) { + String predName = cvModelBuilders[i].getPredictionKey(); + Model.PredictScoreResult result = cvModel.predictScoreImpl(cvValid, adaptFr, predName, _job, true, CFuncRef.from(_parms._custom_metric_func)); + preds = result.getPredictions(); + Scope.untrack(preds); + result.makeModelMetrics(cvValid, adaptFr); + mbs[i] = result.getMetricBuilder(); + DKV.put(cvModel); + } else { + mbs[i] = cvModel.scoreMetrics(adaptFr); + } + } + } finally { + Scope.track(preds); + } DKV.remove(cvModelBuilders[i]._parms._train,fs); DKV.remove(cvModelBuilders[i]._parms._valid,fs); weights[2*i ].remove(fs); weights[2*i+1].remove(fs); } + fs.blockForPending(); return mbs; } + protected boolean makeCVMetrics(ModelBuilder cvModelBuilder) { + return !cvModelBuilder.getName().equals("infogram"); + } + private boolean useParallelMainModelBuilding(int nFolds) { int parallelizationLevel = nModelsInParallel(nFolds); return parallelizationLevel > 1 && _parms._parallelize_cross_validation && cv_canBuildMainModelInParallel(); @@ -913,7 +908,7 @@ protected boolean cv_initStoppingParameters() { } // Step 6: build the main model - protected void buildMainModel(long max_runtime_millis) { + private void buildMainModel(long max_runtime_millis) { if (_job.stop_requested()) { Log.info("Skipping main model"); throw new Job.JobCancelledException(_job); @@ -927,7 +922,7 @@ protected void buildMainModel(long max_runtime_millis) { } // Step 7: Combine cross-validation scores; compute main model x-val scores; compute gains/lifts - protected void cv_mainModelScores(int N, ModelMetrics.MetricBuilder[] mbs, ModelBuilder[] cvModelBuilders) { + public void cv_mainModelScores(int N, ModelMetrics.MetricBuilder mbs[], ModelBuilder cvModelBuilders[]) { //never skipping CV main scores: we managed to reach last step and this should not be an expensive one, so let's offer this model M mainModel = _result.get(); @@ -1013,7 +1008,7 @@ protected void cv_mainModelScores(int N, ModelMetrics.MetricBuilder[] mbs, Model DKV.put(mainModel); } - protected void cv_makeAggregateModelMetrics(ModelMetrics.MetricBuilder[] mbs){ + public void cv_makeAggregateModelMetrics(ModelMetrics.MetricBuilder[] mbs){ for (int i = 1; i < mbs.length; ++i) { mbs[0].reduceForCV(mbs[i]); } @@ -1047,7 +1042,7 @@ protected void setMaxRuntimeSecsForMainModel() { * Also allow the cv models to be modified after all of them have been built. * For example, the model might need to be told to not do early stopping. CV models might have their lambda value modified, etc. */ - protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { } + public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { } /** @return Whether n-fold cross-validation is done */ public boolean nFoldCV() { @@ -1454,7 +1449,7 @@ public void init(boolean expensive) { } else { hide("_nfolds", "nfolds is ignored when a fold column is specified."); } - if (_parms._fold_assignment != Model.Parameters.FoldAssignmentScheme.AUTO && _parms._fold_assignment != null) { + if (_parms._fold_assignment != Model.Parameters.FoldAssignmentScheme.AUTO && _parms._fold_assignment != null && _parms != null) { error("_fold_assignment", "Fold assignment is not allowed in conjunction with a fold column."); } } @@ -1630,7 +1625,9 @@ else if (_parms._weights_column != null && _weights != null && !_weights.isBinar } if (expensive) { - Frame newtrain = encodeFrameCategoricals(_train); + boolean scopeTrack = !_parms._is_cv_model; + Frame newtrain = applyPreprocessors(_train, true, scopeTrack); + newtrain = encodeFrameCategoricals(newtrain, scopeTrack); //we could turn this into a preprocessor later if (newtrain != _train) { _origTrain = _train; _origNames = _train.names(); @@ -1641,7 +1638,8 @@ else if (_parms._weights_column != null && _weights != null && !_weights.isBinar _origTrain = null; } if (_valid != null) { - Frame newvalid = encodeFrameCategoricals(_valid /* for CV, need to score one more time in outer loop */); + Frame newvalid = applyPreprocessors(_valid, false, scopeTrack); + newvalid = encodeFrameCategoricals(newvalid, scopeTrack /* for CV, need to score one more time in outer loop */); setValid(newvalid); } boolean restructured = false; @@ -1747,7 +1745,7 @@ protected void checkCustomMetricForEarlyStopping() { public Frame init_adaptFrameToTrain(Frame fr, String frDesc, String field, boolean expensive) { Frame adapted = adaptFrameToTrain(fr, frDesc, field, expensive, false); if (expensive) - adapted = encodeFrameCategoricals(adapted); + adapted = encodeFrameCategoricals(adapted, true); return adapted; } @@ -1784,7 +1782,25 @@ private Frame adaptFrameToTrain(Frame fr, String frDesc, String field, boolean e return adapted; } - private Frame encodeFrameCategoricals(Frame fr) { + private Frame applyPreprocessors(Frame fr, boolean isTraining, boolean scopeTrack) { + if (_parms._preprocessors == null) return fr; + + for (Key key : _parms._preprocessors) { + DKV.prefetch(key); + } + Frame result = fr; + Frame encoded; + for (Key key : _parms._preprocessors) { + ModelPreprocessor preprocessor = key.get(); + encoded = isTraining ? preprocessor.processTrain(result, _parms) : preprocessor.processValid(result, _parms); + if (encoded != result) trackEncoded(encoded, scopeTrack); + result = encoded; + } + if (!scopeTrack) Scope.untrack(result); // otherwise encoded frame is fully removed on CV model completion, raising exception when computing CV scores. + return result; + } + + private Frame encodeFrameCategoricals(Frame fr, boolean scopeTrack) { Frame encoded = FrameUtils.categoricalEncoder( fr, _parms.getNonPredictors(), @@ -1792,49 +1808,17 @@ private Frame encodeFrameCategoricals(Frame fr) { getToEigenVec(), _parms._max_categorical_levels ); - if (encoded != fr) track(encoded); + if (encoded != fr) trackEncoded(encoded, scopeTrack); return encoded; } - - protected void track(Frame... frames) { - for (Frame fr : frames) track(fr, _parms._is_cv_model); - } - - protected void track(Frame fr, boolean keepUntilCompletion) { - if (fr == null || fr._key == null) return; - if (keepUntilCompletion) { - keepUntilCompletion(fr._key); - Scope.untrack(fr); - } - else { + + private void trackEncoded(Frame fr, boolean scopeTrack) { + assert fr._key != null; + if (scopeTrack) Scope.track(fr); - } - } - - protected void track(Vec... vecs) { - for (Vec vec : vecs) track(vec, _parms._is_cv_model); - } - - protected void track(Vec vec, boolean keepUntilCompletion) { - if (vec == null || vec._key == null) return; - if (keepUntilCompletion) { - keepUntilCompletion(vec._key); - Scope.untrack(vec._key); - } else { - Scope.track(vec); - } - } - - /** - * Track keys to be removed only once the model is fully trained. - * Especially useful for keys created during CV model training that may be needed after the CV model is trained (e.g. CV scoring). - * @param key - */ - protected void keepUntilCompletion(Key key) { - assert key != null; - _workspace.getToDelete(true).put(key, Arrays.toString(Thread.currentThread().getStackTrace())); + else + _workspace.getToDelete(true).put(fr._key, Arrays.toString(Thread.currentThread().getStackTrace())); } - /** * Rebalance a frame for load balancing @@ -2162,7 +2146,7 @@ public String getName() { return getClass().getSimpleName().toLowerCase(); } - protected void cleanUp() { + private void cleanUp() { _workspace.cleanUp(); } diff --git a/h2o-core/src/main/java/hex/ModelBuilderCallbacks.java b/h2o-core/src/main/java/hex/ModelBuilderCallbacks.java deleted file mode 100644 index 9bc641b77d82..000000000000 --- a/h2o-core/src/main/java/hex/ModelBuilderCallbacks.java +++ /dev/null @@ -1,54 +0,0 @@ -package hex; - -import water.Iced; -import water.Key; -import water.util.ArrayUtils; - -public abstract class ModelBuilderCallbacks extends Iced { - - private static class HandledException extends Exception { - @Override - public String toString() { - return ""; - } - - @Override - public synchronized Throwable fillInStackTrace() { - return this; - } - } - - private static final HandledException HANDLED = new HandledException(); - - public void wrapCompute(ModelBuilder builder, Runnable compute) { compute.run(); } - - /** - * Callback for successfully finished model builds - * - * @param modelKey key of built Model. - */ - public void onModelSuccess(Key modelKey) {} - - /** - * Callback for failed model builds - * - * @param modelKey Key of the model that was attempted at being built. - * @param cause An instance of {@link Throwable} - cause of failure - * @param parameters An instance of Model.Parameters used in the attempt to build the model - */ - public void onModelFailure(Key modelKey, Throwable cause, Model.Parameters parameters) {} - - /** - * subclasses may want to call this before processing exceptions on model failure. - * @param cause - * @return true if the exception is considered as having already been handled and can be ignored. - * false otherwise: in this case the exception is automatically marked as handled for future checks, - * but current code is supposed to handle it immediately. - */ - protected boolean checkExceptionHandled(Throwable cause) { - if (ArrayUtils.contains(cause.getSuppressed(), HANDLED)) return true; - cause.addSuppressed(HANDLED); - return false; - } - -} diff --git a/h2o-core/src/main/java/hex/ModelBuilderListener.java b/h2o-core/src/main/java/hex/ModelBuilderListener.java new file mode 100644 index 000000000000..96cfc1a47a96 --- /dev/null +++ b/h2o-core/src/main/java/hex/ModelBuilderListener.java @@ -0,0 +1,22 @@ +package hex; + +import water.Iced; + +public abstract class ModelBuilderListener extends Iced { + /** + * Callback for successfully finished model builds + * + * @param model Model built + */ + abstract void onModelSuccess(Model model); + + /** + * Callback for failed model builds + * + * @param cause An instance of {@link Throwable} - cause of failure + * @param parameters An instance of Model.Parameters used in the attempt to build the model + */ + abstract void onModelFailure(Throwable cause, Model.Parameters parameters); + + +} diff --git a/h2o-core/src/main/java/hex/ModelMetrics.java b/h2o-core/src/main/java/hex/ModelMetrics.java index 70001292315c..0e0e3624da9f 100755 --- a/h2o-core/src/main/java/hex/ModelMetrics.java +++ b/h2o-core/src/main/java/hex/ModelMetrics.java @@ -475,8 +475,7 @@ public void postGlobal(CustomMetric customMetric) { * @param m Model * @param f Scored Frame * @param adaptedFrame Adapted Frame - * @param preds Predictions of m on f (optional) - * @return Filled Model Metrics object + *@param preds Predictions of m on f (optional) @return Filled Model Metrics object */ public abstract ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds); diff --git a/h2o-core/src/main/java/hex/ModelParametersBuilderFactory.java b/h2o-core/src/main/java/hex/ModelParametersBuilderFactory.java index 5aac8638b1d5..1bf6d854f94b 100644 --- a/h2o-core/src/main/java/hex/ModelParametersBuilderFactory.java +++ b/h2o-core/src/main/java/hex/ModelParametersBuilderFactory.java @@ -20,8 +20,6 @@ public interface ModelParametersBuilderFactory { * @return this parameters builder */ ModelParametersBuilder get(MP initialParams); - - /** * Returns mapping from input parameter specification to @@ -40,10 +38,8 @@ public interface ModelParametersBuilderFactory { * * @param type of produced model parameters object */ - interface ModelParametersBuilder { - - boolean isAssignable(String name); - + interface ModelParametersBuilder { + ModelParametersBuilder set(String name, Object value); MP build(); diff --git a/h2o-core/src/main/java/hex/ModelParametersDelegateBuilderFactory.java b/h2o-core/src/main/java/hex/ModelParametersDelegateBuilderFactory.java deleted file mode 100644 index 5f166c10c582..000000000000 --- a/h2o-core/src/main/java/hex/ModelParametersDelegateBuilderFactory.java +++ /dev/null @@ -1,60 +0,0 @@ -package hex; - -import water.util.PojoUtils.FieldNaming; - -/** - * This {@link ModelParametersBuilderFactory} delegates the hyper-parameters building logic - * to the initial {@link Model.Parameters} instance itself, using the {@link Parameterizable} methods. - * This allows better control for complex parameters objects that may this way accept nested hyper-parameters. - */ -public class ModelParametersDelegateBuilderFactory implements ModelParametersBuilderFactory { - - protected final FieldNaming fieldNaming; - - public ModelParametersDelegateBuilderFactory() { - this(FieldNaming.CONSISTENT); - } - - public ModelParametersDelegateBuilderFactory(FieldNaming fieldNaming) { - this.fieldNaming = fieldNaming; - } - - @Override - public ModelParametersBuilder get(MP initialParams) { - return new DelegateParamsBuilder<>(initialParams, fieldNaming); - } - - @Override - public FieldNaming getFieldNamingStrategy() { - return fieldNaming; - } - - public static class DelegateParamsBuilder - implements ModelParametersBuilder { - - protected final MP params; - protected final FieldNaming fieldNaming; - - - protected DelegateParamsBuilder(MP params, FieldNaming fieldNaming) { - this.params = params; - this.fieldNaming = fieldNaming; - } - - @Override - public boolean isAssignable(String name) { - return this.params.isParameterAssignable(fieldNaming.toDest(name)); - } - - @Override - public ModelParametersBuilder set(String name, Object value) { - this.params.setParameter(fieldNaming.toDest(name), value); - return this; - } - - @Override - public MP build() { - return params; - } - } -} diff --git a/h2o-core/src/main/java/hex/ModelParametersGenericBuilderFactory.java b/h2o-core/src/main/java/hex/ModelParametersGenericBuilderFactory.java deleted file mode 100644 index 28f27ed574f9..000000000000 --- a/h2o-core/src/main/java/hex/ModelParametersGenericBuilderFactory.java +++ /dev/null @@ -1,103 +0,0 @@ -package hex; - -import water.util.Log; -import water.util.PojoUtils; -import water.util.PojoUtils.FieldNaming; - -import java.util.HashMap; -import java.util.Map; - -/** - * A {@link ModelParametersBuilderFactory} that can dynamically generate parameters for any kind of model algorithm, - * as soon as one of the hyper-parameter is named {@value #ALGO_PARAM}, - * in which case it is recommended to obtain a new builder using a {@link CommonModelParameters} instance, - * that will be used to provide the standard params for all type of algos. - * - * Otherwise, if there's no {@value #ALGO_PARAM} hyper-parameter, this factory behaves similarly to {@link ModelParametersBuilderFactory}. - * - * TODO: future improvement. When griding over multiple algos, we may want to apply different values for an hyper-parameter with the same name on algo-A and algo-B. - * In this case, we should be able to handle hyper-parameters differently based on naming convention. For example using `$` to prefix the param with the algo: - * - GBM$_max_depth = [3, 5, 7, 9, 11] - * - XGBoost$_max_depth = [5, 10, 15] - * as soon as the algo is defined, then the params are assigned this way: - * - if `_my_param` is provided, check if `Algo$_my_param` is also provided: if so then apply only the latter, otherwise apply the former. - */ -public class ModelParametersGenericBuilderFactory extends ModelParametersDelegateBuilderFactory { - - public static final String ALGO_PARAM = "algo"; - - /** - * A generic class containing only common {@link Model.Parameters} that can be used as initial common parameters - * when searching over multiple algos. - */ - public static class CommonModelParameters extends Model.Parameters { - @Override - public String algoName() { - return null; - } - - @Override - public String fullName() { - return null; - } - - @Override - public String javaName() { - return null; - } - - @Override - public long progressUnits() { - return 0; - } - } - - public ModelParametersGenericBuilderFactory() { - super(); - } - - public ModelParametersGenericBuilderFactory(FieldNaming fieldNaming) { - super(fieldNaming); - } - - @Override - public ModelParametersBuilder get(Model.Parameters initialParams) { - return new GenericParamsBuilder(initialParams, fieldNaming); - } - - public static class GenericParamsBuilder extends DelegateParamsBuilder { - - private final Map hyperParams = new HashMap<>(); - - public GenericParamsBuilder(Model.Parameters params, FieldNaming fieldNaming) { - super(params, fieldNaming); - } - - @Override - public ModelParametersBuilder set(String name, Object value) { - hyperParams.put(name, value); - return this; - } - - @Override - public Model.Parameters build() { - Model.Parameters result = params; - String algo = null; - if (hyperParams.containsKey(ALGO_PARAM)) { - algo = (String) hyperParams.get(ALGO_PARAM); - result = ModelBuilder.makeParameters(algo); - //add values from init params - PojoUtils.copyProperties(result, params, FieldNaming.CONSISTENT); - } - for (Map.Entry e : hyperParams.entrySet()) { - if (ALGO_PARAM.equals(e.getKey())) continue; - if (algo == null || result.hasParameter(fieldNaming.toDest(e.getKey()))) { // no check for `result.hasParameter` in case of strict algo, so that we can fail on invalid param - result.setParameter(fieldNaming.toDest(e.getKey()), e.getValue()); - } else { // algo hyper-param was provided and this hyper-param is incompatible with it - Log.debug("Ignoring hyper-parameter `"+e.getKey()+"` unsupported by `"+algo+"`."); - } - } - return result; - } - } -} diff --git a/h2o-core/src/main/java/hex/ParallelModelBuilder.java b/h2o-core/src/main/java/hex/ParallelModelBuilder.java index 34e365621e5a..98bb175a165e 100644 --- a/h2o-core/src/main/java/hex/ParallelModelBuilder.java +++ b/h2o-core/src/main/java/hex/ParallelModelBuilder.java @@ -3,10 +3,10 @@ import jsr166y.ForkJoinTask; import org.apache.log4j.Logger; import water.Iced; -import water.Key; import water.util.IcedAtomicInt; import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; /** * Dispatcher for parallel model building. Starts building models every time the run method is invoked. @@ -29,12 +29,12 @@ public static abstract class ParallelModelBuilderCallback modelBuilders) { if (LOG.isTraceEnabled()) LOG.trace("run with " + modelBuilders.size() + " models"); for (final ModelBuilder modelBuilder : modelBuilders) { _modelInProgressCounter.incrementAndGet(); - modelBuilder.setCallbacks(_modelBuildersCallbacks); - modelBuilder.trainModel(); + modelBuilder.trainModel(_parallelModelBuiltListener); } } - private class EachBuilderCallbacks extends ModelBuilderCallbacks { + private class ParallelModelBuiltListener extends ModelBuilderListener { @Override - public void onModelSuccess(Key modelKey) { - Model model = modelKey.get(); - if (model._parms._is_cv_model) return; // not interested in CV models here + public void onModelSuccess(Model model) { try { _callback.onBuildSuccess(model, ParallelModelBuilder.this); } finally { @@ -67,9 +64,7 @@ public void onModelSuccess(Key modelKey) { } @Override - public void onModelFailure(Key modelKey, Throwable cause, Model.Parameters parameters) { - if (checkExceptionHandled(cause)) return; - if (parameters._is_cv_model) return; // not interested in CV models here + public void onModelFailure(Throwable cause, Model.Parameters parameters) { try { final ModelBuildFailure modelBuildFailure = new ModelBuildFailure(cause, parameters); _callback.onBuildFailure(modelBuildFailure, ParallelModelBuilder.this); diff --git a/h2o-core/src/main/java/hex/Parameterizable.java b/h2o-core/src/main/java/hex/Parameterizable.java deleted file mode 100644 index 6ed6f91d3938..000000000000 --- a/h2o-core/src/main/java/hex/Parameterizable.java +++ /dev/null @@ -1,61 +0,0 @@ -package hex; - -/** - * Must be implemented by classes that can be configured dynamically during HyperParameter-Optimization. - */ -public interface Parameterizable { - - /** - * @param name hyperparameter name. - * @return true if this hyperparameter is generally supported. - */ - boolean hasParameter(String name); - - /** - * - * @param name hyperparameter name. - * @return the current value for the given hyperparameter. - */ - Object getParameter(String name); - - /** - * - * @param name hyperparameter name. - * @param value the new value to assign for the given hyperparameter. - */ - void setParameter(String name, Object value); - - /** - * - * @param name hyperparameter name. - * @return true iff the hyperparameter is currently set to its default value. - */ - boolean isParameterSetToDefault(String name); - - /** - * - * @param name hyperparameter name. - * @return the default value for the given hyperparameter. - */ - Object getParameterDefaultValue(String name); - - /** - * - * @param name hyperparameter name. - * @return true iff the given hyperparameter is allowed to be modified/reassigned on this instance. - */ - boolean isParameterAssignable(String name); - - /** - * To be implemented by subclasses to provide a proper copy that can be then parametrized without risk of modifying the original. - * This is necessary for hyperparameter search to work properly: - * for most subclasses it can default to a basic clone, - * but some others (e.g. compound model parameters like {@link hex.pipeline.PipelineModel.PipelineParameters}) - * may also need to create fresh keys pointing to fresh objects: a simple {@code clone} or {@code deepClone} would not be enough for those. - * - * Also, simply overriding {@code clone} would have unsuitable side effects as simple cloning is necessary and used during serialization. - * - * @return a copy of the current instance, safe to use in hyperparameter search. - */ - SELF freshCopy(); -} diff --git a/h2o-core/src/main/java/hex/SubModelBuilder.java b/h2o-core/src/main/java/hex/SubModelBuilder.java index 7b1c7890c2a7..c5681c69510f 100644 --- a/h2o-core/src/main/java/hex/SubModelBuilder.java +++ b/h2o-core/src/main/java/hex/SubModelBuilder.java @@ -4,6 +4,7 @@ import water.H2O; import water.Job; import water.ParallelizationTask; +import water.util.Log; /** * Execute build of a collection of sub-models (CV models, main model) in parallel diff --git a/h2o-core/src/main/java/hex/faulttolerance/Recovery.java b/h2o-core/src/main/java/hex/faulttolerance/Recovery.java index 46c7fcfc21a6..b1fc8af73090 100644 --- a/h2o-core/src/main/java/hex/faulttolerance/Recovery.java +++ b/h2o-core/src/main/java/hex/faulttolerance/Recovery.java @@ -244,7 +244,7 @@ void autoRecover() { Grid grid = Grid.importBinary(recoveryFile(resultKey), true); GridSearch.resumeGridSearch( jobKey, grid, - new GridSearchHandler.SchemaModelParametersBuilderFactory(), + new GridSearchHandler.DefaultModelParametersBuilderFactory(), (Recovery) this ); } else { diff --git a/h2o-core/src/main/java/hex/grid/Grid.java b/h2o-core/src/main/java/hex/grid/Grid.java index 4edaf5b233d2..80b26ddffa18 100644 --- a/h2o-core/src/main/java/hex/grid/Grid.java +++ b/h2o-core/src/main/java/hex/grid/Grid.java @@ -16,6 +16,9 @@ import java.lang.reflect.Array; import java.net.URI; import java.util.*; +import java.util.stream.Collectors; + +import static hex.grid.GridSearch.IGNORED_FIELDS_PARAM_HASH; /** * A Grid of Models representing result of hyper-parameter space exploration. @@ -102,20 +105,52 @@ private SearchFailure(final Class paramsClass) { */ private void appendFailedModelParameters(MP params, String[] rawParams, String failureDetails, String stackTrace) { assert rawParams != null : "API has to always pass rawParams"; - _failed_params = ArrayUtils.append(_failed_params, params); - _failure_details = ArrayUtils.append(_failure_details, failureDetails); - _failed_raw_params = ArrayUtils.append(_failed_raw_params, new String[][]{rawParams}); - _failure_stack_traces = ArrayUtils.append(_failure_stack_traces, stackTrace); + // Append parameter + MP[] a = _failed_params; + MP[] na = Arrays.copyOf(a, a.length + 1); + na[a.length] = params; + _failed_params = na; + // Append message + String[] m = _failure_details; + String[] nm = Arrays.copyOf(m, m.length + 1); + nm[m.length] = failureDetails; + _failure_details = nm; + // Append raw params + String[][] rp = _failed_raw_params; + String[][] nrp = Arrays.copyOf(rp, rp.length + 1); + nrp[rp.length] = rawParams; + _failed_raw_params = nrp; + // Append stack trace + String[] st = _failure_stack_traces; + String[] nst = Arrays.copyOf(st, st.length + 1); + nst[st.length] = stackTrace; + _failure_stack_traces = nst; } - public void addWarning(String message) { - Log.warn(message); - _warning_details = ArrayUtils.append(_warning_details, message); + private void appendWarningMessage(String[] hyper_parameter, String checkField) { + if (hyper_parameter != null && Arrays.asList(hyper_parameter).contains(checkField)) { + String warningMessage = null; + if ("alpha".equals(checkField)) { + warningMessage = "Adding alpha array to hyperparameter runs slower with gridsearch. " + + "This is due to the fact that the algo has to run initialization for every alpha value. " + + "Setting the alpha array as a model parameter will skip the initialization and run faster overall."; + } + if (warningMessage != null) { + Log.warn(warningMessage); + // Append message + String[] m = _warning_details; + String[] nm = Arrays.copyOf(m, m.length+1); + nm[m.length] = warningMessage; + _warning_details = nm; + } + } } public void appendFailedModelParameters(final MP[] params, final String[][] rawParams, final String[] failureDetails, final String[] stackTraces) { + assert rawParams != null : "API has to always pass rawParams"; + _failed_params = ArrayUtils.append(_failed_params, params); _failed_raw_params = ArrayUtils.append(_failed_raw_params, rawParams); _failure_details = ArrayUtils.append(_failure_details, failureDetails); @@ -273,7 +308,7 @@ public Model getModel(MP params) { } public Key getModelKey(MP params) { - long checksum = params.checksum(); + long checksum = params.checksum(IGNORED_FIELDS_PARAM_HASH); return getModelKey(checksum); } @@ -282,8 +317,8 @@ Key getModelKey(long paramsChecksum) { return mKey; } - /* FIXME: should pass model parameters instead of checksum, but model - * parameters are not immutable and model builder modifies them! */ + /* FIXME: should pass model parameters instead of checksum, but model + * parameters are not imutable and model builder modifies them! */ /* package */ synchronized Key putModel(long checksum, Key modelKey) { return _models.put(IcedLong.valueOf(checksum), modelKey); @@ -311,7 +346,7 @@ private void appendFailedModelParameters(final Key modelKey, final MP par _failures.put(searchedKey, searchFailure); } searchFailure.appendFailedModelParameters(params, rawParams, failureDetails, stackTrace); - if (params != null) params.addSearchWarnings(searchFailure, this); + searchFailure.appendWarningMessage(_hyper_names, "alpha"); } static boolean isJobCanceled(final Throwable t) { @@ -404,12 +439,12 @@ public SearchFailure getFailures() { final Collection values = _failures.values(); // Original failures should be left intact. Also avoid mutability from outer space. final SearchFailure searchFailure = new SearchFailure(_params != null ? _params.getClass() : null); - if (_params != null) _params.addSearchWarnings(searchFailure, this); for (SearchFailure f : values) { searchFailure.appendFailedModelParameters(f._failed_params, f._failed_raw_params, f._failure_details, f._failure_stack_traces); } + searchFailure.appendWarningMessage(_hyper_names, "alpha"); return searchFailure; } @@ -433,7 +468,7 @@ protected void clearNonRelatedFailures(){ public Object[] getHyperValues(MP parms) { Object[] result = new Object[_hyper_names.length]; for (int i = 0; i < _hyper_names.length; i++) { - result[i] = parms.getParameter(_field_naming_strategy.toDest(_hyper_names[i])); + result[i] = PojoUtils.getFieldValue(parms, _hyper_names[i], _field_naming_strategy); } return result; } @@ -525,7 +560,7 @@ public TwoDimTable createSummaryTable(Key[] model_ids, String sort_by, bo Model.Parameters parms = m._parms; int j; for (j = 0; j < _hyper_names.length; ++j) { - Object paramValue = parms.getParameter(_field_naming_strategy.toDest(_hyper_names[j])); + Object paramValue = PojoUtils.getFieldValue(parms, _hyper_names[j], _field_naming_strategy); if (paramValue.getClass().isArray()) { // E.g., GLM alpha/lambda parameters can be arrays with one value if (paramValue instanceof float[] && ((float[])paramValue).length == 1) paramValue = ((float[]) paramValue)[0]; diff --git a/h2o-core/src/main/java/hex/grid/GridSearch.java b/h2o-core/src/main/java/hex/grid/GridSearch.java index 850d23e56a50..9d9c15ac12a0 100644 --- a/h2o-core/src/main/java/hex/grid/GridSearch.java +++ b/h2o-core/src/main/java/hex/grid/GridSearch.java @@ -5,12 +5,12 @@ import hex.grid.HyperSpaceWalker.BaseWalker; import jsr166y.CountedCompleter; import water.*; -import water.KeyGen.PatternKeyGen; import water.exceptions.H2OConcurrentModificationException; import water.exceptions.H2OGridException; import water.exceptions.H2OIllegalArgumentException; import water.fvec.Frame; import water.util.Log; +import water.util.PojoUtils; import java.util.*; import java.util.concurrent.locks.Lock; @@ -101,10 +101,6 @@ private Builder(Key destKey, HyperSpaceWalker hyperSpaceWalker) { _gridSearch = new GridSearch<>(destKey, hyperSpaceWalker); } - public Key dest() { - return _gridSearch._result; - } - public Builder withParallelism(int parallelism) { _gridSearch._parallelism = parallelism; return this; @@ -140,9 +136,6 @@ public Job start() { private int _maxConsecutiveFailures = Integer.MAX_VALUE; // for now, disabled by default private final Key _result; - - private final KeyGen _modelKeyGen = new PatternKeyGen("{0}_model_{1}"); - /** Walks hyper space and for each point produces model parameters. It is * used only locally to fire new model builders. */ private final transient HyperSpaceWalker _hyperSpaceWalker; @@ -279,7 +272,7 @@ public void onBuildSuccess(final Model finishedModel, final ParallelModelBuilder try { parallelSearchGridLock.lock(); constructScoringInfo(finishedModel); - onModel(grid, finishedModel._input_parms.checksum(), finishedModel._key); + onModel(grid, finishedModel._input_parms.checksum(IGNORED_FIELDS_PARAM_HASH), finishedModel._key); _job.update(1); grid.update(_job); @@ -351,7 +344,7 @@ private MP getNextModelParams(final HyperSpaceWalker.HyperSpaceIterator hype while (params == null) { if (hyperSpaceIterator.hasNext()) { params = hyperSpaceIterator.nextModelParameters(); - final Key modelKey = grid.getModelKey(params.checksum()); + final Key modelKey = grid.getModelKey(params.checksum(IGNORED_FIELDS_PARAM_HASH)); if (modelKey != null) { params = null; } @@ -378,7 +371,7 @@ private void parallelGridSearch(final Grid grid) { while (startModels.size() < _parallelism && iterator.hasNext()) { final MP nextModelParameters = iterator.nextModelParameters(); - final long checksum = nextModelParameters.checksum(); + final long checksum = nextModelParameters.checksum(IGNORED_FIELDS_PARAM_HASH); if (grid.getModelKey(checksum) == null) { startModels.add(ModelBuilder.make(nextModelParameters)); } @@ -405,6 +398,7 @@ private void parallelGridSearch(final Grid grid) { * @param grid grid object to save results; grid already locked */ private void gridSearch(Grid grid) { + final String protoModelKey = grid._key + "_model_"; // Get iterator to traverse hyper space HyperSpaceWalker.HyperSpaceIterator it = _hyperSpaceWalker.iterator(); // Number of traversed model parameters @@ -427,7 +421,7 @@ private void gridSearch(Grid grid) { scoringInfo.time_stamp_ms = System.currentTimeMillis(); //// build the model! - model = buildModel(params, grid, ++counter); + model = buildModel(params, grid, ++counter, protoModelKey); if (model != null) { model.fillScoringInfo(scoringInfo); grid.setScoringInfos(ScoringInfo.prependScoringInfo(scoringInfo, grid.getScoringInfos())); @@ -439,7 +433,7 @@ private void gridSearch(Grid grid) { if (Job.isCancelledException(e)) { assert model == null; - final long checksum = params.checksum(); + final long checksum = params.checksum(IGNORED_FIELDS_PARAM_HASH); final Key[] modelKeys = findModelsByChecksum(checksum); if (modelKeys.length == 1) { Keyed.removeQuietly(modelKeys[0]); @@ -532,6 +526,11 @@ private void attemptGridSave(final Grid grid) { grid.exportBinary(checkpointsDir, false); } + static final Set IGNORED_FIELDS_PARAM_HASH = new HashSet<>(Arrays.asList( + "_export_checkpoints_dir", + "_max_runtime_secs" // We are modifying ourselves in Grid Search code + )); + /** * Build a model based on specified parameters and save it to resulting Grid object. * @@ -547,15 +546,17 @@ private void attemptGridSave(final Grid grid) { * @param params parameters for a new model * @param grid grid object holding created models * @param paramsIdx index of generated model parameter + * @param protoModelKey prototype of model key * @return return a new model if it does not exist */ - private Model buildModel(final MP params, Grid grid, int paramsIdx) { + private Model buildModel(final MP params, Grid grid, int paramsIdx, String protoModelKey) { // Make sure that the model is not yet built (can be case of duplicated hyper parameters). // We first look in the grid _models cache, then we look in the DKV. // FIXME: get checksum here since model builder will modify instance of params!!! // Grid search might be continued over the very exact hyperspace, but with autoexporting disabled. - final long checksum = params.checksum(); + // To prevent + final long checksum = params.checksum(IGNORED_FIELDS_PARAM_HASH); Key key = grid.getModelKey(checksum); if (key != null) { if (DKV.get(key) == null) { @@ -577,11 +578,11 @@ private Model buildModel(final MP params, Grid grid, int paramsIdx) { // Modify model key to have nice version with counter // Note: Cannot create it before checking the cache since checksum would differ for each model - Key result = _modelKeyGen.make(grid._key, paramsIdx); + Key result = Key.make(protoModelKey + paramsIdx); // Build a new model assert grid.getModel(params) == null; Model m = ModelBuilder.trainModelNested(_job, result, params, null); - assert checksum == m._input_parms.checksum() : + assert checksum == m._input_parms.checksum(IGNORED_FIELDS_PARAM_HASH) : "Model checksum different from original params"; onModel(grid, checksum, result); return m; @@ -598,7 +599,7 @@ public boolean filter(KeySnapshot.KeyInfo k) { if ((m == null) || (m._parms == null)) return false; try { - return m._parms.checksum() == checksum; + return m._parms.checksum(IGNORED_FIELDS_PARAM_HASH) == checksum; } catch (H2OConcurrentModificationException e) { // We are inspecting model parameters that doesn't belong to us - they might be modified (or deleted) while // checksum is being calculated: we skip them (see PUBDEV-5286) @@ -633,7 +634,7 @@ protected static Key gridKeyName(String modelName, Frame fr) { if (fr == null || fr._key == null) { throw new IllegalArgumentException("The frame being grid-searched over must have a Key"); } - return Key.make("Grid_" + modelName + "_" + fr._key + H2O.calcNextUniqueModelId("")); + return Key.make("Grid_" + modelName + "_" + fr._key.toString() + H2O.calcNextUniqueModelId("")); } /** @@ -841,6 +842,58 @@ public static Job resumeGridSearch( grid.getParallelism() ); } + + /** + * The factory is producing a parameters builder which uses reflection to setup field values. + * + * @param type of model parameters object + */ + public static class SimpleParametersBuilderFactory + implements ModelParametersBuilderFactory { + + @Override + public ModelParametersBuilder get(MP initialParams) { + return new SimpleParamsBuilder<>(initialParams); + } + + @Override + public PojoUtils.FieldNaming getFieldNamingStrategy() { + return PojoUtils.FieldNaming.CONSISTENT; + } + + /** + * The builder modifies initial model parameters directly by reflection. + * + * Usage: + *

{@code
+     *   GBMModel.GBMParameters params =
+     *     new SimpleParamsBuilder(initialParams)
+     *      .set("_ntrees", 30).set("_learn_rate", 0.01).build()
+     * }
+ * + * @param type of model parameters object + */ + public static class SimpleParamsBuilder + implements ModelParametersBuilder { + + final private MP params; + + public SimpleParamsBuilder(MP initialParams) { + params = initialParams; + } + + @Override + public ModelParametersBuilder set(String name, Object value) { + PojoUtils.setField(params, name, value, PojoUtils.FieldNaming.CONSISTENT); + return this; + } + + @Override + public MP build() { + return params; + } + } + } /** * Constant for adaptive parallelism level - number of models built in parallel is decided by H2O. diff --git a/h2o-core/src/main/java/hex/grid/HyperSpaceWalker.java b/h2o-core/src/main/java/hex/grid/HyperSpaceWalker.java index 71aed48fa26f..85f71500d1bb 100644 --- a/h2o-core/src/main/java/hex/grid/HyperSpaceWalker.java +++ b/h2o-core/src/main/java/hex/grid/HyperSpaceWalker.java @@ -1,21 +1,23 @@ package hex.grid; -import hex.*; -import hex.ModelParametersBuilderFactory.ModelParametersBuilder; +import hex.Model; +import hex.ModelParametersBuilderFactory; +import hex.ScoreKeeper; +import hex.ScoringInfo; import hex.grid.HyperSpaceSearchCriteria.CartesianSearchCriteria; import hex.grid.HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria; import hex.grid.HyperSpaceSearchCriteria.Strategy; import water.exceptions.H2OIllegalArgumentException; import water.util.ArrayUtils; -import water.util.RandomUtils; +import water.util.PojoUtils; import java.util.*; import java.util.function.Consumer; import java.util.stream.Stream; -public interface HyperSpaceWalker { +import static hex.grid.HyperSpaceWalker.BaseWalker.SUBSPACES; - String SUBSPACES = "subspaces"; +public interface HyperSpaceWalker { interface HyperSpaceIterator { /** @@ -146,6 +148,8 @@ abstract class BaseWalker _hyperParams; - long model_number = 0; // denote model number + long model_number = 0l; // denote model number /** * Cached names of used hyper parameters. */ @@ -226,24 +232,25 @@ public BaseWalker(MP params, ModelParametersBuilderFactory paramsBuilderFactory, C search_criteria) { _params = params; - _hyperParams = processHyperParams(hyperParams); + _hyperParams = hyperParams; _paramsBuilderFactory = paramsBuilderFactory; - _hyperParamNames = _hyperParams.keySet().toArray(new String[0]); + _hyperParamNames = hyperParams.keySet().toArray(new String[0]); _hyperParamSubspaces = extractSubspaces(); _hyperParamNamesSubspace = extractSubspaceNames(); _hyperParams.remove(SUBSPACES); _search_criteria = search_criteria; _maxHyperSpaceSize = computeMaxSizeOfHyperSpace(); - - validateHyperParams(_hyperParams, false); - Arrays.stream(_hyperParamSubspaces).forEach(subspace -> validateHyperParams(subspace, true)); + + // Sanity check the hyperParams map, and check it against the params object + try { + _defaultParams = (MP) params.getClass().newInstance(); + } catch (Exception e) { + throw new H2OIllegalArgumentException("Failed to instantiate a new Model.Parameters object to get the default values."); + } + validateParams(_hyperParams, false); + Arrays.stream(_hyperParamSubspaces).forEach(subspace -> validateParams(subspace, true)); } // BaseWalker() - /** can be used by walkers to extract some specific hyper params */ - protected Map processHyperParams(Map hyperParams) { - return hyperParams; - } - @Override public String[] getHyperParamNames() { return _hyperParamNames; @@ -285,11 +292,15 @@ private String[] extractSubspaceNames() { } protected MP getModelParams(MP params, Object[] hyperParams, String[] hyperParamNames) { - ModelParametersBuilder - paramsBuilder = _paramsBuilderFactory.get((MP) params.freshCopy()); + ModelParametersBuilderFactory.ModelParametersBuilder + paramsBuilder = _paramsBuilderFactory.get(params); for (int i = 0; i < hyperParamNames.length; i++) { String paramName = hyperParamNames[i]; Object paramValue = hyperParams[i]; + if (paramName.equals("valid")) { // change paramValue to key for validation_frame + paramName = "validation_frame"; // @#$, paramsSchema is still using validation_frame and training_frame + } + paramsBuilder.set(paramName, paramValue); } return paramsBuilder.build(); @@ -341,25 +352,53 @@ protected int integerHash(Map hyperParams, String[] hyperParam return Arrays.deepHashCode(hashMe); } - private void validateHyperParams(Map hyperParams, boolean isSubspace) { + private void validateParams(Map params, boolean isSubspace) { // if a parameter is specified in both model parameter and hyper-parameter, this is only allowed if the // parameter value is set to be default. Otherwise, an exception will be thrown. - ModelParametersBuilder paramsBuilder = _paramsBuilderFactory.get((MP)_params.clone()); - - for (Map.Entry e : hyperParams.entrySet()) { + for (String key : params.keySet()) { // Throw if the user passed an empty value list: - if (e.getValue() == null || e.getValue().length == 0) - throw new H2OIllegalArgumentException("Grid search hyperparameter value list is empty for hyperparameter: " + e.getKey()); + Object[] values = params.get(key); + if (0 == values.length) + throw new H2OIllegalArgumentException("Grid search hyperparameter value list is empty for hyperparameter: " + key); - if (isSubspace && _hyperParams.containsKey(e.getKey())) { - throw new H2OIllegalArgumentException("Grid search model parameter '" + e.getKey() + "' is set in " + + if ("seed".equals(key) || "_seed".equals(key)) continue; // initialized to the wall clock + + if (isSubspace && _hyperParams.containsKey(key)) { + throw new H2OIllegalArgumentException("Grid search model parameter '" + key + "' is set in " + "both the subspaces and in the hyperparameters map. This is ambiguous; set it in one place" + " or the other, not both."); } - if (!paramsBuilder.isAssignable(e.getKey())) { - throw new H2OIllegalArgumentException("Grid search model parameter '"+e.getKey()+"' is invalid or set both in the model parameters and in the hyperparameters map. "+ - "This is ambiguous; set it in one place or the other, not both."); - } + + validateParamVals(key); + // Ugh. Java callers, like the JUnits or Sparkling Water users, use a leading _. REST users don't. + } + } + + private void validateParamVals(String key) { + String prefix = (key.startsWith("_") ? "" : "_"); + + // Throw if params has a non-default value which is not in the hyperParams map + Object defaultVal = PojoUtils.getFieldValue(_defaultParams, prefix + key, PojoUtils.FieldNaming.CONSISTENT); + Object actualVal = PojoUtils.getFieldValue(_params, prefix + key, PojoUtils.FieldNaming.CONSISTENT); + + if (defaultVal != null && actualVal != null) { + // both are not set to null + if (defaultVal.getClass().isArray() && + // array + !PojoUtils.arraysEquals(defaultVal, actualVal)) { + throw new H2OIllegalArgumentException("Grid search model parameter '" + key + "' is set in both the model parameters and in the hyperparameters map. This is ambiguous; set it in one place or the other, not both."); + } // array + if (!defaultVal.getClass().isArray() && + // ! array + !defaultVal.equals(actualVal)) { + throw new H2OIllegalArgumentException("Grid search model parameter '" + key + "' is set in both the model parameters and in the hyperparameters map. This is ambiguous; set it in one place or the other, not both."); + } // ! array + } // both are set: defaultVal != null && actualVal != null + + // defaultVal is null but actualVal is not, raise exception + if (defaultVal == null && !(actualVal == null)) { + // only actual is set + throw new H2OIllegalArgumentException("Grid search model parameter '" + key + "' is set in both the model parameters and in the hyperparameters map. This is ambiguous; set it in one place or the other, not both."); } } } @@ -403,8 +442,10 @@ public MP nextModelParameters() { if (_currentHyperparamIndices != null) { // Fill array of hyper-values Object[] hypers = hypers(_currentHyperParams, _currentHyperParamNames, _currentHyperparamIndices); + // Get clone of parameters + MP commonModelParams = (MP) _params.clone(); // Fill model parameters - MP params = getModelParams(_params, hypers, _currentHyperParamNames); + MP params = getModelParams(commonModelParams, hypers, _currentHyperParamNames); return params; } else { @@ -419,7 +460,9 @@ public boolean hasNext() { int[] hyperParamIndicesCopy = new int[_currentHyperparamIndices.length]; System.arraycopy(_currentHyperparamIndices, 0, hyperParamIndicesCopy, 0, _currentHyperparamIndices.length); if (nextModelIndices(hyperParamIndicesCopy) == null) { - return _currentSubspace != _hyperParamSubspaces.length-1; + if(_currentSubspace == _hyperParamSubspaces.length - 1) { + return false; + } } } @@ -428,7 +471,7 @@ public boolean hasNext() { @Override public void onModelFailure(Model failedModel, Consumer withFailedModelHyperParams) { - // FIXME: when using parallel grid search, there's no good reason to think that the current hyperparam indices were the ones used for the failed model + // FIXME: when using parallel grid search, there's no good reason to think that the current hyperparam indices where the ones used for the failed model withFailedModelHyperParams.accept(hypers(_currentHyperParams, _currentHyperParamNames, _currentHyperparamIndices)); } @@ -468,13 +511,10 @@ private int[] nextModelIndices(int[] hyperparamIndices) { class RandomDiscreteValueWalker extends BaseWalker { - public static final String WEIGHTS_SUFFIX = "$weights"; // Used by HyperSpaceIterator.nextModelIndices to ensure that the space is explored enough before giving up private static final double MIN_NUMBER_OF_SAMPLES = 1e4; private Random _random; private boolean _set_model_seed_from_search_seed; // true if model parameter seed is set to default value and false otherwise - - private Map _weights; public RandomDiscreteValueWalker(MP params, Map hyperParams, @@ -482,39 +522,15 @@ public RandomDiscreteValueWalker(MP params, RandomDiscreteValueSearchCriteria search_criteria) { super(params, hyperParams, paramsBuilderFactory, search_criteria); - // seed the models using the search seed iff it is the only one specified - long defaultSeed = (long)_params.getParameterDefaultValue("_seed"); + // seed the models using the search seed if it is the only one specified + long defaultSeed = _defaultParams._seed; long actualSeed = _params._seed; long gridSeed = search_criteria.seed(); _set_model_seed_from_search_seed = defaultSeed == actualSeed && defaultSeed != gridSeed; - _random = RandomUtils.getRNG(gridSeed != defaultSeed ? gridSeed : System.nanoTime()); - } - - @Override - protected Map processHyperParams(Map hyperParams) { -// _weights = new HashMap<>(); -// return collectWeights(hyperParams, _weights); - return hyperParams; + _random = gridSeed == defaultSeed ? new Random() : new Random(gridSeed); } -/* - private Map collectWeights(Map hyperParams, Map weights) { - Map hp = new HashMap<>(); - for (Map.Entry e : hyperParams.entrySet()) { - String name = e.getKey(); - if (name.equals(SUBSPACES)) { - } else if (name.endsWith(WEIGHTS_SUFFIX)) { - _weights.put(name.substring(0, name.length()-WEIGHTS_SUFFIX.length()), (Integer[])e.getValue()); - } else { - _weights - } - } - - } -*/ - - - /** Based on the last model, the given array of ScoringInfo, and our stopping criteria should we stop early? */ + /** Based on the last model, the given array of ScoringInfo, and our stopping criteria should we stop early? */ @Override public boolean stopEarly(Model model, ScoringInfo[] sk) { return ScoreKeeper.stopEarly(ScoringInfo.scoreKeepers(sk), @@ -558,8 +574,10 @@ public MP nextModelParameters() { // Fill array of hyper-values Object[] hypers = hypers(_currentHyperParams, _currentHyperParamNames, _currentHyperparamIndices); + // Get clone of parameters + MP commonModelParams = (MP) _params.clone(); // Fill model parameters - MP params = getModelParams(_params, hypers, _currentHyperParamNames); + MP params = getModelParams(commonModelParams, hypers, _currentHyperParamNames); // add max_runtime_secs in search criteria into params if applicable if (_search_criteria != null && _search_criteria.strategy() == Strategy.RandomDiscrete) { diff --git a/h2o-core/src/main/java/hex/grid/SequentialWalker.java b/h2o-core/src/main/java/hex/grid/SequentialWalker.java index 74b2ad6b0c09..da16d7bcbf10 100644 --- a/h2o-core/src/main/java/hex/grid/SequentialWalker.java +++ b/h2o-core/src/main/java/hex/grid/SequentialWalker.java @@ -118,7 +118,7 @@ public boolean stopEarly(Model model, ScoringInfo[] sk) { } private MP getModelParams(MP params, Object[] hyperParams) { - ModelParametersBuilderFactory.ModelParametersBuilder paramsBuilder = _paramsBuilderFactory.get(params.freshCopy()); + ModelParametersBuilderFactory.ModelParametersBuilder paramsBuilder = _paramsBuilderFactory.get(params.clone()); for (int i = 0; i < _hyperParamNames.length; i++) { String paramName = _hyperParamNames[i]; Object paramValue = hyperParams[i]; diff --git a/h2o-core/src/main/java/hex/grid/SimpleParametersBuilderFactory.java b/h2o-core/src/main/java/hex/grid/SimpleParametersBuilderFactory.java deleted file mode 100644 index feaba387c360..000000000000 --- a/h2o-core/src/main/java/hex/grid/SimpleParametersBuilderFactory.java +++ /dev/null @@ -1,64 +0,0 @@ -package hex.grid; - -import hex.Model; -import hex.ModelParametersBuilderFactory; -import water.util.PojoUtils; - -/** - * The factory is producing a parameters builder which uses reflection to setup field values. - * - * @param type of model parameters object - */ -public class SimpleParametersBuilderFactory - implements ModelParametersBuilderFactory { - - @Override - public ModelParametersBuilder get(MP initialParams) { - return new SimpleParamsBuilder<>(initialParams); - } - - @Override - public PojoUtils.FieldNaming getFieldNamingStrategy() { - return PojoUtils.FieldNaming.CONSISTENT; - } - - /** - * The builder modifies initial model parameters directly by reflection. - *

- * Usage: - *

{@code
-   *   GBMModel.GBMParameters params =
-   *     new SimpleParamsBuilder(initialParams)
-   *      .set("_ntrees", 30)
-   *      .set("_learn_rate", 0.01)
-   *      .build()
-   * }
- * - * @param type of model parameters object - */ - public static class SimpleParamsBuilder - implements ModelParametersBuilder { - - final private MP params; - - public SimpleParamsBuilder(MP initialParams) { - params = initialParams; - } - - @Override - public boolean isAssignable(String name) { - return params.isParameterAssignable(name); - } - - @Override - public ModelParametersBuilder set(String name, Object value) { - PojoUtils.setField(params, name, value, PojoUtils.FieldNaming.CONSISTENT); - return this; - } - - @Override - public MP build() { - return params; - } - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/DataTransformer.java b/h2o-core/src/main/java/hex/pipeline/DataTransformer.java deleted file mode 100644 index f0473fe87a38..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/DataTransformer.java +++ /dev/null @@ -1,271 +0,0 @@ -package hex.pipeline; - -import hex.Parameterizable; -import hex.pipeline.TransformerChain.Completer; -import water.*; -import water.fvec.Frame; -import water.util.ArrayUtils; -import water.util.Checksum; -import water.util.IcedLong; -import water.util.PojoUtils; - -import java.util.Arrays; -import java.util.HashSet; -import java.util.Objects; -import java.util.Set; - -public abstract class DataTransformer> extends Lockable implements Parameterizable { - - public enum FrameType { - Training, - Validation, - Test - } - - private static final Set IGNORED_FIELDS_FOR_CHECKSUM = new HashSet(Arrays.asList( - "_key" - )); - - public boolean _enabled = true; // flag allowing to enable/disable transformers dynamically esp. in pipelines (can be used as a pipeline hyperparam in grids). - private String _name; - private String _description; - private Key _refCountKey; - private KeyGen _keyGen; - - protected DataTransformer() { - this(null); - } - - public DataTransformer(String name) { - this(name, null); - } - - public DataTransformer(String name, String description) { - super(null); - _name = name == null ? getClass().getSimpleName().toLowerCase()+Key.rand() : name; - _description = description == null ? getClass().getSimpleName().toLowerCase() : description; - } - - @SuppressWarnings("unchecked") - public SELF name(String name) { - _name = name; - return (SELF) this; - } - - public String name() { - return _name; - } - - @SuppressWarnings("unchecked") - public SELF description(String description) { - _description = description; - return (SELF) this; - } - - public String description() { - return _description; - } - - @SuppressWarnings("unchecked") - public SELF enable(boolean enabled) { - _enabled = enabled; - return (SELF) this; - } - - public SELF init() { - assert _name != null; - if (_keyGen == null) { - _keyGen = new KeyGen.PatternKeyGen("{0}_{n}"); - } - if (_refCountKey == null) { - _refCountKey = Key.make(_name+ "_refCount"); - DKV.put(_refCountKey, new IcedLong(0)); - } - if (_key == null || _key.get() == null) { - _key = _keyGen.make(_name); - DKV.put(this); - } - return (SELF) this; - } - - public boolean enabled() { return _enabled; } - - @Override - public boolean hasParameter(String name) { - try { - getParameter(name); - return true; - } catch (Exception e) { - return false; - } - } - - @Override - public Object getParameter(String name) { - return PojoUtils.getFieldValue(this, name); - } - - @Override - public void setParameter(String name, Object value) { - PojoUtils.setField(this, name, value); - } - - @Override - public boolean isParameterSetToDefault(String name) { - Object val = getParameter(name); - Object defaultVal = getParameterDefaultValue(name); - return Objects.deepEquals(val, defaultVal); - } - - @Override - public Object getParameterDefaultValue(String name) { - return getDefaults().getParameter(name); - } - - @Override - public boolean isParameterAssignable(String name) { - return isParameterSetToDefault(name); - } - - /** private use only to avoid this getting mutated. */ - private transient DataTransformer _defaults; - - /** private use only to avoid this getting mutated. */ - private DataTransformer getDefaults() { - if (_defaults == null) { - _defaults = makeDefaults(); - } - return _defaults; - } - - protected DataTransformer makeDefaults() { - try { - return getClass().newInstance(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - /** - * @return true iff the transformer needs to be applied in a specific way to training/validation frames during cross-validation. - */ - public boolean isCVSensitive() { - return false; - } - - public final void prepare(PipelineContext context) { - if (_enabled) { - if (_key != null) { - Key jobKey = context == null ? null : context._jobKey; - write_lock(jobKey); - doPrepare(context); - update(jobKey); - unlock(jobKey); - } else { - doPrepare(context); - } - } - } - - /** - * Transformers can implement this method if it needs to preparation before being able to do the transformations. - * @param context - */ - protected void doPrepare(PipelineContext context) {} - - protected void prepare(PipelineContext context, TransformerChain chain) { - assert chain != null; - prepare(context); - chain.nextPrepare(context); - } - - public final void cleanup() { - cleanup(new Futures()); - } - - public final void cleanup(Futures futures) { - if (_refCountKey == null || IcedLong.decrementAndGet(_refCountKey) <= 0) doCleanup(futures); - if (_key != null) DKV.remove(_key); - } - protected void doCleanup(Futures futures) { - remove(futures); - } - - public final Frame transform(Frame fr) { - return transform(fr, FrameType.Test, null); - } - - public final Frame transform(Frame fr, FrameType type, PipelineContext context) { - if (!_enabled || fr == null) return fr; - Frame trfr = doTransform(fr, type, context); - if (context != null && context._tracker != null) { - context._tracker.apply(trfr, fr, type, context, this); - } - return trfr; - } - - protected final Frame[] transform(Frame[] frames, FrameType[] types, PipelineContext context) { - assert frames != null; - assert types != null; - assert frames.length == types.length; - Frame[] transformed = new Frame[frames.length]; - for (int i=0; i R transform(Frame[] frames, FrameType[] types, PipelineContext context, Completer completer, TransformerChain chain) { - assert chain != null; - Frame[] transformed = transform(frames, types, context); - return chain.nextTransform(transformed, types, context, completer); - } - - /** - * Transformers must implement this method. They can ignore type and context if they're not context-sensitive. - * @param fr - * @param type - * @param context - * @return - */ - protected abstract Frame doTransform(Frame fr, FrameType type, PipelineContext context); - - @Override - public SELF freshCopy() { - SELF copy = clone(); - copy._key = _keyGen.make(_name); - DKV.put(copy); - IcedLong.incrementAndGet(_refCountKey); - return copy; - } - - @Override - protected Futures remove_impl(Futures fs, boolean cascade) { - if (_refCountKey != null) { - DKV.remove(_refCountKey); - _refCountKey = null; - } - return super.remove_impl(fs, cascade); - } - - @Override - public long checksum_impl() { - return Checksum.checksum(this, ignoredFieldsForChecksum()); - } - - protected Set ignoredFieldsForChecksum() { - return IGNORED_FIELDS_FOR_CHECKSUM; - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/FrameTracker.java b/h2o-core/src/main/java/hex/pipeline/FrameTracker.java deleted file mode 100644 index 8cb0bf04157a..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/FrameTracker.java +++ /dev/null @@ -1,15 +0,0 @@ -package hex.pipeline; - -import hex.pipeline.DataTransformer.FrameType; -import water.fvec.Frame; - -import java.io.Serializable; - -/** - * {@link FrameTracker}s are called after each transformation and can be used - * for consistent logging, debugging, renaming,... and various other tracking logic. - */ -@FunctionalInterface -public interface FrameTracker extends Serializable { - void apply(Frame transformed, Frame original, FrameType type, PipelineContext context, DataTransformer transformer); -} diff --git a/h2o-core/src/main/java/hex/pipeline/ModelParametersGenericPipelineBuilderFactory.java b/h2o-core/src/main/java/hex/pipeline/ModelParametersGenericPipelineBuilderFactory.java deleted file mode 100644 index a68241fe4114..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/ModelParametersGenericPipelineBuilderFactory.java +++ /dev/null @@ -1,71 +0,0 @@ -package hex.pipeline; - -import hex.Model; -import hex.ModelBuilder; -import hex.ModelParametersDelegateBuilderFactory; -import hex.pipeline.PipelineModel.PipelineParameters; -import water.util.Log; -import water.util.PojoUtils; -import water.util.PojoUtils.FieldNaming; - -import java.util.HashMap; -import java.util.Map; - -import static hex.ModelParametersGenericBuilderFactory.ALGO_PARAM; - -/** - * Similar to {@link hex.ModelParametersGenericBuilderFactory} but for pipelines: - * - pipeline estimator params can be created dynamically based on {@value hex.ModelParametersGenericBuilderFactory#ALGO_PARAM} hyper-param. - * - then other hyper-parameters can be set. - */ -public class ModelParametersGenericPipelineBuilderFactory extends ModelParametersDelegateBuilderFactory { - - public ModelParametersGenericPipelineBuilderFactory() { - super(); - } - - @Override - public ModelParametersBuilder get(PipelineParameters initialParams) { - return new GenericPipelineParamsBuilder(initialParams, fieldNaming); - } - - public static class GenericPipelineParamsBuilder extends DelegateParamsBuilder { - - private final Map hyperParams = new HashMap<>(); - - public GenericPipelineParamsBuilder(PipelineParameters params, FieldNaming fieldNaming) { - super(params, fieldNaming); - } - - @Override - public ModelParametersBuilder set(String name, Object value) { - hyperParams.put(name, value); - return this; - } - - @Override - public PipelineParameters build() { - PipelineParameters result = params; - Model.Parameters initEstimatorParams = result._estimatorParams; - String algo = null; - - if (hyperParams.containsKey(ALGO_PARAM)) { - algo = (String) hyperParams.get(ALGO_PARAM); - result._estimatorParams = ModelBuilder.makeParameters(algo); - if (initEstimatorParams != null) { - //add values from init estimator params - PojoUtils.copyProperties(result._estimatorParams, initEstimatorParams, FieldNaming.CONSISTENT); - } - } - for (Map.Entry e : hyperParams.entrySet()) { - if (ALGO_PARAM.equals(e.getKey())) continue; - if (algo == null || result.hasParameter(fieldNaming.toDest(e.getKey()))) { // no check for `result.hasParameter` in case of strict algo, so that we can fail on invalid param - result.setParameter(fieldNaming.toDest(e.getKey()), e.getValue()); - } else { // algo hyper-param was provided and this hyper-param is incompatible with it - Log.debug("Ignoring hyper-parameter `"+e.getKey()+"` unsupported by `"+algo+"`."); - } - } - return result; - } - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/Pipeline.java b/h2o-core/src/main/java/hex/pipeline/Pipeline.java deleted file mode 100644 index bb88101549f1..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/Pipeline.java +++ /dev/null @@ -1,282 +0,0 @@ -package hex.pipeline; - -import hex.Model; -import hex.ModelBuilder; -import hex.ModelBuilderCallbacks; -import hex.ModelCategory; -import hex.pipeline.DataTransformer.FrameType; -import hex.pipeline.trackers.*; -import hex.pipeline.PipelineModel.PipelineOutput; -import hex.pipeline.PipelineModel.PipelineParameters; -import water.*; -import water.exceptions.H2OModelBuilderIllegalArgumentException; -import water.fvec.Frame; - -import java.util.Arrays; - -import static hex.pipeline.PipelineHelper.reassign; - - -/** - * The {@link ModelBuilder} for {@link PipelineModel}s. - */ -public class Pipeline extends ModelBuilder { - - public Pipeline(PipelineParameters parms) { - super(parms); - init(false); - } - - public Pipeline(PipelineParameters parms, Key key) { - super(parms, key); - } - - public Pipeline(boolean startup_once) { -// super(new PipelineParameters(), startup_once, null); // no schema directory to completely disable schema lookup for now. - super(new PipelineParameters(), startup_once); - } - - @Override - public void init(boolean expensive) { - if (expensive) { - earlyValidateParams(); - if (_parms._transformers == null) _parms._transformers = new Key[0]; - DataTransformer[] transformers = _parms.getTransformers(); - _parms._transformers = Arrays.stream(transformers) - .filter(DataTransformer::enabled) - .map(DataTransformer::init) - .map(DataTransformer::getKey) - .toArray(Key[]::new); - Arrays.stream(transformers).filter(t -> !t.enabled()).forEach(DataTransformer::cleanup); - } - super.init(expensive); - } - - protected void earlyValidateParams() { - if (_parms._categorical_encoding != Model.Parameters.CategoricalEncodingScheme.AUTO) { - // we need to ensure that no transformation occurs before the transformers in the pipeline - hide("_categorical_encoding", - "Pipeline supports only AUTO categorical encoding: custom categorical encoding should be applied either as a transformer or directly to the final estimator of the pipeline."); - _parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.AUTO; - } - if (_parms._estimatorParams == null && nFoldCV()) { - error("_estimator", "Pipeline can use cross validation only if provided with an estimator."); - } - } - - @Override - protected PipelineDriver trainModelImpl() { - return new PipelineDriver(); - } - - @Override - public ModelCategory[] can_build() { - ModelBuilder finalBuilder = getFinalBuilder(); -// return finalBuilder == null ? new ModelCategory[] {ModelCategory.Unknown} : finalBuilder.can_build(); - return finalBuilder == null ? ModelCategory.values() : finalBuilder.can_build(); - } - - @Override - public boolean isSupervised() { - ModelBuilder finalBuilder = getFinalBuilder(); - return finalBuilder != null && finalBuilder.isSupervised(); - } - - private ModelBuilder getFinalBuilder() { - return _parms._estimatorParams == null ? null : ModelBuilder.make(_parms._estimatorParams.algoName(), null, null); - } - - - public class PipelineDriver extends Driver { - @Override - public void computeImpl() { - init(true); //also protects original train+valid frames - PipelineOutput output = new PipelineOutput(Pipeline.this); - PipelineModel model = new PipelineModel(dest(), _parms, output); - model.delete_and_lock(_job); - - try { - PipelineContext context = newContext(); - TransformerChain chain = newChain(context); - setTrain(context.getTrain()); - setValid(context.getValid()); - Scope.track(train(), valid()); //chain preparation may have provided extended/modified train/valid frames, so better track the current ones. - output._transformers = _parms._transformers.clone(); - if (_parms._estimatorParams == null) return; - try (Scope.Safe inner = Scope.safe(train(), valid())) { - output._estimator = chain.transform( - new Frame[]{train(), valid()}, - new FrameType[]{FrameType.Training, FrameType.Validation}, - context, - (frames, ctxt) -> { - // use params from the context as they may have been modified during chain preparation - ModelBuilder mb = makeEstimatorBuilder(_parms._estimatorKeyGen.make(_result), ctxt._params._estimatorParams, ctxt._params, frames[0], frames[1]); - Keyed res = mb.trainModelNested(null); - return res == null ? null : res.getKey(); - } - ); - } - } finally { - model.syncOutput(); - model.update(_job); - model.unlock(_job); - } - } - } - - @Override - public void computeCrossValidation() { - assert _parms._estimatorParams != null; // no CV if pipeline used as a pure transformer (see params validation) - PipelineModel model = null; - try { - Scope.enter(); - init(true); //also protects original train+valid frames - PipelineOutput output = new PipelineOutput(Pipeline.this); - model = new PipelineModel(dest(), _parms, output); - model.delete_and_lock(_job); - - PipelineContext context = newContext(); - TransformerChain chain = newChain(context); - setTrain(context.getTrain()); - setValid(context.getValid()); - Scope.track(train(), valid()); //chain preparation may have provided extended/modified train/valid frames. - try (Scope.Safe mainModelScope = Scope.safe(train(), valid())) { - output._transformers = _parms._transformers.clone(); - output._estimator = chain.transform( - new Frame[]{train(), valid()}, - new FrameType[]{FrameType.Training, FrameType.Validation}, - context, - (frames, ctxt) -> { - ModelBuilder mb = makeEstimatorBuilder(_parms._estimatorKeyGen.make(_result), ctxt._params._estimatorParams, ctxt._params, frames[0], frames[1]); - mb.setCallbacks(new ModelBuilderCallbacks() { - /** - * Using this callback, the transformations are applied at the time the CV model training is triggered, - * we don't have to stack up all transformed frames in memory BEFORE starting the CV-training. - */ - @Override - public void wrapCompute(ModelBuilder builder, Runnable compute) { - Model.Parameters params = builder._parms; - if (!params._is_cv_model || !chain.isCVSensitive()) { - compute.run(); - return; - } - - try (Scope.Safe cvModelComputeScope = Scope.safe(train(), params.train(), params.valid())) { - PipelineContext cvContext = newCVContext(context, params); - Scope.track(cvContext.getTrain(), cvContext.getValid()); - TransformerChain cvChain = chain.clone(); // as cv models can be trained in parallel - cvChain.transform( - new Frame[]{cvContext.getTrain(), cvContext.getValid()}, - new FrameType[]{FrameType.Training, FrameType.Validation}, - cvContext, - (cvFrames, ctxt) -> { - // ensure that generated vecs, that will be used to train+score this CV model, get deleted at the end of the pipeline training - track(cvFrames[0], true); - track(cvFrames[1], true); - reassign(cvFrames[0], params._train, _job.getKey()); - reassign(cvFrames[1], params._valid, _job.getKey()); - // re-init & re-validate the builder in case we produced a bad frame - // (although this should have been detected earlier as a similar transformation was already applied to main training frame) - builder._input_parms = params.clone(); - builder.setTrain(null); - builder.setValid(null); - builder.init(false); - if (builder.error_count() > 0) - throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(builder); - return null; - } - ); - compute.run(); - } - } - }); - mb.trainModelNested(null); - return mb.dest(); - } - ); - } - model.setInputParms(_input_parms); - } finally { - if (model != null) { - model.syncOutput(); - model.update(_job); - model.unlock(_job); - } - cleanUp(); - Scope.exit(); - } - } - - private PipelineContext newCVContext(PipelineContext context, Model.Parameters cvParams) { - PipelineContext cvContext = new PipelineContext(context._params, context._tracker, _job); - PipelineParameters pparams = cvContext._params; - pparams._is_cv_model = cvParams._is_cv_model; - pparams._cv_fold = cvParams._cv_fold; - Frame baseFrame = new Frame(Key.make(_result.toString()+"_cv_"+(pparams._cv_fold+1)), train().names(), train().vecs()); - if ( pparams._weights_column != null ) baseFrame.remove( pparams._weights_column ); - Frame cvTrainOld = cvParams.train(); - Frame cvValidOld = cvParams.valid(); - String cvWeights = cvParams._weights_column; - Frame cvTrain = new Frame(baseFrame); - cvTrain.add(cvWeights, cvTrainOld.vec(cvWeights)); - DKV.put(cvTrain); - Frame cvValid = new Frame(baseFrame); - cvValid.add(cvWeights, cvValidOld.vec(cvWeights)); - DKV.put(cvValid); - cvContext.setTrain(cvTrain); - cvContext.setValid(cvValid); - return cvContext; - } - - private PipelineContext newContext() { - return new PipelineContext( - _parms, - new CompositeFrameTracker( - new CancellationTracker(this), - new ConsistentKeyTracker(), - new ScopeTracker() - ), - _job - ); - } - - static class CancellationTracker extends AbstractFrameTracker { - - private final Pipeline _pipeline; - - public CancellationTracker() { this(null); } // for (de)serialization - - public CancellationTracker(Pipeline pipeline) { - _pipeline = pipeline; - } - - @Override - public void apply(Frame transformed, Frame original, DataTransformer.FrameType type, PipelineContext context, DataTransformer transformer) { - if (_pipeline.stop_requested()) throw new Job.JobCancelledException(_pipeline._job); - } - } - - private TransformerChain newChain(PipelineContext context) { - TransformerChain chain = new TransformerChain(_parms._transformers);//.init(); - chain.prepare(context); - return chain; - } - - private ModelBuilder makeEstimatorBuilder(Key eKey, Model.Parameters eParams, PipelineParameters pParams, Frame train, Frame valid) { - eParams._train = train == null ? null : train.getKey(); - eParams._valid = valid == null ? null : valid.getKey(); - eParams._response_column = pParams._response_column; - eParams._weights_column = pParams._weights_column; - eParams._offset_column = pParams._offset_column; - eParams._ignored_columns = pParams._ignored_columns; - eParams._fold_column = pParams._fold_column; - eParams._fold_assignment = pParams._fold_assignment; - eParams._nfolds= pParams._nfolds; - eParams._max_runtime_secs = pParams._max_runtime_secs > 0 ? remainingTimeSecs() : pParams._max_runtime_secs; - - ModelBuilder mb = ModelBuilder.make(eParams, eKey); - mb._job = _job; - return mb; - } - -} diff --git a/h2o-core/src/main/java/hex/pipeline/PipelineAlgoRegistration.java b/h2o-core/src/main/java/hex/pipeline/PipelineAlgoRegistration.java deleted file mode 100644 index 6a709fe28008..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/PipelineAlgoRegistration.java +++ /dev/null @@ -1,19 +0,0 @@ -package hex.pipeline; - -import water.api.AlgoAbstractRegister; -import water.api.PipelineHandler; -import water.api.RestApiContext; -import water.api.SchemaServer; - -public class PipelineAlgoRegistration extends AlgoAbstractRegister { - - @Override - public void registerEndPoints(RestApiContext context) { - Pipeline builder = new Pipeline(true); - registerModelBuilder(context, builder, SchemaServer.getStableVersion()); - - context.registerEndpoint("pipeline_datatransformer", "GET /3/Pipeline/DataTransformer/{key}", PipelineHandler.class, "fetchTransformer", - "Fetch a DataTransformer by its key"); - } - -} diff --git a/h2o-core/src/main/java/hex/pipeline/PipelineContext.java b/h2o-core/src/main/java/hex/pipeline/PipelineContext.java deleted file mode 100644 index 1f090791f4d0..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/PipelineContext.java +++ /dev/null @@ -1,61 +0,0 @@ -package hex.pipeline; - -import hex.pipeline.PipelineModel.PipelineParameters; -import water.Iced; -import water.Job; -import water.Key; -import water.fvec.Frame; - -/** - * A context object passed to the {@link DataTransformer}s (usually through a {@link TransformerChain} - * and providing useful information, especially to help some transformers - * to configure themselves/initialize during the {@link DataTransformer#prepare(PipelineContext)} phase. - */ -public class PipelineContext extends Iced { - - public final Key _jobKey; - - public final PipelineParameters _params; - - public final FrameTracker _tracker; - - private Frame _train; - private Frame _valid; - - - public PipelineContext(PipelineParameters params) { - this(params, null, null); - } - - public PipelineContext(PipelineParameters params, Job job) { - this(params, null, job); - } - - public PipelineContext(PipelineParameters params, FrameTracker tracker, Job job) { - assert params != null; - _jobKey = job == null ? null : job._key; - _params = (PipelineParameters) params.clone(); // cloning this here as those _params can be mutated during transformers' preparation. - _tracker = tracker; - } - - public Frame getTrain() { - return _train != null ? _train - : _params != null ? _params.train() - : null; - } - - public void setTrain(Frame train) { - _train = train; - } - - public Frame getValid() { - return _valid != null ? _valid - : _params != null ? _params.valid() - : null; - } - - public void setValid(Frame valid) { - _valid = valid; - } - -} diff --git a/h2o-core/src/main/java/hex/pipeline/PipelineHelper.java b/h2o-core/src/main/java/hex/pipeline/PipelineHelper.java deleted file mode 100644 index 6b956a591ea6..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/PipelineHelper.java +++ /dev/null @@ -1,36 +0,0 @@ -package hex.pipeline; - -import water.DKV; -import water.Job; -import water.Key; -import water.Scope; -import water.fvec.Frame; - -public final class PipelineHelper { - - private PipelineHelper() {} - - public static Frame reassign(Frame fr, Key key) { - return reassign(fr, key, null); - } - - public static Frame reassign(Frame fr, Key key, Key job) { - Frame copy = new Frame(fr); - DKV.remove(key); - copy._key = key; - copy.write_lock(job); - copy.update(job); - return copy; - } - - public static void reassignInplace(Frame fr, Key key) { - reassignInplace(fr, key, null); - } - - public static void reassignInplace(Frame fr, Key key, Key job) { - assert DKV.get(key) == null || DKV.getGet(key) == null; // inplace reassignment only to fresh/unassigned keys - if (fr.getKey() != null) DKV.remove(fr.getKey()); - fr._key = key; - DKV.put(fr); - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/PipelineModel.java b/h2o-core/src/main/java/hex/pipeline/PipelineModel.java deleted file mode 100644 index 84acf674e809..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/PipelineModel.java +++ /dev/null @@ -1,379 +0,0 @@ -package hex.pipeline; - -import hex.Model; -import hex.ModelBuilder; -import hex.ModelCategory; -import hex.ModelMetrics; -import hex.pipeline.DataTransformer.FrameType; -import hex.pipeline.trackers.CompositeFrameTracker; -import hex.pipeline.TransformerChain.UnaryCompleter; -import hex.pipeline.trackers.ConsistentKeyTracker; -import hex.pipeline.trackers.ScopeTracker; -import org.apache.commons.lang.StringUtils; -import water.*; -import water.KeyGen.PatternKeyGen; -import water.fvec.Frame; -import water.udf.CFuncRef; - -import java.util.Arrays; -import java.util.HashSet; -import java.util.Set; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import java.util.stream.Stream; - -/** - * A {@link PipelineModel} encapsulates in a single model a collection of transformations applied to train/validation - * before training a final delegate model (here called `estimator` to avoid confusion). - * For scoring, the same transformations are first applied to the test data and the result is used by the estimator model to provide the final score. - */ -public class PipelineModel extends Model { - - public static final String ESTIMATOR_PARAM = "estimator"; - private static final KeyGen TRANSFORM_KEY_GEN = new PatternKeyGen("{0}_trf_by_{1}"); - - public PipelineModel(Key selfKey, PipelineParameters parms, PipelineOutput output) { - super(selfKey, parms, output); - } - - @Override - public boolean havePojo() { - return false; - } - - @Override - public boolean haveMojo() { - return false; - } - - @Override - public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) { - throw new UnsupportedOperationException("PipelineModel.makeMetricBuilder should never be called!"); - } - - @Override - protected double[] score0(double[] data, double[] preds) { - throw H2O.unimpl("Pipeline can not score on raw data"); - } - - /* - @Override - protected PipelinePredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) { - Frame preds = doScore(adaptFrm, destination_key, j, computeMetrics, customMetricFunc); - ModelMetrics mm = null; - Model finalModel = _output.getFinalModel(); - if (computeMetrics && finalModel != null) { - // obtaining the model metrics from the final model - Key[] mms = finalModel._output.getModelMetrics(); - ModelMetrics lastComputedMetric = mms[mms.length - 1].get(); - mm = lastComputedMetric.deepCloneWithDifferentModelAndFrame(this, adaptFrm); - this.addModelMetrics(mm); - //now that we have the metric set on the pipeline model, removing the one we just computed on the delegate model (otherwise it leaks in client mode) - for (Key kmm : finalModel._output.clearModelMetrics(true)) { - DKV.remove(kmm); - } - } - String[] names = makeScoringNames(); - String[][] domains = makeScoringDomains(adaptFrm, computeMetrics, names); - ModelMetrics.MetricBuilder mb = makeMetricBuilder(domains[0]); - return new PipelinePredictScoreResult(mb, preds, mm); - } - */ - - @Override - public Frame score(Frame fr, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) throws IllegalArgumentException { - return doScore(fr, destination_key, j, computeMetrics, customMetricFunc); - } - - private Frame doScore(Frame fr, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) throws IllegalArgumentException { - if (fr == null) return null; - try (Scope.Safe s = Scope.safe(fr)) { - PipelineContext context = newContext(fr, j); - Frame result = newChain().transform(fr, FrameType.Test, context, new UnaryCompleter() { - @Override - public Frame apply(Frame frame, PipelineContext context) { - if (_output._estimator == null) { - return new Frame(Key.make(destination_key), frame.names(), frame.vecs()); - } - Frame result = _output._estimator.get().score(frame, destination_key, j, computeMetrics, customMetricFunc); - if (computeMetrics) { - ModelMetrics mm = ModelMetrics.getFromDKV(_output._estimator.get(), frame); - if (mm != null) addModelMetrics(mm.deepCloneWithDifferentModelAndFrame(PipelineModel.this, fr)); - } - return result; - } - }); - Scope.untrack(result); - DKV.put(result); - return result; - } - } - - /** - * applies all the pipeline transformers to the input frame. - * @param fr - * @return the transformed frame, as it would be passed to the estimator model if fr was used for predict/scoring. - */ - @Override - public Frame transform(Frame fr) { - if (fr == null) return null; - try (Scope.Safe s = Scope.safe(fr)) { - PipelineContext context = newContext(fr, null); - Frame result = newChain().transform(fr, FrameType.Test, context, new UnaryCompleter() { - @Override - public Frame apply(Frame frame, PipelineContext context) { - return new Frame(TRANSFORM_KEY_GEN.make(fr.getKey(), getKey()), frame.names(), frame.vecs()); - } - }); - Scope.untrack(result); - DKV.put(result); - return result; - } - } - - private TransformerChain newChain() { - //no need to call `prepare` on this chain as we're using the output transformers, which have been prepared during training. - return new TransformerChain(_output._transformers); - } - - private PipelineContext newContext(Frame fr, Job job) { - return new PipelineContext( - _parms, - new CompositeFrameTracker( - new ConsistentKeyTracker(fr), - new ScopeTracker() - ), - job); - } - - @Override - protected Futures remove_impl(Futures fs, boolean cascade) { - if (cascade) { - if (_output._transformers != null) { - for (DataTransformer dt : _output.getTransformers()) { - if (dt != null) dt.cleanup(fs); - } - } - if (_output._estimator != null) { - Keyed.remove(_output._estimator, fs, cascade); - } - } - return super.remove_impl(fs, cascade); - } - - void syncOutput() { - PipelineModel.PipelineOutput pmo = this._output; - if (pmo.getEstimatorModel() == null) return; - Model.Output mo = pmo.getEstimatorModel()._output; - if (mo._training_metrics != null) pmo._training_metrics = addModelMetrics(mo._training_metrics.deepCloneWithDifferentModelAndFrame(this, this._parms.train())); - if (mo._validation_metrics != null) pmo._validation_metrics = addModelMetrics(mo._validation_metrics.deepCloneWithDifferentModelAndFrame(this, this._parms.valid())); - if (mo._cross_validation_metrics != null) pmo._cross_validation_metrics = addModelMetrics(mo._cross_validation_metrics.deepCloneWithDifferentModelAndFrame(this, this._parms.train())); - pmo._cross_validation_metrics_summary = mo._cross_validation_metrics_summary; - pmo._cross_validation_fold_assignment_frame_id = mo._cross_validation_fold_assignment_frame_id; - pmo._cross_validation_holdout_predictions_frame_id = mo._cross_validation_holdout_predictions_frame_id; - pmo._cross_validation_predictions = mo._cross_validation_predictions; - pmo._cross_validation_models = mo._cross_validation_models; // FIXME: ideally, should be PipelineModels (build pipeline output pointing at CV model, use it for new pipeline model, etc.) - //...??? - } - - - /* - public class PipelinePredictScoreResult extends PredictScoreResult { - - private final ModelMetrics _modelMetrics; - public PipelinePredictScoreResult(ModelMetrics.MetricBuilder metricBuilder, Frame preds, ModelMetrics modelMetrics) { - super(metricBuilder, preds, preds); - _modelMetrics = modelMetrics; - } - - @Override - public ModelMetrics makeModelMetrics(Frame fr, Frame adaptFr) { - return _modelMetrics; - } - } - */ - - - public static class PipelineParameters extends Model.Parameters { - - static String ALGO = "Pipeline"; - - // think about Grids: we should be able to slightly modify grids to set nested hyperparams, for example "_transformers[1]._my_param", "_estimator._my_param" - // this doesn't have to work for all type of transformers, but for example for those wrapping a model (see ModelAsFeatureTransformer) and for the final estimator. - // as soon as we can do this, then we will be able to train pipelines in grids like any other model. - - public Key[] _transformers; - public Model.Parameters _estimatorParams; - public KeyGen _estimatorKeyGen = new PatternKeyGen("{0}_estimator"); - - @Override - public String algoName() { - return ALGO; - } - - @Override - public String fullName() { - return ALGO; - } - - @Override - public String javaName() { - return PipelineModel.class.getName(); - } - - @Override - public long progressUnits() { - return 0; - } - - @Override - public Object getParameter(String name) { - String[] tokens = parseParameterName(name); - if (tokens.length > 1) { - String tok0 = tokens[0]; - if (ESTIMATOR_PARAM.equals(tok0)) return _estimatorParams == null ? null : _estimatorParams.getParameter(tokens[1]); - DataTransformer dt = getTransformer(tok0); - return dt == null ? null : dt.getParameter(tokens[1]); - } - return super.getParameter(name); - } - - @Override - public void setParameter(String name, Object value) { - String[] tokens = parseParameterName(name); - if (tokens.length > 1) { - String tok0 = tokens[0]; - if (ESTIMATOR_PARAM.equals(tok0)) { - _estimatorParams.setParameter(tokens[1], value); - return; - } - DataTransformer dt = getTransformer(tok0); - if (dt != null) dt.setParameter(tokens[1], value); - return; - } - super.setParameter(name, value); - } - - @Override - public boolean isParameterAssignable(String name) { - String[] tokens = parseParameterName(name); - if (tokens.length > 1) { - String tok0 = tokens[0]; - if (ESTIMATOR_PARAM.equals(tok0)) return _estimatorParams == null ? null : _estimatorParams.isParameterAssignable(tokens[1]); - DataTransformer dt = getTransformer(tok0); - // for now allow transformers hyper params on non-defaults - return dt != null && dt.hasParameter(tokens[1]); -// return dt != null && dt.isValidHyperParameter(tokens[1]); - } - return super.isParameterAssignable(name); - } - - private static final Pattern TRANSFORMER_PAT = Pattern.compile("transformers\\[(\\w+)]"); - private String[] parseParameterName(String name) { - String[] tokens = name.split("\\.", 2); - if (tokens.length == 1) return tokens; - String tok0 = StringUtils.stripStart(tokens[0], "_"); - if (ESTIMATOR_PARAM.equals(tok0) || getTransformer(tok0) != null) { - return new String[]{tok0, tokens[1]} ; - } else { - Matcher m = TRANSFORMER_PAT.matcher(tok0); - if (m.matches()) { - String id = m.group(1); - try { - int idx = Integer.parseInt(id); - assert idx >=0 && idx < _transformers.length; - assert _transformers[idx].get() != null; - return new String[]{_transformers[idx].get().name(), tokens[1]}; - } catch(NumberFormatException nfe) { - if (getTransformer(id) != null) return new String[] {id, tokens[1]}; - throw new IllegalArgumentException("Unknown pipeline transformer: "+tok0); - } - } else { - throw new IllegalArgumentException("Unknown pipeline parameter: "+name); - } - } - } - - public DataTransformer[] getTransformers() { - if (_transformers == null) return null; - for (Key key : _transformers) { - DKV.prefetch(key); - } - return Arrays.stream(_transformers).map(Key::get).toArray(DataTransformer[]::new); - } - - public void setTransformers(DataTransformer... transformers) { - if (transformers == null) { - _transformers = null; - return; - } - _transformers = Arrays.stream(transformers) - .map(DataTransformer::init) - .map(DataTransformer::getKey) - .toArray(Key[]::new); - } - - private DataTransformer getTransformer(String id) { - if (_transformers == null) return null; - return Stream.of(getTransformers()).filter(t -> t.name().equals(id)).findFirst().orElse(null); - } - - @Override - public PipelineParameters freshCopy() { - PipelineParameters copy = (PipelineParameters) super.freshCopy(); - DataTransformer[] dts = getTransformers(); - copy._transformers = dts == null - ? null - : Arrays.stream(dts) - .map(DataTransformer::freshCopy) - .map(DataTransformer::getKey) - .toArray(Key[]::new); - return copy; - } - - @Override - protected Parameters cloneImpl() throws CloneNotSupportedException { - PipelineParameters clone = (PipelineParameters) super.cloneImpl(); - clone._transformers = _transformers == null ? null : _transformers.clone(); - clone._estimatorParams = _estimatorParams == null ? null : _estimatorParams.clone(); - return clone; - } - - @Override - public long checksum(Set ignoredFields) { - Set ignored = ignoredFields == null ? new HashSet<>() : new HashSet<>(ignoredFields); - ignored.add("_estimatorKeyGen"); - long xs = super.checksum(ignored); - return xs; - } - } - - public static class PipelineOutput extends Model.Output { - - public Key[] _transformers; - public Key _estimator; - - public PipelineOutput(ModelBuilder b) { - super(b); - } - - public DataTransformer[] getTransformers() { - if (_transformers == null) return null; - for (Key key : _transformers) { - DKV.prefetch(key); - } - return Arrays.stream(_transformers).map(Key::get).toArray(DataTransformer[]::new); - } - - public Model getEstimatorModel() { - return _estimator == null ? null : _estimator.get(); - } - - @Override - public ModelCategory getModelCategory() { - Model em = getEstimatorModel(); - return em == null ? super.getModelCategory() : em._output.getModelCategory(); - } - } - -} diff --git a/h2o-core/src/main/java/hex/pipeline/TransformerChain.java b/h2o-core/src/main/java/hex/pipeline/TransformerChain.java deleted file mode 100644 index 18238708118b..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/TransformerChain.java +++ /dev/null @@ -1,169 +0,0 @@ -package hex.pipeline; - -import water.DKV; -import water.Futures; -import water.Key; -import water.fvec.Frame; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.stream.Stream; - -/** - * A {@link DataTransformer} that calls multiple transformers as a chain of transformations, - * that can be optionally completed by a {@link Completer} whose result will be the result of the chain. - *
- * The chain accepts one or multiple {@link Frame}s as input, all going through the transformations at the same time, - * and all transformed frames can feed the final {@link Completer}. - *
- * The chain logic also allows transformers to create temporary resources and close/clean them only after the {@link Completer} has been evaluated. - */ -public class TransformerChain extends DataTransformer { - - /** - * The last operation applied to all the transformed frames. - * @param type of the final result - */ - @FunctionalInterface - interface Completer extends Serializable { - R apply(Frame[] frames, PipelineContext context); - } - - /** - * A {@link Completer} accepting a single {@link Frame}. - * @param type of the final result - */ - public static abstract class UnaryCompleter implements Completer { - @Override - public R apply(Frame[] frames, PipelineContext context) { - assert frames.length == 1; - return apply(frames[0], context); - } - - public abstract R apply(Frame frame, PipelineContext context); - } - - /** - * A default {@link Completer} simply returning all transformed frames. - */ - public static class AsFramesCompleter implements Completer { - public static final AsFramesCompleter INSTANCE = new AsFramesCompleter(); - - @Override - public Frame[] apply(Frame[] frames, PipelineContext context) { - return frames; - } - } - - /** - * A default {@link Completer} simply returning the transformed frame (when used in context of a single transformed frame). - */ - public static class AsSingleFrameCompleter implements Completer { - public static final AsSingleFrameCompleter INSTANCE = new AsSingleFrameCompleter(); - - public Frame apply(Frame[] frames, PipelineContext context) { - assert frames.length == 1; - return frames[0]; - } - } - - private final Key[] _transformers; - - private int _index; - - public TransformerChain(DataTransformer[] transformers) { - assert transformers != null; - _transformers = Stream.of(transformers) - .map(DataTransformer::init) - .map(DataTransformer::getKey) - .toArray(Key[]::new); - } - - public TransformerChain(Key[] transformers) { - assert transformers!= null; - _transformers = transformers.clone(); - } - - @Override - public TransformerChain init() { - super.init(); - Arrays.stream(getTransformers()).forEach(DataTransformer::init); - return this; - } - - private DataTransformer[] getTransformers() { - for (Key key : _transformers) { - DKV.prefetch(key); - } - return Arrays.stream(_transformers).map(Key::get).toArray(DataTransformer[]::new); - } - - @Override - public boolean isCVSensitive() { - return Arrays.stream(getTransformers()).anyMatch(DataTransformer::isCVSensitive); - } - - @Override - protected DataTransformer makeDefaults() { - return new TransformerChain(new Key[0]); - } - - @Override - protected void doPrepare(PipelineContext context) { - resetIteration(); - nextPrepare(context); - } - - final void nextPrepare(PipelineContext context) { - DataTransformer dt = next(); - if (dt != null) { - dt.prepare(context, this); - } - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - return transform(fr, type, context, AsSingleFrameCompleter.INSTANCE); - } - - final R transform(Frame fr, FrameType type, PipelineContext context, Completer completer) { - return transform(new Frame[] { fr }, new FrameType[] { type }, context, completer); - } - - final R transform(Frame[] frames, FrameType[] types, PipelineContext context, Completer completer) { - resetIteration(); - return nextTransform(frames, types, context, completer); - } - - final R nextTransform(Frame[] frames, FrameType[] types, PipelineContext context, Completer completer) { - DataTransformer dt = next(); - if (dt != null) { - return dt.transform(frames, types, context, completer, this); - } else if (completer != null) { - return completer.apply(frames, context); - } else { - return null; - } - } - - private DataTransformer next() { - if (_index >= _transformers.length) return null; - return _transformers[_index++].get(); - } - - private void resetIteration() { - _index = 0; - } - - @Override - protected Futures remove_impl(Futures fs, boolean cascade) { - if (cascade) { - if (_transformers != null) { - for (DataTransformer dt : getTransformers()) { - if (dt != null) dt.cleanup(fs); - } - } - } - return super.remove_impl(fs, cascade); - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/trackers/AbstractFrameTracker.java b/h2o-core/src/main/java/hex/pipeline/trackers/AbstractFrameTracker.java deleted file mode 100644 index d25f3290303f..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/trackers/AbstractFrameTracker.java +++ /dev/null @@ -1,15 +0,0 @@ -package hex.pipeline.trackers; - -import hex.pipeline.FrameTracker; -import water.Iced; - -/** - * This abstract class just makes it easier to provide good serialization support for its subclasses. - * The no-arg public constructor is implemented as a reminder that subclasses need to override it. - */ -public abstract class AbstractFrameTracker extends Iced implements FrameTracker { - - public AbstractFrameTracker() { - super(); - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/trackers/CompositeFrameTracker.java b/h2o-core/src/main/java/hex/pipeline/trackers/CompositeFrameTracker.java deleted file mode 100644 index c3bcbef8c414..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/trackers/CompositeFrameTracker.java +++ /dev/null @@ -1,29 +0,0 @@ -package hex.pipeline.trackers; - -import hex.pipeline.DataTransformer; -import hex.pipeline.FrameTracker; -import hex.pipeline.PipelineContext; -import water.fvec.Frame; - -/** - * A {@link FrameTracker} applying multiple trackers sequentially. - */ -public class CompositeFrameTracker extends AbstractFrameTracker { - - private final FrameTracker[] _trackers; - - public CompositeFrameTracker() { // for (de)serialization - _trackers = new FrameTracker[0]; - } - - public CompositeFrameTracker(FrameTracker... trackers) { - _trackers = trackers; - } - - @Override - public void apply(Frame transformed, Frame original, DataTransformer.FrameType type, PipelineContext context, DataTransformer transformer) { - for (FrameTracker tracker : _trackers) { - tracker.apply(transformed, original, type, context, transformer); - } - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/trackers/ConsistentKeyTracker.java b/h2o-core/src/main/java/hex/pipeline/trackers/ConsistentKeyTracker.java deleted file mode 100644 index c5b79e0627cf..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/trackers/ConsistentKeyTracker.java +++ /dev/null @@ -1,64 +0,0 @@ -package hex.pipeline.trackers; - -import hex.pipeline.DataTransformer; -import hex.pipeline.FrameTracker; -import hex.pipeline.PipelineContext; -import water.KeyGen; -import water.fvec.Frame; - -import static hex.pipeline.PipelineHelper.reassignInplace; - -/** - * A {@link FrameTracker} ensuring that all transformed frames in the pipeline are named consistently}, - * facilitating debugging and obtaining the origin of frames in the DKV. - */ -public class ConsistentKeyTracker extends AbstractFrameTracker { - - private static final String SEP = "@@"; // anything that doesn't contain Key.MAGIC_CHAR - private static final KeyGen DEFAULT_FRAME_KEY_GEN = new KeyGen.PatternKeyGen("{0}"+SEP+"{1}_trf_by_{2}_{rstr}"); - private final KeyGen _frameKeyGen; - - private final Frame _refFrame; - - public ConsistentKeyTracker() { - this(null, DEFAULT_FRAME_KEY_GEN); - } - - public ConsistentKeyTracker(Frame origin) { - this(origin, DEFAULT_FRAME_KEY_GEN); - } - - public ConsistentKeyTracker(Frame origin, KeyGen frameKeyGen) { - _refFrame = origin; - _frameKeyGen = frameKeyGen; - } - - private Frame getReference(DataTransformer.FrameType type, PipelineContext context) { - if (_refFrame != null) return _refFrame; - switch (type) { - case Training: - return context.getTrain(); - case Validation: - return context.getValid(); - case Test: - default: - return null; - } - } - - @Override - public void apply(Frame transformed, Frame original, DataTransformer.FrameType type, PipelineContext context, DataTransformer transformer) { - if (transformed == null) return; - Frame ref = getReference(type, context); - if (ref == null) return; - String refName = ref.getKey().toString(); - String frName = original.getKey().toString(); - if (!frName.startsWith(refName)) - return; // all successive frames must have the same naming pattern when using this tracker -> doesn't apply to this frame. - - if (transformed != original) { - String baseName = frName.contains(SEP) ? frName.substring(0, frName.lastIndexOf(SEP)) : frName; - reassignInplace(transformed, _frameKeyGen.make(baseName, type, transformer.name())); - } - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/trackers/ScopeTracker.java b/h2o-core/src/main/java/hex/pipeline/trackers/ScopeTracker.java deleted file mode 100644 index 8a70b448168b..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/trackers/ScopeTracker.java +++ /dev/null @@ -1,18 +0,0 @@ -package hex.pipeline.trackers; - -import hex.pipeline.DataTransformer; -import hex.pipeline.FrameTracker; -import hex.pipeline.PipelineContext; -import water.Scope; -import water.fvec.Frame; - -/** - * a {@link FrameTracker} ensuring that all transformed framed are added to current {@link Scope}. - */ -public class ScopeTracker extends AbstractFrameTracker { - @Override - public void apply(Frame transformed, Frame original, DataTransformer.FrameType type, PipelineContext context, DataTransformer transformer) { - if (transformed == null) return; - Scope.track(transformed); - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/transformers/CachingTransformer.java b/h2o-core/src/main/java/hex/pipeline/transformers/CachingTransformer.java deleted file mode 100644 index e2af25a5b56c..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/transformers/CachingTransformer.java +++ /dev/null @@ -1,83 +0,0 @@ -package hex.pipeline.transformers; - -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineContext; -import water.DKV; -import water.Key; -import water.fvec.Frame; -import water.nbhm.NonBlockingHashMap; -import water.util.FrameUtils; - -import java.util.Collection; - -/** - * WIP: not ready for production usage for now due to memory + frame lifecycle issues. - * If a Frame is cached, then returning a shallow copy is not enough as the individual Vecs could then be removed from DKV. - * Deep copy would however increase the memory cost of caching. - */ -public class CachingTransformer, T extends DataTransformer> extends DelegateTransformer { - - boolean _cacheEnabled = true; - private final NonBlockingHashMap> _cache = new NonBlockingHashMap<>(); - - protected CachingTransformer() {} - - public CachingTransformer(T transformer) { - super(transformer); - } - - void enableCache(boolean enabled) { - _cacheEnabled = enabled; - } - - boolean isCacheEnabled() { - return _cacheEnabled; - } - - private Object makeCachingKey(Frame fr, FrameType type, PipelineContext context) { - // this way, works only for simple transformations, not if it is type/context-sensitive. - // The most important is to have something that would prevent transformation again and again: - // - for each model using the main training frame in AutoML. - // - for each model scoring the validation or leaderboard frame in AutoML. - // - for each cv-training frame based on the same main training frame for each model in AutoML. - // this makes about max 3+nfolds frame to cache, which is not much. - // We can probably rely on the frame checksum instead of the frame key as a caching key. - // -// return fr.getKey(); - return fr.checksum(); - } - - private Frame copy(Frame fr, Key newKey) { - Frame cp = new Frame(fr); - cp._key = newKey; - return cp; - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - if (!isCacheEnabled()) return super.doTransform(fr, type, context); - - final Object cachingKey = makeCachingKey(fr, type, context); - if (_cache.containsKey(cachingKey)) { - Frame cached = _cache.get(cachingKey).get(); - if (cached == null) { - _cache.remove(cachingKey); - } else { - return copy(cached, Key.make(cached.getKey()+".copy")); - } - } - Frame transformed = super.doTransform(fr, type, context); - Frame cached = copy(transformed, Key.make(fr.getKey()+".cached")); - DKV.put(cached); //??? how can we guarantee that cached Frame(s)/Vec(s) won't be deleted at the end of a single model training? - // we want those to remain accessible as long as the cache is "alive", - // they should only be removed from DKV when cache (or DKV) is cleared explicitly. - // Could we flag some objects in DKV as "cached"/"persistent" and can be removed from DKV only through a special call to `DKV.remove`? Use special key? - _cache.put(cachingKey, cached.getKey()); - return transformed; - } - - public void clearCache() { - FrameUtils.cleanUp((Collection)_cache.values()); - _cache.clear(); - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/transformers/DelegateTransformer.java b/h2o-core/src/main/java/hex/pipeline/transformers/DelegateTransformer.java deleted file mode 100644 index 443210451db1..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/transformers/DelegateTransformer.java +++ /dev/null @@ -1,69 +0,0 @@ -package hex.pipeline.transformers; - -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineContext; -import water.Key; -import water.fvec.Frame; - -public abstract class DelegateTransformer, T extends DataTransformer> extends DataTransformer { - - Key _transformer; - - protected DelegateTransformer() {} - - public DelegateTransformer(Key transformer) { - this._transformer = transformer; - } - - public DelegateTransformer(T transformer) { - _transformer = transformer._key; - } - - private DataTransformer getDelegate() { - return _transformer.get(); - } - - @Override - public Object getParameter(String name) { - try { - return getDelegate().getParameter(name); - } catch (IllegalArgumentException iae) { - return super.getParameter(name); - } - } - - @Override - public void setParameter(String name, Object value) { - try { - getDelegate().setParameter(name, value); - } catch (IllegalArgumentException iae) { - super.setParameter(name, value); - } - } - - @Override - public Object getParameterDefaultValue(String name) { - try { - return getDelegate().getParameterDefaultValue(name); - } catch (IllegalArgumentException iae) { - return super.getParameterDefaultValue(name); - } - } - - @Override - protected void doPrepare(PipelineContext context) { - getDelegate().prepare(context); - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - return getDelegate().transform(fr, type, context); - } - - @Override - protected S cloneImpl() throws CloneNotSupportedException { - S clone = super.cloneImpl(); - clone._transformer = _transformer == null ? null : ((T)getDelegate().clone()).getKey(); - return clone; - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/transformers/FeatureTransformer.java b/h2o-core/src/main/java/hex/pipeline/transformers/FeatureTransformer.java deleted file mode 100644 index aae7d6a4d823..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/transformers/FeatureTransformer.java +++ /dev/null @@ -1,15 +0,0 @@ -package hex.pipeline.transformers; - -import hex.pipeline.DataTransformer; - -/** - * a DataTransformer that never modifies the response column - */ -public abstract class FeatureTransformer> extends DataTransformer { - - private String[] _excluded_columns; - - public void excludeColumns(String[] columns) { - _excluded_columns = columns; - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/transformers/FilteringTransformer.java b/h2o-core/src/main/java/hex/pipeline/transformers/FilteringTransformer.java deleted file mode 100644 index 7688a63f7041..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/transformers/FilteringTransformer.java +++ /dev/null @@ -1,24 +0,0 @@ -package hex.pipeline.transformers; - -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineContext; -import water.fvec.Frame; - -/** - * WiP: not used for now. - * An abstract transformer to sample/filter the input the frame. - */ -public abstract class FilteringTransformer> extends DataTransformer { - - boolean _filterEnabled = true; - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - if (_filterEnabled) { - return filterRows(fr); - } - return fr; - } - - public abstract Frame filterRows(Frame fr); -} diff --git a/h2o-core/src/main/java/hex/pipeline/transformers/KFoldColumnGenerator.java b/h2o-core/src/main/java/hex/pipeline/transformers/KFoldColumnGenerator.java deleted file mode 100644 index a94b15a00829..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/transformers/KFoldColumnGenerator.java +++ /dev/null @@ -1,115 +0,0 @@ -package hex.pipeline.transformers; - -import hex.Model.Parameters.FoldAssignmentScheme; -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineContext; -import water.DKV; -import water.KeyGen; -import water.Scope; -import water.fvec.Frame; -import water.fvec.Vec; -import water.rapids.ast.prims.advmath.AstKFold; - -public class KFoldColumnGenerator extends DataTransformer { - - private static final int DEFAULT_NFOLDS = 5; - - static String FOLD_COLUMN_PREFIX = "__fold__"; - - private String _fold_column; - private FoldAssignmentScheme _scheme; - - private int _nfolds; - private long _seed; - - private String _response_column; - - private final KeyGen _trainWFoldKeyGen = new KeyGen.PatternKeyGen("{0}_wfoldc"); - - public KFoldColumnGenerator() { - this(null); - } - - public KFoldColumnGenerator(String foldColumn) { - this(foldColumn, null, -1, -1); - } - - public KFoldColumnGenerator(String foldColumn, FoldAssignmentScheme scheme, int nfolds, long seed) { - super(); - _fold_column = foldColumn; - _scheme = scheme; - _nfolds = nfolds; - _seed = seed; - } - - @Override - protected void doPrepare(PipelineContext context) { - assert context != null; - assert context._params != null; - if (_fold_column == null) _fold_column = context._params._fold_column; - if (_fold_column == null) _fold_column = FOLD_COLUMN_PREFIX+context._params._response_column; - - if (_scheme == null) _scheme = context._params._fold_assignment; - if (_scheme == null) _scheme = FoldAssignmentScheme.AUTO; - - if (_nfolds <= 0) _nfolds = context._params._nfolds; - if (_nfolds <= 0) _nfolds = DEFAULT_NFOLDS; - - if (_seed < 0) _seed = context._params.getOrMakeRealSeed(); - - if (_response_column == null) _response_column = context._params._response_column; - assert !(_response_column == null && _scheme == FoldAssignmentScheme.Stratified); - - if (context.getTrain() != null && context.getTrain().find(_fold_column) < 0) { - Frame withFoldC = doTransform(context.getTrain(), FrameType.Training, context); - withFoldC._key = _trainWFoldKeyGen.make(context.getTrain()._key); - DKV.put(withFoldC); - Scope.track(withFoldC); - context.setTrain(withFoldC); - } - // now that we have a fold column, reassign cv params to avoid confusion - context._params._fold_column = _fold_column; - context._params._nfolds = 0; - context._params._fold_assignment = FoldAssignmentScheme.AUTO; - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - if (type == FrameType.Training && fr.find(_fold_column) < 0) { - Vec foldColumn = createFoldColumn( - fr, - _scheme, - _nfolds, - _response_column, - _seed - ); - Frame withFoldc = new Frame(fr); - withFoldc.add(_fold_column, foldColumn); - return withFoldc; - } - return fr; - } - - static Vec createFoldColumn(Frame fr, - FoldAssignmentScheme fold_assignment, - int nfolds, - String responseColumn, - long seed) { - Vec foldColumn; - switch (fold_assignment) { - default: - case AUTO: - case Random: - foldColumn = AstKFold.kfoldColumn(fr.anyVec().makeZero(), nfolds, seed); - break; - case Modulo: - foldColumn = AstKFold.moduloKfoldColumn(fr.anyVec().makeZero(), nfolds); - break; - case Stratified: - foldColumn = AstKFold.stratifiedKFoldColumn(fr.vec(responseColumn), nfolds, seed); - break; - } - return foldColumn; - } - -} diff --git a/h2o-core/src/main/java/hex/pipeline/transformers/ModelAsFeatureTransformer.java b/h2o-core/src/main/java/hex/pipeline/transformers/ModelAsFeatureTransformer.java deleted file mode 100644 index 15118aa5abb9..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/transformers/ModelAsFeatureTransformer.java +++ /dev/null @@ -1,229 +0,0 @@ -package hex.pipeline.transformers; - -import hex.Model; -import hex.ModelBuilder; -import hex.pipeline.PipelineContext; -import water.*; -import water.KeyGen.PatternKeyGen; -import water.fvec.Frame; -import water.util.IcedHashMap; -import water.util.Log; - -import java.util.Arrays; -import java.util.HashSet; -import java.util.Set; - -public class ModelAsFeatureTransformer, M extends Model, MP extends Model.Parameters> extends FeatureTransformer { - - private static final Set IGNORED_FIELDS_FOR_CHECKSUM = new HashSet(Arrays.asList( - "_modelKey", "_modelsCacheKey" - )); - - protected MP _params; - private Key _modelKey; - - private boolean _cacheEnabled; - private Key> _modelsCacheKey; - private final KeyGen _modelKeyGen; - private final int _model_type; - - - - protected ModelAsFeatureTransformer() { - this(null); - } - - public ModelAsFeatureTransformer(MP params) { - this(params, null); - } - - public ModelAsFeatureTransformer(MP params, Key modelKey) { - _params = params; - _modelKey = modelKey; - _modelKeyGen = modelKey == null - ? new PatternKeyGen("{0}_{n}_model") // if no modelKey provided, then a new key and its corresponding model is trained for each - : new KeyGen.ConstantKeyGen(modelKey); // if modelKey provided, only use that one - _model_type = params == null ? TypeMap.NULL : TypeMap.getIcedId(params.javaName()); - } - - public S enableCache() { - _cacheEnabled = true; - return (S) this; - } - - @Override - public S init() { - S self = super.init(); - if (_cacheEnabled && _modelsCacheKey == null) { - _modelsCacheKey = Key.make(name()+"_cache"); - DKV.put(new ModelsCache(_modelsCacheKey)); - } - return self; - } - - public M getModel() { - return _modelKey == null ? null : _modelKey.get(); - } - - private ModelsCache getCache() { - return _modelsCacheKey == null ? null : _modelsCacheKey.get(); - } - - @SuppressWarnings("unchecked") - protected Key lookupModel(MP params) { - long cs = params.checksum(); - ModelsCache cache = getCache(); - Key k = cache == null ? null : cache.get(cs); - if (k != null && k.get() != null) return k; - return KeySnapshot.globalSnapshot().findFirst(ki -> { - if (ki._type != _model_type) return false; - M m = (M) ki._key.get(); - return m != null && m._parms != null && m._output != null && m._output._end_time > 0 && cs == m._parms.checksum(); - }); - } - - - @Override - public Object getParameter(String name) { - try { - return _params.getParameter(name); - } catch (IllegalArgumentException iae) { - return super.getParameter(name); - } - } - - @Override - public void setParameter(String name, Object value) { - try { - _modelKey = null; //consider this as a completely new transformer as soon as we're trying new hyper-parameters. - _params.setParameter(name, value); - } catch (IllegalArgumentException iae) { - super.setParameter(name, value); - } - } - - @Override - public Object getParameterDefaultValue(String name) { - try { - return _params.getParameterDefaultValue(name); - } catch (IllegalArgumentException iae) { - return super.getParameterDefaultValue(name); - } - } - - protected void doPrepare(PipelineContext context) { - if (getModel() != null) return; // if modelKey was provided, use it immediately. - prepareModelParams(context); - excludeColumns(_params.getNonPredictors()); - Key km = lookupModel(_params); - if (km == null || (_modelKey != null && !km.equals(_modelKey))) { - if (_modelKey == null) _modelKey = _modelKeyGen.make(name()); - ModelBuilder mb = ModelBuilder.make(_params, _modelKey); - mb.trainModel().get(); - ModelsCache cache = getCache(); - if (cache != null) cache.put(_params.checksum(), _modelKey); - } else { - _modelKey = km; - } - } - - protected void prepareModelParams(PipelineContext context) { - assert context != null; - // to train the model, we use the train frame from context by default - - if (_params._train == null) { // do not propagate params from context otherwise as they were defined for the default training frame - assert context.getTrain() != null; - _params._train = context.getTrain().getKey(); - if (_params._valid == null && context.getValid() != null) _params._valid = context.getValid().getKey(); - if (_params._response_column == null) _params._response_column = context._params._response_column; - if (_params._weights_column == null) _params._weights_column = context._params._weights_column; - if (_params._offset_column == null) _params._offset_column = context._params._offset_column; - if (_params._ignored_columns == null) _params._ignored_columns = context._params._ignored_columns; - if (isCVSensitive()) { - MP defaults = ModelBuilder.makeParameters(_params.algoName()); - if (_params._fold_column == null) _params._fold_column = context._params._fold_column; - if (_params._fold_assignment == defaults._fold_assignment) - _params._fold_assignment = context._params._fold_assignment; - if (_params._nfolds == defaults._nfolds) _params._nfolds = context._params._nfolds; - } - } - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - validateTransform(); - return fr == null ? null : getModel().transform(fr); - } - - protected void validateTransform() { - assert getModel() != null; - } - - @Override - protected S cloneImpl() throws CloneNotSupportedException { - ModelAsFeatureTransformer clone = super.cloneImpl(); - clone._params = _params == null ? null : (MP) _params.clone(); - return (S) clone; - } - - @Override - protected Futures remove_impl(Futures fs, boolean cascade) { - if (cascade) { - Keyed.removeQuietly(_modelKey); _modelKey = null; - Keyed.removeQuietly(_modelsCacheKey); _modelsCacheKey = null; - } - return super.remove_impl(fs, cascade); - } - - @Override - protected Set ignoredFieldsForChecksum() { - Set ignored = new HashSet<>(super.ignoredFieldsForChecksum()); - ignored.addAll(IGNORED_FIELDS_FOR_CHECKSUM); - return ignored; - } - - private static class ModelsCache extends Keyed> { - private IcedHashMap> _modelsCache; - - public ModelsCache(Key> key) { - super(key); - _modelsCache = new IcedHashMap<>(); - } - - @Override - protected Futures remove_impl(Futures fs, boolean cascade) { - if (cascade) { - Log.info("Clearing models cache "+_key+": "+_modelsCache.values()); - for (Key k : _modelsCache.values()) { - Keyed.removeQuietly(k); - } - _modelsCache.clear(); - } - return super.remove_impl(fs, cascade); - } - - private Key get(Long cacheKey) { - return _modelsCache.get(cacheKey); - } - - private void put(Long cacheKey, Key key) { - new PutInCache<>(cacheKey, key).invoke(getKey()); - } - - private static class PutInCache extends TAtomic> { - - Long cacheKey; - Key modelKey; - public PutInCache(Long cacheKey, Key key) { - this.cacheKey = cacheKey; - this.modelKey = key; - } - - @Override - protected ModelsCache atomic(ModelsCache old) { - old._modelsCache.put(cacheKey, modelKey); - return old; - } - } - } -} diff --git a/h2o-core/src/main/java/hex/pipeline/transformers/UnionTransformer.java b/h2o-core/src/main/java/hex/pipeline/transformers/UnionTransformer.java deleted file mode 100644 index 6ce378e9c76d..000000000000 --- a/h2o-core/src/main/java/hex/pipeline/transformers/UnionTransformer.java +++ /dev/null @@ -1,117 +0,0 @@ -package hex.pipeline.transformers; - -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineContext; -import org.apache.commons.lang.StringUtils; -import water.fvec.Frame; - -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import java.util.stream.Stream; - -/** - * This transformer applies several independent transformers (possibly in parallel) to the same input frame. - * The results of those transformations to the input frame are then concatenated to produce the result frame, - * or possibly appended to the input frame. - */ -public class UnionTransformer extends DataTransformer { - - public enum UnionStrategy { - append, - replace - } - - private DataTransformer[] _transformers; - private UnionStrategy _strategy; - - - protected UnionTransformer() {} - - public UnionTransformer(UnionStrategy strategy, DataTransformer... transformers) { - _strategy = strategy; - _transformers = transformers; - } - - @Override - public Object getParameter(String name) { - String[] tokens = parseParameterName(name); - if (tokens.length > 1) { - String tok0 = tokens[0]; - DataTransformer dt = getTransformer(tok0); - return dt == null ? null : dt.getParameter(tokens[1]); - } - return super.getParameter(name); - } - - @Override - public void setParameter(String name, Object value) { - String[] tokens = parseParameterName(name); - if (tokens.length > 1) { - String tok0 = tokens[0]; - DataTransformer dt = getTransformer(tok0); - if (dt != null) dt.setParameter(tokens[1], value); - return; - } - super.setParameter(name, value); - } - - @Override - public boolean isParameterAssignable(String name) { - String[] tokens = parseParameterName(name); - if (tokens.length > 1) { - String tok0 = tokens[0]; - DataTransformer dt = getTransformer(tok0); - return dt != null && dt.hasParameter(tokens[1]); - } - return super.isParameterAssignable(name); - } - - //TODO similar logic as in PipelineParameters: delegate to some kind of ModelParametersAccessor? - private static final Pattern TRANSFORMER_PAT = Pattern.compile("transformers\\[(\\w+)]"); - private String[] parseParameterName(String name) { - String[] tokens = name.split("\\.", 2); - if (tokens.length == 1) return tokens; - String tok0 = StringUtils.stripStart(tokens[0], "_"); - if (getTransformer(tok0) != null) { - return new String[]{tok0, tokens[1]} ; - } else { - Matcher m = TRANSFORMER_PAT.matcher(tok0); - if (m.matches()) { - String id = m.group(1); - try { - int idx = Integer.parseInt(id); - assert idx >=0 && idx < _transformers.length; - return new String[]{_transformers[idx].name(), tokens[1]}; - } catch(NumberFormatException nfe) { - if (getTransformer(id) != null) return new String[] {id, tokens[1]}; - throw new IllegalArgumentException("Unknown transformer: "+tok0); - } - } else { - throw new IllegalArgumentException("Unknown parameter: "+name); - } - } - } - - private DataTransformer getTransformer(String id) { - if (_transformers == null) return null; - return Stream.of(_transformers).filter(t -> t.name().equals(id)).findFirst().orElse(null); - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - Frame result = null; - switch (_strategy) { - case append: - result = new Frame(fr); - break; - case replace: - result = new Frame(); - break; - } - for (DataTransformer dt : _transformers) { - result.add(dt.transform(fr, type, context)); - } - return result; - } - -} diff --git a/h2o-core/src/main/java/hex/schemas/ClientDataTransformer.java b/h2o-core/src/main/java/hex/schemas/ClientDataTransformer.java deleted file mode 100644 index eaf3e1e6b9d1..000000000000 --- a/h2o-core/src/main/java/hex/schemas/ClientDataTransformer.java +++ /dev/null @@ -1,14 +0,0 @@ -package hex.schemas; - -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineContext; -import water.Key; -import water.fvec.Frame; - -public class ClientDataTransformer extends DataTransformer { - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - throw new UnsupportedOperationException("this transformer is for client rendering only and does not support `transform`"); - } -} diff --git a/h2o-core/src/main/java/hex/schemas/DataTransformerV3.java b/h2o-core/src/main/java/hex/schemas/DataTransformerV3.java deleted file mode 100644 index f7df5d79e158..000000000000 --- a/h2o-core/src/main/java/hex/schemas/DataTransformerV3.java +++ /dev/null @@ -1,24 +0,0 @@ -package hex.schemas; - -import hex.pipeline.DataTransformer; -import water.api.API; -import water.api.schemas3.KeyV3; -import water.api.schemas3.SchemaV3; - -public class DataTransformerV3> extends SchemaV3 { - - @API(help="Transformer key", direction=API.Direction.INOUT) - public KeyV3.DataTransformerKeyV3 key; - - @API(help="Transformer name (must be unique in the pipeline)", direction=API.Direction.OUTPUT) - public String name; - - @API(help="A short description of this transformer", direction=API.Direction.OUTPUT) - public String description; - - @Override - public D createImpl() { - // later we can create more specific transformers if we need to expose the internals. - return (D) new ClientDataTransformer(); - } -} diff --git a/h2o-core/src/main/java/hex/schemas/GridSearchSchema.java b/h2o-core/src/main/java/hex/schemas/GridSearchSchema.java index 812a6da0180f..21dcf6facfb3 100644 --- a/h2o-core/src/main/java/hex/schemas/GridSearchSchema.java +++ b/h2o-core/src/main/java/hex/schemas/GridSearchSchema.java @@ -15,7 +15,7 @@ import java.util.*; -import static hex.grid.HyperSpaceWalker.SUBSPACES; +import static hex.grid.HyperSpaceWalker.BaseWalker.SUBSPACES; import static water.api.API.Direction.INOUT; import static water.api.API.Direction.INPUT; diff --git a/h2o-core/src/main/java/hex/schemas/PipelineModelV3.java b/h2o-core/src/main/java/hex/schemas/PipelineModelV3.java deleted file mode 100644 index 1661b79f5ba9..000000000000 --- a/h2o-core/src/main/java/hex/schemas/PipelineModelV3.java +++ /dev/null @@ -1,40 +0,0 @@ -package hex.schemas; - -import hex.pipeline.PipelineModel; -import hex.pipeline.PipelineModel.PipelineOutput; -import hex.pipeline.PipelineModel.PipelineParameters; -import water.api.API; -import water.api.schemas3.KeyV3; -import water.api.schemas3.ModelOutputSchemaV3; -import water.api.schemas3.ModelSchemaV3; - -public class PipelineModelV3 extends ModelSchemaV3< - PipelineModel, PipelineModelV3, - PipelineParameters, PipelineV3.PipelineParametersV3, - PipelineOutput, PipelineModelV3.PipelineModelOutputV3 - > { - - @Override - public PipelineV3.PipelineParametersV3 createParametersSchema() { - return new PipelineV3.PipelineParametersV3(); - } - - @Override - public PipelineModelOutputV3 createOutputSchema() { - return new PipelineModelOutputV3(); - } - - public static final class PipelineModelOutputV3 extends ModelOutputSchemaV3 { - - @API(help="Sequence of transformers applied to input data.", direction = API.Direction.OUTPUT) - public KeyV3.DataTransformerKeyV3[] transformers; - - @API(help="Estimator model trained and/or applied after transformations.", direction = API.Direction.OUTPUT) - public KeyV3.ModelKeyV3 estimator; - - @Override - public PipelineModelOutputV3 fillFromImpl(PipelineOutput impl) { - return super.fillFromImpl(impl); - } - } -} diff --git a/h2o-core/src/main/java/hex/schemas/PipelineV3.java b/h2o-core/src/main/java/hex/schemas/PipelineV3.java deleted file mode 100644 index ed3796894086..000000000000 --- a/h2o-core/src/main/java/hex/schemas/PipelineV3.java +++ /dev/null @@ -1,20 +0,0 @@ -package hex.schemas; - -import hex.pipeline.Pipeline; -import hex.pipeline.PipelineModel; -import water.api.schemas3.ModelParametersSchemaV3; - -public class PipelineV3 extends ModelBuilderSchema { - - public static final class PipelineParametersV3 extends ModelParametersSchemaV3 { - static public String[] fields = new String[] { - "model_id", -// "training_frame", -// "validation_frame", -// "response_column", -// "nfolds", -// "fold_column", - }; - - } -} diff --git a/h2o-core/src/main/java/water/Checksumable.java b/h2o-core/src/main/java/water/Checksumable.java deleted file mode 100644 index 179e4de03dc6..000000000000 --- a/h2o-core/src/main/java/water/Checksumable.java +++ /dev/null @@ -1,6 +0,0 @@ -package water; - -public interface Checksumable { - - long checksum(); -} diff --git a/h2o-core/src/main/java/water/Freezable.java b/h2o-core/src/main/java/water/Freezable.java index 7eb4b522249a..b9f0b86a9c9c 100644 --- a/h2o-core/src/main/java/water/Freezable.java +++ b/h2o-core/src/main/java/water/Freezable.java @@ -111,5 +111,5 @@ public interface Freezable extends Cloneable { T reloadFromBytes(byte [] ary); /** Make clone public, but without the annoying exception. * @return Returns this object cloned. */ - T clone(); + public T clone(); } diff --git a/h2o-core/src/main/java/water/Iced.java b/h2o-core/src/main/java/water/Iced.java index f36f6d87ed89..0c47556460f0 100644 --- a/h2o-core/src/main/java/water/Iced.java +++ b/h2o-core/src/main/java/water/Iced.java @@ -87,14 +87,9 @@ final public byte[] toJsonBytes() { @Override final public int frozenType() { return icer().frozenType(); } /** Clone, without the annoying exception */ @Override public final D clone() { - try { return cloneImpl(); } + try { return (D)super.clone(); } catch( CloneNotSupportedException e ) { throw water.util.Log.throwErr(e); } } - - /** override this if the class requires custom cloning logic */ - protected D cloneImpl() throws CloneNotSupportedException { - return (D)super.clone(); - } /** Copy over cloned instance 'src' over 'this', field by field. */ protected void copyOver( D src ) { icer().copyOver((D)this,src); } diff --git a/h2o-core/src/main/java/water/KeyGen.java b/h2o-core/src/main/java/water/KeyGen.java deleted file mode 100644 index a9fad77e4c6a..000000000000 --- a/h2o-core/src/main/java/water/KeyGen.java +++ /dev/null @@ -1,126 +0,0 @@ -package water; - -import water.util.ArrayUtils; - -import java.util.Arrays; -import java.util.Objects; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -public abstract class KeyGen extends Iced { - - public abstract > Key make(Object... params); - - public static final class ConstantKeyGen extends KeyGen { - private final Key key; - - public ConstantKeyGen(Key key) { - this.key = key; - } - - @Override - public > Key make(Object... params) { - return key; - } - } - - public static final class RandomKeyGen extends KeyGen { - public RandomKeyGen() {} - - @Override - public > Key make(Object... params) { - return Key.make(Key.rand()); - } - } - - public static final class PatternKeyGen extends KeyGen { - - private static final String PIPE = "\\s*\\|\\s*"; - - - private enum Command { - SUBSTITUTE() { - private final Pattern CMD = Pattern.compile("s/(.*?)/(.*?)/?"); - - @Override - boolean matches(String cmd) { - return cmd.startsWith("s/"); - } - - @Override - String apply(String cmd, String str) { - Matcher m = CMD.matcher(cmd); - if (m.matches()) { - return str.replace(m.group(1), m.group(2)); - } - throw new IllegalArgumentException("invalid command `"+cmd+"` for "+name()); - } - }, - ; - - abstract boolean matches(String cmd); - - abstract String apply(String cmd, String str); - } - - private static final String RANDOM_STR = "{rstr}"; //this will be replaced by random string on each new key; - private static final String COUNTER = "{n}"; //this will be replaced by incremented integer on each new key; - - private final String pattern; - - private final String[] commands; - private final AtomicInteger count = new AtomicInteger(0); - - /** - * creates a key generator using the given pattern supporting the following placeholders: - *
    - *
  • {rstr} -> will be replaced by a randomly generated string
  • - *
  • {n} -> will be replaced by an integer incremented at each call to {@link #make(Object...)}
  • - *
  • {0}, {1}, {2}, ... -> dynamically replaced by positional parameters of {@link #make(Object...)}
  • - *
  • piped commands applied after the pattern
      - *
    • {0}_suffix | s/foo/bar/ -> adds a suffix to the first {@link #make(Object...)} param and then substitute occurrences of "foo" with "bar"
    • - *
  • - *
- * - * @param pattern - */ - public PatternKeyGen(String pattern) { - String[] tokens = pattern.split(PIPE); - this.pattern = tokens[0]; - this.commands = ArrayUtils.remove(tokens, 0); - } - - public > Key make(Object... params) { - String keyStr = pattern; - for (int i = 0; i < params.length; i++) { - keyStr = keyStr.replace("{"+i+"}", Objects.toString(params[i])); - } - keyStr = keyStr - .replace(RANDOM_STR, Key.rand()) - .replace(COUNTER, Integer.toString(count.incrementAndGet())); - for (String cmd : commands) { - keyStr = applyCommand(cmd, keyStr); - } - return Key.make(keyStr); - } - - private String applyCommand(String cmd, String str) { - for (Command c : Command.values()) { - if (c.matches(cmd)) - return c.apply(cmd, str); - } - throw new IllegalArgumentException("Invalid command: "+cmd); - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder("PatternKeyGen{"); - sb.append(", pattern='").append(pattern).append('\''); - sb.append(", commands=").append(Arrays.toString(commands)); - sb.append(", count=").append(count); - sb.append('}'); - return sb.toString(); - } - } -} diff --git a/h2o-core/src/main/java/water/KeySnapshot.java b/h2o-core/src/main/java/water/KeySnapshot.java index d4979f0e0883..a03725fedb75 100644 --- a/h2o-core/src/main/java/water/KeySnapshot.java +++ b/h2o-core/src/main/java/water/KeySnapshot.java @@ -19,11 +19,10 @@ */ public class KeySnapshot { /** Class to filter keys from the snapshot. */ - @FunctionalInterface - public interface KVFilter { + public abstract static class KVFilter { /** @param k KeyInfo to be filtered * @return true if the key should be included in the new (filtered) set. */ - boolean filter(KeyInfo k); + public abstract boolean filter(KeyInfo k); } /** Class containing information about user keys. @@ -66,12 +65,6 @@ public KeySnapshot filter(KVFilter kvf){ if(kvf.filter(kinfo))res.add(kinfo); return new KeySnapshot(res.toArray(new KeyInfo[res.size()])); } - - public Key findFirst(KVFilter kvf) { - for (KeyInfo ki : _keyInfos) - if (kvf.filter(ki)) return ki._key; - return null; - } KeySnapshot(KeyInfo[] snapshot){ _keyInfos = snapshot; diff --git a/h2o-core/src/main/java/water/Keyed.java b/h2o-core/src/main/java/water/Keyed.java index 12899a68725e..ec68c96226a9 100644 --- a/h2o-core/src/main/java/water/Keyed.java +++ b/h2o-core/src/main/java/water/Keyed.java @@ -5,7 +5,7 @@ import water.util.Log; /** Iced, with a Key. Support for DKV removal. */ -public abstract class Keyed extends Iced implements Checksumable { +public abstract class Keyed extends Iced { /** Key mapping a Value which holds this object; may be null */ public Key _key; public Keyed() { _key = null; } // NOTE: every Keyed that can come out of the REST API has to have a no-arg constructor. diff --git a/h2o-core/src/main/java/water/Scope.java b/h2o-core/src/main/java/water/Scope.java index 3bfc16c78275..32a4b398c53f 100644 --- a/h2o-core/src/main/java/water/Scope.java +++ b/h2o-core/src/main/java/water/Scope.java @@ -199,7 +199,6 @@ private static void track_impl(Level level, Key key) { * Use {@link #untrack(Frame...)} is you need a behaviour symmetrical to {@link #track(Frame...)}. * @param keys */ - @SafeVarargs public static void untrack(K... keys) { if (keys.length == 0) return; Level level = lget(); // Pay the price of T.L.S. lookup diff --git a/h2o-core/src/main/java/water/ScopeInspect.java b/h2o-core/src/main/java/water/ScopeInspect.java index 9404b2411f96..76552d7797eb 100644 --- a/h2o-core/src/main/java/water/ScopeInspect.java +++ b/h2o-core/src/main/java/water/ScopeInspect.java @@ -106,8 +106,7 @@ private static StringBuilder appendKey(StringBuilder sb, Key key, int numIndent, if (fr != null) { for (int i=0; i vk = fr.keys()[i]; - String name = fr.name(i); - appendKey(sb, vk, numIndent+1, i+":'"+name+"'", true, keyFilter); + appendKey(sb, vk, numIndent+1, "vec_"+i, true, keyFilter); } } } diff --git a/h2o-core/src/main/java/water/TypeMap.java b/h2o-core/src/main/java/water/TypeMap.java index a7c7ff5e493d..64d755dcc33c 100644 --- a/h2o-core/src/main/java/water/TypeMap.java +++ b/h2o-core/src/main/java/water/TypeMap.java @@ -71,7 +71,7 @@ public class TypeMap { }; // Class name -> ID mapping - private static final NonBlockingHashMap MAP = new NonBlockingHashMap<>(); + static private final NonBlockingHashMap MAP = new NonBlockingHashMap<>(); // ID -> Class name mapping static String[] CLAZZES; // ID -> pre-allocated Golden Instance of Icer @@ -139,7 +139,7 @@ public static String[] bootstrapClasses() { // new code: leader sets string->ID mapping // printing: id -> string // deserial: id -> string -> Icer -> Iced (slow path) - // deserial: id -> Icer -> Iced (fast path) + // deserial: id -> Icer -> Iced (fath path) // lookup : id -> string (on leader) // diff --git a/h2o-core/src/main/java/water/Weaver.java b/h2o-core/src/main/java/water/Weaver.java index 31e9524d6209..ec8b1c7eb252 100644 --- a/h2o-core/src/main/java/water/Weaver.java +++ b/h2o-core/src/main/java/water/Weaver.java @@ -61,7 +61,6 @@ public static > Icer genDelegate( int id, Class cla } catch( InvocationTargetException | InstantiationException | IllegalAccessException | NotFoundException | CannotCompileException | NoSuchFieldException | ClassNotFoundException e) { - H2O.fail("Fatal error on serialization", e); throw new RuntimeException(e); } } diff --git a/h2o-core/src/main/java/water/api/GridSearchHandler.java b/h2o-core/src/main/java/water/api/GridSearchHandler.java index 7b143a213b8c..26dc6b40f7f9 100644 --- a/h2o-core/src/main/java/water/api/GridSearchHandler.java +++ b/h2o-core/src/main/java/water/api/GridSearchHandler.java @@ -7,13 +7,14 @@ import hex.grid.Grid; import hex.grid.GridSearch; import hex.grid.HyperSpaceSearchCriteria; -import static hex.grid.HyperSpaceWalker.SUBSPACES; +import static hex.grid.HyperSpaceWalker.BaseWalker.SUBSPACES; import hex.schemas.*; import water.H2O; import water.Job; import water.Key; import water.TypeMap; import water.api.schemas3.JobV3; +import water.api.schemas3.KeyV3; import water.api.schemas3.ModelParametersSchemaV3; import water.exceptions.H2OIllegalArgumentException; import water.util.IcedHashMap; @@ -67,7 +68,7 @@ private S resumeGrid(String algoURLName, Properties parms) { Recovery recovery = getRecovery(gss); Job gsJob = GridSearch.resumeGridSearch( jobKey, grid, - new SchemaModelParametersBuilderFactory(), + new DefaultModelParametersBuilderFactory(), recovery ); gss.hyper_parameters = null; @@ -138,7 +139,7 @@ private S trainGrid(String algoURLName, Properties parms) { destKey, params, sortedMap, - new SchemaModelParametersBuilderFactory(), + new DefaultModelParametersBuilderFactory(), (HyperSpaceSearchCriteria) gss.search_criteria.createAndFillImpl(), recovery, GridSearch.getParallelismLevel(gss.parallelism) @@ -204,7 +205,7 @@ private Recovery getRecovery(GridSearchSchema gss) { } } - public static class SchemaModelParametersBuilderFactory + public static class DefaultModelParametersBuilderFactory implements ModelParametersBuilderFactory { @Override @@ -214,7 +215,7 @@ public ModelParametersBuilder get(MP initialParams) { @Override public PojoUtils.FieldNaming getFieldNamingStrategy() { - return ModelParametersFromSchemaBuilder.NAMING; + return PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES; } } @@ -229,8 +230,6 @@ public PojoUtils.FieldNaming getFieldNamingStrategy() { */ public static class ModelParametersFromSchemaBuilder implements ModelParametersBuilderFactory.ModelParametersBuilder { - - private final static PojoUtils.FieldNaming NAMING = PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES; final private MP params; final private PS paramsSchema; @@ -242,11 +241,6 @@ public ModelParametersFromSchemaBuilder(MP initialParams) { fields = new ArrayList<>(7); } - @Override - public boolean isAssignable(String name) { - return params.isParameterAssignable(NAMING.toDest(name)); - } - public ModelParametersFromSchemaBuilder set(String name, Object value) { try { Field f = paramsSchema.getClass().getField(name); @@ -264,7 +258,7 @@ public ModelParametersFromSchemaBuilder set(String name, Object value) { public MP build() { PojoUtils - .copyProperties(params, paramsSchema, NAMING, null, + .copyProperties(params, paramsSchema, PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES, null, fields.toArray(new String[fields.size()])); // FIXME: handle these train/valid fields in different way // See: ModelParametersSchemaV3#fillImpl diff --git a/h2o-core/src/main/java/water/api/ModelBuilderHandlerUtils.java b/h2o-core/src/main/java/water/api/ModelBuilderHandlerUtils.java index 89a47fee34d5..19155f083999 100644 --- a/h2o-core/src/main/java/water/api/ModelBuilderHandlerUtils.java +++ b/h2o-core/src/main/java/water/api/ModelBuilderHandlerUtils.java @@ -15,7 +15,6 @@ static , P extends M ) { String algoName = ModelBuilder.algoName(algoURLName); // gbm -> GBM; deeplearning -> DeepLearning String schemaDir = ModelBuilder.schemaDirectory(algoURLName); - if (schemaDir == null) return null; // this builder doesn't have any schema // Build a Model Schema and a ModelParameters Schema String schemaName = schemaDir + algoName + "V" + version; diff --git a/h2o-core/src/main/java/water/api/ModelBuildersHandler.java b/h2o-core/src/main/java/water/api/ModelBuildersHandler.java index eeb65c0e349c..876cb5498768 100644 --- a/h2o-core/src/main/java/water/api/ModelBuildersHandler.java +++ b/h2o-core/src/main/java/water/api/ModelBuildersHandler.java @@ -22,10 +22,8 @@ class ModelBuildersHandler extends Handler { public ModelBuildersV3 list(int version, ModelBuildersV3 m) { m.model_builders = new ModelBuilderSchema.IcedHashMapStringModelBuilderSchema(); for( String algo : ModelBuilder.algos() ) { - ModelBuilderSchema schema = makeSchema(algo, version); - if (schema != null) { - m.model_builders.put(algo.toLowerCase(), schema); - } + ModelBuilder builder = ModelBuilder.make(algo, null, null); + m.model_builders.put(algo.toLowerCase(), (ModelBuilderSchema)SchemaServer.schema(version, builder).fillFromImpl(builder)); } return m; } @@ -34,9 +32,8 @@ public ModelBuildersV3 list(int version, ModelBuildersV3 m) { @SuppressWarnings("unused") // called through reflection by RequestServer public ModelBuildersV3 fetch(int version, ModelBuildersV3 m) { m.model_builders = new ModelBuilderSchema.IcedHashMapStringModelBuilderSchema(); - ModelBuilderSchema schema = makeSchema(m.algo, version); - if (schema != null) - m.model_builders.put(m.algo.toLowerCase(), schema); + ModelBuilder builder = ModelBuilder.make(m.algo, null, null); + m.model_builders.put(m.algo.toLowerCase(), (ModelBuilderSchema)SchemaServer.schema(version, builder).fillFromImpl(builder)); return m; } @@ -74,11 +71,6 @@ public ModelsInfoV4 modelsInfo(int version, ListRequestV4 m) { return res; } - private ModelBuilderSchema makeSchema(String algo, int version) { - ModelBuilder builder = ModelBuilder.make(algo, null, null); - if (ModelBuilder.schemaDirectory(builder.getName()) == null) return null; // this builder disabled schema - return (ModelBuilderSchema)SchemaServer.schema(version, builder).fillFromImpl(builder); // use ModelBuilderHandlerUtils.makeBuilderSchema instead? - } private String detectMojoVersion(ModelBuilder builder) { Class modelClass = ReflectionUtils.findActualClassParameter(builder.getClass(), 0); diff --git a/h2o-core/src/main/java/water/api/PipelineHandler.java b/h2o-core/src/main/java/water/api/PipelineHandler.java deleted file mode 100644 index 8f5020759a59..000000000000 --- a/h2o-core/src/main/java/water/api/PipelineHandler.java +++ /dev/null @@ -1,11 +0,0 @@ -package water.api; - -import hex.pipeline.DataTransformer; -import hex.schemas.DataTransformerV3; - -public class PipelineHandler extends Handler { - - public DataTransformerV3 fetchTransformer(int version, DataTransformerV3 schema) { - return (DataTransformerV3) schema.fillFromImpl(getFromDKV("datatransformer_id", schema.key.key(), DataTransformer.class)); - } -} diff --git a/h2o-core/src/main/java/water/api/schemas3/KeyV3.java b/h2o-core/src/main/java/water/api/schemas3/KeyV3.java index c2b738ed0aef..7d3624cfa695 100644 --- a/h2o-core/src/main/java/water/api/schemas3/KeyV3.java +++ b/h2o-core/src/main/java/water/api/schemas3/KeyV3.java @@ -2,7 +2,6 @@ import hex.Model; import hex.PartialDependence; -import hex.pipeline.DataTransformer; import hex.segments.SegmentModels; import hex.grid.Grid; import water.*; @@ -125,11 +124,6 @@ public DecryptionToolKeyV3() {} public DecryptionToolKeyV3(Key key) { super(key); } } - public static class DataTransformerKeyV3 extends KeyV3, T> { - public DataTransformerKeyV3() {} - public DataTransformerKeyV3(Key key) { super(key); } - } - @Override public S fillFromImpl(Iced i) { if (! (i instanceof Key)) throw new H2OIllegalArgumentException("fillFromImpl", "key", i); diff --git a/h2o-core/src/main/java/water/util/Checksum.java b/h2o-core/src/main/java/water/util/Checksum.java deleted file mode 100644 index 2715e40b1e61..000000000000 --- a/h2o-core/src/main/java/water/util/Checksum.java +++ /dev/null @@ -1,107 +0,0 @@ -package water.util; - -import water.Checksumable; -import water.H2O; -import water.Weaver; - -import java.lang.reflect.Field; -import java.util.Arrays; -import java.util.Comparator; -import java.util.Set; - -public final class Checksum { - - private Checksum() {} - - public static long checksum(final T obj) { - return checksum(obj, null); - } - - public static long checksum(final T obj, final Set ignoredFields) { - return checksum(obj, ignoredFields, 0x600DL); - } - - /** - * Compute a checksum based on all non-transient non-static ice-able assignable fields (incl. inherited ones). - * Sort the fields first, since reflection gives us the fields in random order, and we don't want the checksum to be affected by the field order. - * - * NOTE: if a field is added to a class the checksum will differ even when all the previous parameters have the same value. If - * a client wants backward compatibility they will need to compare values explicitly. - * - * The method is motivated by standard hash implementation `hash = hash * P + value` but we use high prime numbers in random order. - * - * @param ignoredFields A {@link Set} of fields to ignore. Can be empty or null. - * @return checksum A 64-bit long representing the checksum of the object - */ - public static long checksum(final T obj, final Set ignoredFields, final long initVal) { - assert obj != null; - long xs = initVal; - int count = 0; - Field[] fields = Weaver.getWovenFields(obj.getClass()); - Arrays.sort(fields, Comparator.comparing(Field::getName)); - for (Field f : fields) { - if (ignoredFields != null && ignoredFields.contains(f.getName())) { - // Do not include ignored fields in the final hash - continue; - } - final long P = MathUtils.PRIMES[count % MathUtils.PRIMES.length]; - Class c = f.getType(); - Object fvalue; - try { - f.setAccessible(true); - fvalue = f.get(obj); - } catch (IllegalAccessException e) { - throw new RuntimeException(e); - } - if (c.isArray()) { - try { - if (fvalue != null) { - if (c.getComponentType() == Integer.TYPE){ - int[] arr = (int[]) fvalue; - xs = xs * P + (long) Arrays.hashCode(arr); - } else if (c.getComponentType() == Float.TYPE) { - float[] arr = (float[]) fvalue; - xs = xs * P + (long) Arrays.hashCode(arr); - } else if (c.getComponentType() == Double.TYPE) { - double[] arr = (double[]) fvalue; - xs = xs * P + (long) Arrays.hashCode(arr); - } else if (c.getComponentType() == Long.TYPE){ - long[] arr = (long[]) fvalue; - xs = xs * P + (long) Arrays.hashCode(arr); - } else if (c.getComponentType() == Boolean.TYPE){ - boolean[] arr = (boolean[]) fvalue; - xs = xs * P + (long) Arrays.hashCode(arr); - } else { - Object[] arr = (Object[]) fvalue; - if (Checksumable.class.isAssignableFrom(arr.getClass().getComponentType())) { - for (Checksumable cs : (Checksumable[])arr) { - xs = xs * P + cs.checksum(); - } - } else { - xs = xs * P + (long) Arrays.deepHashCode(arr); - } - } //else lead to ClassCastException - } else { - xs = xs * P; - } - } catch (ClassCastException t) { - throw H2O.fail("Failed to calculate checksum for the parameter object", t); //no support yet for int[][] etc. - } - } else { - if (fvalue instanceof Enum) { - // use string hashcode for enums, otherwise the checksum would be different each run - xs = xs * P + (long) (fvalue.toString().hashCode()); - } else if (fvalue instanceof Checksumable) { - xs = xs * P + ((Checksumable) fvalue).checksum(); - } else if (fvalue != null) { - xs = xs * P + (long)(fvalue.hashCode()); - } else { - xs = xs * P + P; - } - } - count++; - } - return xs; - } - -} diff --git a/h2o-core/src/main/java/water/util/FrameUtils.java b/h2o-core/src/main/java/water/util/FrameUtils.java index 82d9583a3c0a..5924b31b7f66 100644 --- a/h2o-core/src/main/java/water/util/FrameUtils.java +++ b/h2o-core/src/main/java/water/util/FrameUtils.java @@ -1091,12 +1091,12 @@ public Job exec() { } } - public static void cleanUp(Collection toDelete) { + static public void cleanUp(IcedHashMap toDelete) { if (toDelete == null) { return; } Futures fs = new Futures(); - for (Key k : toDelete) { + for (Key k : toDelete.keySet()) { Keyed.remove(k, fs, true); } fs.blockForPending(); diff --git a/h2o-core/src/main/java/water/util/IcedHashMapBase.java b/h2o-core/src/main/java/water/util/IcedHashMapBase.java index e0b0220addcf..767f2a177e1e 100644 --- a/h2o-core/src/main/java/water/util/IcedHashMapBase.java +++ b/h2o-core/src/main/java/water/util/IcedHashMapBase.java @@ -1,10 +1,16 @@ package water.util; -import water.*; +import water.AutoBuffer; +import water.Freezable; +import water.H2O; +import water.Iced; import java.io.Serializable; import java.lang.reflect.Array; -import java.util.*; +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; +import java.util.Set; import java.util.stream.Stream; import static org.apache.commons.lang.ArrayUtils.toObject; @@ -16,20 +22,13 @@ public abstract class IcedHashMapBase extends Iced implements Map, Cloneable, Serializable { public enum KeyType { - Long(Long.class, java.lang.Long.MIN_VALUE), String(String.class), Freezable(Freezable.class), ; Class _clazz; - Object _endValue; KeyType(Class clazz) { - this(clazz, null); - } - - KeyType(Class clazz, Object endValue) { _clazz = clazz; - _endValue = endValue; } } @@ -144,12 +143,9 @@ public final AutoBuffer write_impl( AutoBuffer ab ) { byte mode = getMode(getKeyType(key), getValueType(val), getArrayType(val)); ab.put1(mode); // Type of hashmap being serialized writeMap(ab, mode); // Do the hard work of writing the map - KeyType kt = keyType(mode); - switch (kt) { // finally write null to indicate termination - case Long: - return ab.put8((long)kt._endValue); + switch (keyType(mode)) { case String: - return ab.putStr((String)kt._endValue); + return ab.putStr(null); case Freezable: default: return ab.put(null); @@ -191,7 +187,6 @@ protected void writeMap(AutoBuffer ab, byte mode) { protected void writeKey(AutoBuffer ab, KeyType keyType, K key) { switch (keyType) { - case Long: ab.put8((Long)key); break; case String: ab.putStr((String)key); break; case Freezable: ab.put((Freezable)key); break; } @@ -236,8 +231,7 @@ protected void writeValue(AutoBuffer ab, ValueType valueType, ArrayType arrayTyp @SuppressWarnings("unchecked") protected K readKey(AutoBuffer ab, KeyType keyType) { switch (keyType) { - case Long: return (K)(Long) ab.get8(); - case String: return (K) ab.getStr(); + case String: return (K) ab.getStr(); case Freezable: return ab.get(); default: return null; } @@ -298,7 +292,7 @@ public final IcedHashMapBase read_impl(AutoBuffer ab) { while (true) { K key = readKey(ab, keyType); - if (Objects.equals(key, keyType._endValue)) break; + if (key == null) break; V val = readValue(ab, valueType, arrayType); map.put(key, val); } @@ -323,12 +317,13 @@ public final AutoBuffer writeJSON_impl( AutoBuffer ab ) { K key = entry.getKey(); V value = entry.getValue(); - assert key != null; + KeyType keyType = getKeyType(key); + assert keyType == KeyType.String: "JSON format supports only String keys"; ValueType valueType = getValueType(value); ArrayType arrayType = getArrayType(value); if (first) { first = false; } else {ab.put1(',').put1(' '); } - String name = key.toString(); + String name = (String) key; switch (arrayType) { case None: switch (valueType) { diff --git a/h2o-core/src/main/java/water/util/IcedLong.java b/h2o-core/src/main/java/water/util/IcedLong.java index 80fe2813c31b..c0fece57df63 100644 --- a/h2o-core/src/main/java/water/util/IcedLong.java +++ b/h2o-core/src/main/java/water/util/IcedLong.java @@ -1,6 +1,9 @@ package water.util; -import water.*; +import water.H2O; +import water.Iced; +import water.Key; +import water.TAtomic; public class IcedLong extends Iced { public long _val; @@ -16,18 +19,10 @@ public class IcedLong extends Iced { public static IcedLong valueOf(long value) { return new IcedLong(value); } - - public static long get(Key key) { - return ((IcedLong) DKV.getGet(key))._val; - } public static long incrementAndGet(Key key) { return ((AtomicIncrementAndGet) new AtomicIncrementAndGet().invoke(key))._val; } - - public static long decrementAndGet(Key key) { - return ((AtomicDecrementAndGet) new AtomicDecrementAndGet().invoke(key))._val; - } public static class AtomicIncrementAndGet extends TAtomic { public AtomicIncrementAndGet() { @@ -46,22 +41,4 @@ protected IcedLong atomic(IcedLong old) { } } - public static class AtomicDecrementAndGet extends TAtomic { - public AtomicDecrementAndGet() { - this(null); - } - public AtomicDecrementAndGet(H2O.H2OCountedCompleter cc) { - super(cc); - } - - // OUT - public long _val; - - @Override - protected IcedLong atomic(IcedLong old) { - return new IcedLong(_val = old._val - 1); - } - } - - } diff --git a/h2o-core/src/main/java/water/util/PojoUtils.java b/h2o-core/src/main/java/water/util/PojoUtils.java index 1f250dcee47a..b06b8dd41aa0 100644 --- a/h2o-core/src/main/java/water/util/PojoUtils.java +++ b/h2o-core/src/main/java/water/util/PojoUtils.java @@ -21,16 +21,16 @@ public class PojoUtils { public enum FieldNaming { CONSISTENT { - @Override public String toDest(String origin) { return origin; } - @Override public String toOrigin(String dest) { return dest; } + @Override String toDest(String origin) { return origin; } + @Override String toOrigin(String dest) { return dest; } }, DEST_HAS_UNDERSCORES { - @Override public String toDest(String origin) { return "_" + origin; } - @Override public String toOrigin(String dest) { return dest.substring(1); } + @Override String toDest(String origin) { return "_" + origin; } + @Override String toOrigin(String dest) { return dest.substring(1); } }, ORIGIN_HAS_UNDERSCORES { - @Override public String toDest(String origin) { return origin.substring(1); } - @Override public String toOrigin(String dest) { return "_" + dest; } + @Override String toDest(String origin) { return origin.substring(1); } + @Override String toOrigin(String dest) { return "_" + dest; } }; /** @@ -38,14 +38,14 @@ public enum FieldNaming { * @param origin name of origin argument * @return return a name of destination argument. */ - public abstract String toDest(String origin); + abstract String toDest(String origin); /** * Return name of origin parameter derived from name of origin parameter. * @param dest name of destination argument. * @return return a name of origin argument. */ - public abstract String toOrigin(String dest); + abstract String toOrigin(String dest); } @@ -625,28 +625,16 @@ public static Field getFieldEvenInherited(Object o, String name) throws NoSuchFi * @throws java.lang.IllegalArgumentException when o is null, or field is not found, * or field cannot be read. */ - public static Object getFieldValue(Object o, String name) { - return getFieldValue(o, name, false); - } - - public static Object getFieldValue(Object o, String name, boolean anyVisibility) { + public static Object getFieldValue(Object o, String name, FieldNaming fieldNaming) { if (o == null) throw new IllegalArgumentException("Cannot get the field from null object!"); + String destName = fieldNaming.toDest(name); try { - Field f = PojoUtils.getFieldEvenInherited(o, name); // failing with fields declared in superclasses - if (anyVisibility) f.setAccessible(true); + Field f = PojoUtils.getFieldEvenInherited(o, destName); // failing with fields declared in superclasses return f.get(o); } catch (NoSuchFieldException e) { - throw new IllegalArgumentException("Field not found: '"+name+"' on object " + o); + throw new IllegalArgumentException("Field not found: '" + name + "/" + destName + "' on object " + o); } catch (IllegalAccessException e) { - throw new IllegalArgumentException("Cannot get value of the field: '"+name+"' on object " + o); - } - } - public static Object getFieldValue(Object o, String name, FieldNaming fieldNaming) { - String dest = fieldNaming.toDest(name); - try { - return getFieldValue(o, dest); - } catch (IllegalArgumentException e) { - throw new IllegalArgumentException(e.getMessage().replace("'"+dest+"'", "'"+name+"/"+dest+"'")); + throw new IllegalArgumentException("Cannot get value of the field: '" + name + "/" + destName + "' on object " + o); } } @@ -728,6 +716,32 @@ private static T fillFromMap(T o, Map setFields) { return o; } + /** + * Helper for Arrays.equals(). + */ + public static boolean arraysEquals(Object a, Object b) { + if (a == null || ! a.getClass().isArray()) + throw new H2OIllegalArgumentException("a", "arraysEquals", a); + if (b == null || ! b.getClass().isArray()) + throw new H2OIllegalArgumentException("b", "arraysEquals", b); + if (a.getClass().getComponentType() != b.getClass().getComponentType()) + throw new H2OIllegalArgumentException("Can't compare arrays of different types: " + a.getClass().getComponentType() + " and: " + b.getClass().getComponentType()); + + if (a.getClass().getComponentType() == boolean.class) return Arrays.equals((boolean[])a, (boolean[])b); + if (a.getClass().getComponentType() == Boolean.class) return Arrays.equals((Boolean[])a, (Boolean[])b); + + if (a.getClass().getComponentType() == char.class) return Arrays.equals((char[])a, (char[])b); + if (a.getClass().getComponentType() == short.class) return Arrays.equals((short[])a, (short[])b); + if (a.getClass().getComponentType() == Short.class) return Arrays.equals((Short[])a, (Short[])b); + if (a.getClass().getComponentType() == int.class) return Arrays.equals((int[])a, (int[])b); + if (a.getClass().getComponentType() == Integer.class) return Arrays.equals((Integer[])a, (Integer[])b); + if (a.getClass().getComponentType() == float.class) return Arrays.equals((float[])a, (float[])b); + if (a.getClass().getComponentType() == Float.class) return Arrays.equals((Float[])a, (Float[])b); + if (a.getClass().getComponentType() == double.class) return Arrays.equals((double[])a, (double[])b); + if (a.getClass().getComponentType() == Double.class) return Arrays.equals((Double[])a, (Double[])b); + return Arrays.deepEquals((Object[])a, (Object[])b); + } + public static String toJavaDoubleArray(double[] array) { if (array == null) { return "null"; diff --git a/h2o-core/src/main/java/water/util/TwoDimTable.java b/h2o-core/src/main/java/water/util/TwoDimTable.java index 7c751998817a..26de71836210 100644 --- a/h2o-core/src/main/java/water/util/TwoDimTable.java +++ b/h2o-core/src/main/java/water/util/TwoDimTable.java @@ -373,20 +373,11 @@ public String toString(final int pad, boolean full) { return sb.toString(); } public Frame asFrame(Key frameKey) { - return asFrame(frameKey, false); - } - - public Frame asFrame(Key frameKey, boolean withRowHeaders) { String[] colNames = new String[getColDim()]; System.arraycopy(getColHeaders(), 0, colNames, 0, getColDim()); - int colOffset = 0; - if (withRowHeaders) { - colNames = ArrayUtils.append(new String[]{"_"}, colNames); - colOffset++; - } Vec[] vecs = new Vec[colNames.length]; - if (withRowHeaders) vecs[0] = Vec.makeVec(getRowHeaders(), Vec.newKey()); + vecs[0] = Vec.makeVec(getRowHeaders(), Vec.newKey()); for (int j = 0; j < this.getColDim(); j++) { switch (getColTypes()[j]){ @@ -395,7 +386,7 @@ public Frame asFrame(Key frameKey, boolean withRowHeaders) { for (int i = 0; i < getRowDim(); i++) { strRow[i] = (String) get(i, j); } - vecs[j+colOffset] = Vec.makeVec(strRow, Vec.newKey()); + vecs[j] = Vec.makeVec(strRow, Vec.newKey()); break; case "int": case "long": @@ -403,7 +394,7 @@ public Frame asFrame(Key frameKey, boolean withRowHeaders) { for (int i = 0; i < getRowDim(); i++) { longRow[i] = ((Number) get(i, j)).longValue(); } - vecs[j+colOffset] = Vec.makeVec(longRow, Vec.newKey()); + vecs[j] = Vec.makeVec(longRow, Vec.newKey()); break; case "float": case "double": @@ -411,7 +402,7 @@ public Frame asFrame(Key frameKey, boolean withRowHeaders) { for (int i = 0; i < getRowDim(); i++) { dblRow[i] = (double) get(i, j); } - vecs[j+colOffset] = Vec.makeVec(dblRow, Vec.newKey()); + vecs[j] = Vec.makeVec(dblRow, Vec.newKey()); break; } } diff --git a/h2o-core/src/main/resources/META-INF/services/water.api.RestApiExtension b/h2o-core/src/main/resources/META-INF/services/water.api.RestApiExtension index e50525c5d638..f5e9d032750b 100644 --- a/h2o-core/src/main/resources/META-INF/services/water.api.RestApiExtension +++ b/h2o-core/src/main/resources/META-INF/services/water.api.RestApiExtension @@ -1,3 +1,2 @@ water.api.RegisterV3Api -water.api.RegisterV4Api -hex.pipeline.PipelineAlgoRegistration +water.api.RegisterV4Api \ No newline at end of file diff --git a/h2o-core/src/main/resources/META-INF/services/water.api.Schema b/h2o-core/src/main/resources/META-INF/services/water.api.Schema index cae1d52a70af..75286a3d4080 100644 --- a/h2o-core/src/main/resources/META-INF/services/water.api.Schema +++ b/h2o-core/src/main/resources/META-INF/services/water.api.Schema @@ -1,5 +1,4 @@ hex.schemas.ClusteringModelBuilderSchema -hex.schemas.DataTransformerV3 hex.schemas.GridSchemaV99 hex.schemas.GridSearchSchema hex.schemas.HyperSpaceSearchCriteriaV99 @@ -7,10 +6,6 @@ hex.schemas.HyperSpaceSearchCriteriaV99$CartesianSearchCriteriaV99 hex.schemas.HyperSpaceSearchCriteriaV99$RandomDiscreteValueSearchCriteriaV99 hex.schemas.HyperSpaceSearchCriteriaV99$SequentialSearchCriteriaV99 hex.schemas.ModelBuilderSchema -hex.schemas.PipelineV3 -hex.schemas.PipelineModelV3 -hex.schemas.PipelineV3$PipelineParametersV3 -hex.schemas.PipelineModelV3$PipelineModelOutputV3 hex.schemas.QuantileV3 hex.schemas.QuantileV3$QuantileParametersV3 water.api.ModelBuildersHandler$ModelIdV3 @@ -61,7 +56,6 @@ water.api.schemas3.KeyV3$JobKeyV3 water.api.schemas3.KeyV3$ModelKeyV3 water.api.schemas3.KeyV3$PartialDependenceKeyV3 water.api.schemas3.KeyV3$DecryptionToolKeyV3 -water.api.schemas3.KeyV3$DataTransformerKeyV3 water.api.schemas3.KillMinus3V3 water.api.schemas3.LogAndEchoV3 water.api.schemas3.LogsV3 diff --git a/h2o-core/src/test/java/hex/HyperSpaceWalkerTest.java b/h2o-core/src/test/java/hex/HyperSpaceWalkerTest.java index d6878b0d3287..83aa3234542a 100644 --- a/h2o-core/src/test/java/hex/HyperSpaceWalkerTest.java +++ b/h2o-core/src/test/java/hex/HyperSpaceWalkerTest.java @@ -1,13 +1,12 @@ package hex; +import hex.grid.GridSearch; import hex.grid.HyperSpaceSearchCriteria; import hex.grid.HyperSpaceWalker; -import hex.grid.SimpleParametersBuilderFactory; import org.junit.BeforeClass; import org.junit.Test; import water.TestUtil; import water.test.dummy.DummyModelParameters; -import water.util.ReflectionUtils; import java.util.HashMap; import java.util.Map; @@ -17,17 +16,6 @@ public class HyperSpaceWalkerTest extends TestUtil { @BeforeClass public static void stall() { stall_till_cloudsize(1); } static public class DummyXGBoostModelParameters extends DummyModelParameters { - - private static final DummyXGBoostModelParameters DEFAULTS; - - static { - try { - DEFAULTS = DummyXGBoostModelParameters.class.newInstance(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - public int _max_depth; public double _min_rows; public double _sample_rate; @@ -38,12 +26,6 @@ static public class DummyXGBoostModelParameters extends DummyModelParameters { public float _reg_alpha; public float _scale_pos_weight; public float _max_delta_step; - - @Override - public Object getParameterDefaultValue(String name) { - // tricking the default logic here as this parameters class is not properly registered, so we can't obtain the defaults the usual way. - return ReflectionUtils.getFieldValue(DEFAULTS, name); - } } @@ -66,7 +48,7 @@ public void testRandomDiscreteValueWalkerFinishes() { searchParams.put("_max_delta_step", new Float[]{0f, 5f, 10f}); HyperSpaceWalker.RandomDiscreteValueWalker rdvw = new HyperSpaceWalker.RandomDiscreteValueWalker<>(new DummyXGBoostModelParameters(), - searchParams, new SimpleParametersBuilderFactory<>(), new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria()); + searchParams, new GridSearch.SimpleParametersBuilderFactory<>(), new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria()); HyperSpaceWalker.HyperSpaceIterator hsi = rdvw.iterator(); try { while (hsi.hasNext()) { diff --git a/h2o-core/src/test/java/hex/ModelAdaptTest.java b/h2o-core/src/test/java/hex/ModelAdaptTest.java index 5eb454a36948..0461fe95f989 100644 --- a/h2o-core/src/test/java/hex/ModelAdaptTest.java +++ b/h2o-core/src/test/java/hex/ModelAdaptTest.java @@ -148,7 +148,7 @@ static class AOutput extends Model.Output { } Frame.deleteTempFrameAndItsNonSharedVecs(adapt, tst); tst.remove(); - FrameUtils.cleanUp(am._toDelete.keySet()); + FrameUtils.cleanUp(am._toDelete); } } diff --git a/h2o-core/src/test/java/hex/pipeline/DataTransformerTest.java b/h2o-core/src/test/java/hex/pipeline/DataTransformerTest.java deleted file mode 100644 index a85f758927b9..000000000000 --- a/h2o-core/src/test/java/hex/pipeline/DataTransformerTest.java +++ /dev/null @@ -1,277 +0,0 @@ -package hex.pipeline; - -import org.junit.Test; -import org.junit.runner.RunWith; -import water.*; -import water.fvec.*; -import water.logging.Logger; -import water.logging.LoggerFactory; -import water.runner.CloudSize; -import water.runner.H2ORunner; -import water.util.ArrayUtils; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; - -import static org.junit.Assert.*; -import static water.TestUtil.*; - -@RunWith(H2ORunner.class) -@CloudSize(1) -public class DataTransformerTest { - - @Test - public void test_transform() { - try { - Scope.enter(); - DataTransformer dt = new AddRandomColumnTransformer("foo"); - final Frame fr = Scope.track(new TestFrameBuilder() - .withColNames("one", "two", "target") - .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) - .withDataForCol(0, ard(1, 2, 3)) - .withDataForCol(1, ard(1, 2, 3)) - .withDataForCol(2, ar("yes", "no", "yes")) - .build()); - - Frame transformed = Scope.track(dt.transform(fr)); - assertEquals(4, transformed.names().length); - assertArrayEquals(new String[]{"one", "two", "target", "foo"}, transformed.names()); - } finally { - Scope.exit(); - } - } - - public static class MultiplyNumericColumnTransformer extends DataTransformer { - - private final String colName; - - private final int multiplier; - - public MultiplyNumericColumnTransformer(String colName, int multiplier) { - this.colName = colName; - this.multiplier = multiplier; - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - Frame tr = new Frame(fr); - final int colIdx = tr.find(colName); - Vec col = tr.vec(colIdx); - assert col.isNumeric(); - Vec multCol = new MRTask() { - @Override - public void map(Chunk[] cs, NewChunk[] ncs) { - for (int i = 0; i < cs[0]._len; i++) - ncs[0].addNum(multiplier * (cs[0].atd(i))); - } - }.doAll(Vec.T_NUM, new Frame(col)).outputFrame().vec(0); - tr.replace(colIdx, multCol); - return tr; - } - } - - - public static class AddRandomColumnTransformer extends DataTransformer { - - private final String colName; - private final long seed; - - public AddRandomColumnTransformer(String colName) { - this(colName, 0); - } - - public AddRandomColumnTransformer(String colName, long seed) { - this.colName = colName; - this.seed = seed; - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - Frame tr = new Frame(fr); - tr.add(colName, tr.anyVec().makeRand(seed)); - return tr; - } - } - - public static class AddDummyCVColumnTransformer extends DataTransformer { - - private final String colName; - private final int vecType; - - public AddDummyCVColumnTransformer(String colName) { - this(colName, Vec.T_NUM); - } - - public AddDummyCVColumnTransformer(String colName, int vecType) { - this.colName = colName; - this.vecType = vecType; - } - - @Override - public boolean isCVSensitive() { - return true; - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - if (type == FrameType.Training && context._params._is_cv_model) { - Frame tr = new Frame(fr); - Vec v = Vec.makeRepSeq(tr.anyVec().length(), context._params._cv_fold+2); - Vec tmp; - switch (vecType) { - case Vec.T_CAT: - tmp = v.toCategoricalVec(); v.remove(); v = tmp; break; - case Vec.T_STR: - tmp = v.toStringVec(); v.remove(); v = tmp; break; - case Vec.T_NUM: - default: - //already numeric by construct - break; - } - tr.add(colName, v); - return tr; - } - return fr; - } - } - - public static class RenameFrameTransformer extends DataTransformer { - - private final String frameName; - - public RenameFrameTransformer(String frameName) { - this.frameName = frameName; - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - return new Frame(Key.make(frameName), fr.names(), fr.vecs()); - } - } - - public static class FrameTrackerAsTransformer extends DataTransformer { - - static final Logger logger = LoggerFactory.getLogger(FrameTrackerAsTransformer.class); - - public static class Transformation extends Iced { - final String frameId; - final FrameType type; - final boolean is_cv; - - public Transformation(String frameId, FrameType type, boolean is_cv) { - this.frameId = frameId; - this.type = type; - this.is_cv = is_cv; - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder("Transformation{"); - sb.append("frameId='").append(frameId).append('\''); - sb.append(", type=").append(type); - sb.append(", is_cv=").append(is_cv); - sb.append('}'); - return sb.toString(); - } - } - - public static class Transformations extends Keyed { - - Transformation[] transformations; - - public Transformations(Key key) { - super(key); - transformations = new Transformation[0]; - } - - public void add(Transformation t) { - new AddTransformation(t).invoke(getKey()); - } - - public int size() { - return transformations.length; - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder("Transformations{"); - sb.append("transformations=").append(Arrays.toString(transformations)); - sb.append('}'); - return sb.toString(); - } - - private static class AddTransformation extends TAtomic { - - private final Transformation transformation; - - public AddTransformation(Transformation transformation) { - this.transformation = transformation; - } - - @Override - protected Transformations atomic(Transformations old) { - old.transformations = ArrayUtils.append(old.transformations, transformation); - return old; - } - } - } - - private final Key transformationsKey; - - public FrameTrackerAsTransformer() { - transformationsKey = Key.make(); - DKV.put(new Transformations(transformationsKey)); - } - - public Transformations getTransformations() { - return transformationsKey.get(); - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - if (fr == null) return null; - logger.info(fr.getKey()+": columns="+Arrays.toString(fr.names())); - boolean is_cv = context != null && context._params._is_cv_model; - transformationsKey.get().add(new Transformation(fr.getKey().toString(), type, is_cv)); - return fr; - } - - @Override - protected Futures remove_impl(Futures fs, boolean cascade) { - Keyed.remove(transformationsKey, fs, cascade); - return super.remove_impl(fs, cascade); - } - } - - @FunctionalInterface - interface FrameChecker extends Serializable { - void check(Frame fr); - } - - public static class FrameCheckerAsTransformer extends DataTransformer { - - FrameChecker checker; - - public FrameCheckerAsTransformer(FrameChecker checker) { - this.checker = checker; - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - checker.check(fr); - return fr; - } - } - - public static class FailingAssertionTransformer extends DataTransformer { - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - assert false: "expected"; - return fr; - } - } - -} diff --git a/h2o-core/src/test/java/hex/pipeline/PipelineHelperTest.java b/h2o-core/src/test/java/hex/pipeline/PipelineHelperTest.java deleted file mode 100644 index b0a7327b3ce7..000000000000 --- a/h2o-core/src/test/java/hex/pipeline/PipelineHelperTest.java +++ /dev/null @@ -1,178 +0,0 @@ -package hex.pipeline; - -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.runner.RunWith; -import water.DKV; -import water.Key; -import water.Scope; -import water.fvec.Frame; -import water.fvec.TestFrameBuilder; -import water.fvec.Vec; -import water.junit.rules.DKVIsolation; -import water.junit.rules.ScopeTracker; -import water.runner.CloudSize; -import water.runner.H2ORunner; - -import java.io.Serializable; - -import static hex.pipeline.PipelineHelper.reassign; -import static hex.pipeline.PipelineHelper.reassignInplace; -import static org.junit.Assert.*; -import static water.TestUtil.ard; -import static water.TestUtil.assertFrameEquals; - -@RunWith(H2ORunner.class) -@CloudSize(1) -public class PipelineHelperTest { - - private static class ReferenceFrameProvider extends ExternalResource implements Serializable { - private static final Key refKey = Key.make("refFrame"); - - private Frame refFrame; - - @Override - protected void before() throws Throwable { - refFrame = new TestFrameBuilder() - .withName(refKey.toString()) - .withColNames("one", "two") - .withVecTypes(Vec.T_NUM, Vec.T_NUM) - .withDataForCol(0, ard(1, 2, 3)) - .withDataForCol(1, ard(3, 2, 1)) - .build(); - } - - @Override - protected void after() { - assertNotNull(DKV.get(refKey)); - assertFrameEquals(DKV.getGet(refKey), refFrame, 0); // only comparing frame content as in multinode, can't guarantee that these will be the same objects - for (Key k : refFrame.keys()) assertNotNull(DKV.get(k)); - } - - Frame get() { - return refFrame; - } - } - - @Rule - public ScopeTracker scope = new ScopeTracker(); - - @Rule - public ReferenceFrameProvider refFrame = new ReferenceFrameProvider(); - - @Rule - public DKVIsolation isolation = new DKVIsolation(); - - private void addRndVec(Frame fr) { - int lockedStatus = fr._lockers == null ? 0 : fr._lockers.length; - if (lockedStatus == 0) fr.write_lock(); - fr.add("rndvec", fr.anyVec().makeRand(0)); - fr.update(); - if (lockedStatus == 0) fr.unlock(); - } - - @Test - public void test_reassign_frame_null_key_with_fresh_key() { - Frame ref = refFrame.get(); - Frame copy = new Frame(null, ref.names(), ref.vecs()); - Key copyKey = copy.getKey(); - assertNull(copyKey); - - Key reassigned = Key.make("reassigned"); - Frame cc = reassign(copy, reassigned); - addRndVec(cc); - Scope.track_generic(cc); // tracking the key only (instead of the frame) to better see what can potentially leak - assertNotSame(copy, cc); - assertNull(copy.getKey()); //copy was not assigned any key - assertNotNull(DKV.get(reassigned)); - assertSame(cc, DKV.getGet(reassigned)); - } - - @Test - public void test_reassign_frame_not_in_DKV_with_fresh_key() { - Frame ref = refFrame.get(); - Frame copy = new Frame(ref); - Key copyKey = copy.getKey(); - assertNotNull(copyKey); - assertNull(DKV.get(copyKey)); - - Key reassigned = Key.make("reassigned"); - Frame cc = reassign(copy, reassigned); - addRndVec(cc); - Scope.track_generic(cc); // tracking the key only (instead of the frame) to better see what can potentially leak - assertNotSame(copy, cc); - assertNull(DKV.get(copyKey)); - assertEquals(copyKey, copy.getKey()); //copy key was not modified - assertNotNull(DKV.get(reassigned)); - assertSame(cc, DKV.getGet(reassigned)); - } - - @Test - public void test_reassign_frame_in_DKV_with_fresh_key() { - Frame ref = refFrame.get(); - Frame copy = new Frame(ref); - Key copyKey = copy.getKey(); - DKV.put(Scope.track_generic(copy)); // tracking the key only (instead of the frame) to better see what can potentially leak - assertNotNull(copyKey); - assertNotNull(DKV.get(copyKey)); - Key reassigned = Key.make("reassigned"); - Frame cc = reassign(copy, reassigned); - addRndVec(cc); - Scope.track_generic(cc); // tracking the key only (instead of the frame) to better see what can potentially leak - assertNotNull(DKV.get(copyKey)); - assertSame(copy, DKV.getGet(copyKey)); // copy still assigned to previous key - assertNotNull(DKV.get(reassigned)); - assertSame(cc, DKV.getGet(reassigned)); - } - - @Test - public void test_reassign_inplace_frame_null_key_with_fresh_key() { - Frame ref = refFrame.get(); - Frame copy = new Frame(null, ref.names(), ref.vecs()); - Key copyKey = copy.getKey(); - assertNull(copyKey); - - Key reassigned = Key.make("reassigned"); - reassignInplace(copy, reassigned); - addRndVec(copy); - Scope.track_generic(copy); // tracking the key only (instead of the frame) to better see what can potentially leak - assertNotNull(DKV.get(reassigned)); - assertSame(copy, DKV.getGet(reassigned)); - } - - @Test - public void test_reassign_inplace_frame_not_in_DKV_with_fresh_key() { - Frame ref = refFrame.get(); - Frame copy = new Frame(ref); - Key copyKey = copy.getKey(); - assertNotNull(copyKey); - assertNull(DKV.get(copyKey)); - - Key reassigned = Key.make("reassigned"); - reassignInplace(copy, reassigned); - addRndVec(copy); - Scope.track_generic(copy); // tracking the key only (instead of the frame) to better see what can potentially leak - assertNull(DKV.get(copyKey)); - assertNotNull(DKV.get(reassigned)); - assertSame(copy, DKV.getGet(reassigned)); - } - - @Test - public void test_reassign_inplace_frame_in_DKV_with_fresh_key() { - Frame ref = refFrame.get(); - Frame copy = new Frame(ref); - Key copyKey = copy.getKey(); - DKV.put(copy); - assertNotNull(copyKey); - assertNotNull(DKV.get(copyKey)); - Key reassigned = Key.make("reassigned"); - reassignInplace(copy, reassigned); - addRndVec(copy); - Scope.track_generic(copy); // tracking the key only (instead of the frame) to better see what can potentially leak - assertNull(DKV.get(copyKey)); - assertNotNull(DKV.get(reassigned)); - assertSame(copy, DKV.getGet(reassigned)); - } - -} diff --git a/h2o-core/src/test/java/hex/pipeline/PipelineTest.java b/h2o-core/src/test/java/hex/pipeline/PipelineTest.java deleted file mode 100644 index ca7723131d51..000000000000 --- a/h2o-core/src/test/java/hex/pipeline/PipelineTest.java +++ /dev/null @@ -1,510 +0,0 @@ -package hex.pipeline; - -import hex.Model; -import hex.ModelBuilder; -import hex.pipeline.DataTransformerTest.*; -import hex.pipeline.DataTransformerTest.FrameTrackerAsTransformer.Transformation; -import hex.pipeline.DataTransformerTest.FrameTrackerAsTransformer.Transformations; -import hex.pipeline.PipelineModel.PipelineOutput; -import hex.pipeline.PipelineModel.PipelineParameters; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import water.*; -import water.fvec.Frame; -import water.fvec.TestFrameBuilder; -import water.fvec.Vec; -import water.junit.rules.DKVIsolation; -import water.junit.rules.ScopeTracker; -import water.runner.CloudSize; -import water.runner.H2ORunner; -import water.test.dummy.DummyModel; -import water.test.dummy.DummyModelParameters; -import water.util.ArrayUtils; - -import java.util.Arrays; -import java.util.stream.IntStream; -import java.util.stream.Stream; - -import static org.junit.Assert.*; -import static water.TestUtil.*; - -@RunWith(H2ORunner.class) -@CloudSize(1) -public class PipelineTest { - - @Rule - public ScopeTracker scope = new ScopeTracker(); - - @Rule - public DKVIsolation isolation = new DKVIsolation(); - - private void checkFrameState(Frame fr) { - assertNotNull(fr.getKey()); - assertNotNull(DKV.get(fr.getKey())); - assertFrameEquals(fr, DKV.getGet(fr.getKey()), 1e-10); - for (int i=0; i k = fr.keys()[i]; - assertNotNull(k); - assertNotNull(k.get()); - assertVecEquals(fr.vec(i), k.get(), 1e-10); - } - } - - @Test - public void test_simple_transformation_pipeline() { - PipelineParameters pparams = new PipelineParameters(); - FrameTrackerAsTransformer tracker = new FrameTrackerAsTransformer(); - pparams.setTransformers( - new MultiplyNumericColumnTransformer("two", 5).name("mult_5"), - new AddRandomColumnTransformer("foo").name("add_foo"), - new AddRandomColumnTransformer("bar").name("add_bar"), - tracker.name("tracker") - ); - final Frame fr = Scope.track(new TestFrameBuilder() - .withColNames("one", "two", "target") - .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) - .withDataForCol(0, ard(1, 2, 3)) - .withDataForCol(1, ard(3, 2, 1)) - .withDataForCol(2, ar("yes", "no", "yes")) - .build()); - - Vec notMult = fr.vec(1).makeCopy(); - Vec mult = fr.vec(1).makeZero(); mult.set(0, 3*5); mult.set(1, 2*5); mult.set(2, 1*5); // a 5-mult vec of column "two" hand-made for test assertions. - - pparams._train = fr._key; - - Pipeline pipeline = ModelBuilder.make(pparams); - PipelineModel pmodel = Scope.track_generic(pipeline.trainModel().get()); - - assertNotNull(pmodel); - PipelineOutput output = pmodel._output; - assertNotNull(output); - assertNull(output._estimator); - assertNotNull(output._transformers); - assertEquals(4, output._transformers.length); - assertEquals(0, tracker.getTransformations().size()); - checkFrameState(fr); - assertVecEquals(notMult, fr.vec(1), 0); - - Frame scored = Scope.track(pmodel.score(fr)); - assertNotNull(scored); - TestUtil.printOutFrameAsTable(scored); - assertEquals(1, tracker.getTransformations().size()); - assertArrayEquals(new String[] {"one", "two", "target", "foo", "bar"}, scored.names()); - checkFrameState(fr); - checkFrameState(scored); - assertVecEquals(notMult, fr.vec(1), 0); - assertVecEquals(mult, scored.vec(1), 0); - - Frame rescored = Scope.track(pmodel.score(fr)); - TestUtil.printOutFrameAsTable(rescored); - assertEquals(2, tracker.getTransformations().size()); - assertNotSame(scored, rescored); - assertFrameEquals(scored, rescored, 1.6); - checkFrameState(fr); - checkFrameState(rescored); - assertVecEquals(notMult, fr.vec(1), 0); - assertVecEquals(mult, rescored.vec(1), 0); - - Frame transformed = Scope.track(pmodel.transform(fr)); - TestUtil.printOutFrameAsTable(transformed); - assertEquals(3, tracker.getTransformations().size()); - assertNotSame(scored, transformed); - assertFrameEquals(scored, transformed, 1.6); - checkFrameState(fr); - checkFrameState(transformed); - assertVecEquals(notMult, fr.vec(1), 0); - assertVecEquals(mult, transformed.vec(1), 0); - } - - @Test - public void test_simple_classification_pipeline() { - PipelineParameters pparams = new PipelineParameters(); - FrameTrackerAsTransformer tracker = new FrameTrackerAsTransformer(); - pparams.setTransformers( - new AddRandomColumnTransformer("foo").name("add_foo"), - new AddRandomColumnTransformer("bar").name("add_bar"), - tracker.name("tracker") - ); - DummyModelParameters eparams = new DummyModelParameters(); - eparams._makeModel = true; - pparams._estimatorParams = eparams; - final Frame fr = Scope.track(new TestFrameBuilder() - .withColNames("one", "two", "target") - .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) - .withDataForCol(0, ard(1, 2, 3)) - .withDataForCol(1, ard(3, 2, 1)) - .withDataForCol(2, ar("yes", "no", "yes")) - .build()); - - pparams._train = fr._key; - pparams._response_column = "target"; - - Pipeline pipeline = ModelBuilder.make(pparams); - PipelineModel pmodel = Scope.track_generic(pipeline.trainModel().get()); - assertNotNull(pmodel); - PipelineOutput output = pmodel._output; - assertNotNull(output); - assertNotNull(output._estimator); - Model emodel = output._estimator.get(); - assertNotNull(emodel); - assertTrue(emodel instanceof DummyModel); - assertArrayEquals(new String[] {"one", "two", "foo", "bar", "target"}, emodel._output._names); - - assertEquals(1, tracker.getTransformations().size()); - checkFrameState(fr); - - Frame predictions = Scope.track(pmodel.score(fr)); - assertEquals(2, tracker.getTransformations().size()); - assertNotNull(predictions); - TestUtil.printOutFrameAsTable(predictions); - checkFrameState(fr); - checkFrameState(predictions); - - Frame transformed = Scope.track(pmodel.transform(fr)); - assertEquals(3, tracker.getTransformations().size()); - assertNotNull(transformed); - TestUtil.printOutFrameAsTable(transformed); - assertArrayEquals( - Arrays.stream(emodel._output._names).sorted().toArray(), //model reorders input columns to obtain this output - Arrays.stream(transformed.names()).sorted().toArray() - ); - checkFrameState(fr); - checkFrameState(transformed); - } - - @Test - public void test_simple_classification_pipeline_with_sensitive_cv() { - int nfolds = 3; - PipelineParameters pparams = new PipelineParameters(); - pparams._nfolds = nfolds; - FrameTrackerAsTransformer tracker = new FrameTrackerAsTransformer(); - pparams.setTransformers( - new AddRandomColumnTransformer("foo").name("add_foo"), - new AddRandomColumnTransformer("bar").name("add_bar"), - new AddDummyCVColumnTransformer("cv_fold", Vec.T_CAT).name("add_cv_fold"), - tracker.name("track") - ); - DummyModelParameters eparams = new DummyModelParameters(); - eparams._makeModel = true; - eparams._keep_cross_validation_models = true; - - pparams._estimatorParams = eparams; - final Frame fr = Scope.track(new TestFrameBuilder() - .withName("train") - .withColNames("one", "two", "target") - .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) - .withDataForCol(0, ard(1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1)) - .withDataForCol(1, ard(3, 2, 1, 6, 5, 4, 9, 8, 7, 2, 1, 0, 5, 4, 3, 8, 7, 6, 1)) - .withDataForCol(2, ar("y", "n", "y", "y", "y", "y", "n", "n", "n", "n", "y", "y", "n", "n", "y", "y", "y", "n", "n")) - .build()); - - pparams._train = fr._key; - pparams._response_column = "target"; - - Pipeline pipeline = ModelBuilder.make(pparams); - PipelineModel pmodel = Scope.track_generic(pipeline.trainModel().get()); - assertNotNull(pmodel); - PipelineOutput output = pmodel._output; - assertNotNull(output); - assertNotNull(output._estimator); - Model emodel = output._estimator.get(); - assertNotNull(emodel); - assertTrue(emodel instanceof DummyModel); - assertArrayEquals(new String[] {"one", "two", "foo", "bar", "target"}, emodel._output._names); - - Transformations transformations = tracker.getTransformations(); - System.out.println(transformations); - assertEquals(2*nfolds+1, tracker.getTransformations().size()); // nfolds * 2 [train+valid] + 1 [final model, train only] - assertNotEquals(fr.getKey().toString(), transformations.transformations[0].frameId); // training frame for final model transformed first - assertTrue(transformations.transformations[0].frameId.startsWith(fr.getKey().toString()+"@@Training_trf_by_add_bar")); - assertEquals(DataTransformer.FrameType.Training, transformations.transformations[0].type); - assertFalse(transformations.transformations[0].is_cv); - assertEquals(nfolds*2, Stream.of(transformations.transformations).filter(t -> t.is_cv).count()); - assertEquals(nfolds, Stream.of(transformations.transformations).filter(t -> t.is_cv && t.type == DataTransformer.FrameType.Training).count()); - assertEquals(nfolds, Stream.of(transformations.transformations).filter(t -> t.is_cv && t.type == DataTransformer.FrameType.Validation).count()); - assertEquals(nfolds, emodel._output._cross_validation_models.length); - for (int i=0; i { -// TestUtil.printOutFrameAsTable(fr); - String[][] unencodedDomains = {{"a", "b", "c", "d"}, {"a", "b", "c", "d"}, {"n", "y"}}; - if (ArrayUtils.contains(fr.names(), ModelBuilder.CV_WEIGHTS_COLUMN)) { - assertEquals(ModelBuilder.CV_WEIGHTS_COLUMN, fr.name(fr.names().length-1)); - // weights col is numerical - assertArrayEquals(ArrayUtils.append(unencodedDomains, new String[][] { null }), fr.domains()); - } else { - assertArrayEquals(unencodedDomains, fr.domains()); - } - }); - FrameTrackerAsTransformer tracker = new FrameTrackerAsTransformer(); - pparams.setTransformers( - checker.name("check_frame_not_encoded"), - new AddRandomColumnTransformer("foo").name("add_foo"), - new AddRandomColumnTransformer("bar").name("add_bar"), - new AddDummyCVColumnTransformer("cv_fold").name("add_cv_fold"), - tracker.name("tracker") - ); - DummyModelParameters eparams = new DummyModelParameters(); - eparams._makeModel = true; - eparams._keep_cross_validation_models = true; - eparams._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.LabelEncoder; - - pparams._estimatorParams = eparams; - final Frame fr = Scope.track(new TestFrameBuilder() - .withColNames("one", "two", "target") - .withVecTypes(Vec.T_CAT, Vec.T_CAT, Vec.T_CAT) - .withDataForCol(0, ar("a", "b", "c", "a", "b", "c", "a", "b", "c", "d", "c", "b", "a", "c", "b", "a", "c", "b", "a")) - .withDataForCol(1, ar("c", "b", "a", "c", "b", "a", "c", "b", "a", "c", "b", "d", "b", "a", "c", "b", "a", "c", "b")) - .withDataForCol(2, ar("y", "n", "y", "y", "y", "y", "n", "n", "n", "n", "y", "y", "n", "n", "y", "y", "y", "n", "n")) - .build()); - - pparams._train = fr._key; - pparams._response_column = "target"; - - Pipeline pipeline = ModelBuilder.make(pparams); - PipelineModel pmodel = Scope.track_generic(pipeline.trainModel().get()); - assertNotNull(pmodel); - PipelineOutput output = pmodel._output; - assertNotNull(output); - assertNotNull(output._estimator); - Model emodel = output._estimator.get(); - assertNotNull(emodel); - assertTrue(emodel instanceof DummyModel); - assertArrayEquals(new String[][] {{"a", "b", "c", "d"}, {"a", "b", "c", "d"}, {"n", "y"}}, pmodel._output._domains); - assertArrayEquals(new String[][] {{"a", "b", "c", "d"}, {"a", "b", "c", "d"}, null/*foo*/, null/*bar*/, {"n", "y"}}, emodel._output._origDomains); - assertArrayEquals(new String[][] {null/*encoded one*/, null/*encoded two*/, null/*foo*/, null/*bar*/, {"n", "y"}}, emodel._output._domains); - - assertEquals(nfolds, emodel._output._cross_validation_models.length); - for (Key km : emodel._output._cross_validation_models) { - DummyModel cvModel = km.get(); - assertNotNull(cvModel); - assertArrayEquals(new String[][] {{"a", "b", "c", "d"}, {"a", "b", "c", "d"}, null/*foo*/, null/*bar*/, null/*cv_fold*/, null/*cv_weights*/, {"n", "y"}}, cvModel._output._origDomains); - assertArrayEquals(new String[][] {null/*encoded one*/, null/*encoded two*/, null/*foo*/, null/*bar*/ , null/*cv_fold*/, null/*cv_weights*/, {"n", "y"}}, cvModel._output._domains); - } - checkFrameState(fr); - } -} diff --git a/h2o-core/src/test/java/hex/pipeline/TransformerChainTest.java b/h2o-core/src/test/java/hex/pipeline/TransformerChainTest.java deleted file mode 100644 index daad6764bd65..000000000000 --- a/h2o-core/src/test/java/hex/pipeline/TransformerChainTest.java +++ /dev/null @@ -1,265 +0,0 @@ -package hex.pipeline; - -import hex.pipeline.DataTransformer.FrameType; -import hex.pipeline.DataTransformerTest.AddRandomColumnTransformer; -import hex.pipeline.DataTransformerTest.FailingAssertionTransformer; -import hex.pipeline.DataTransformerTest.FrameTrackerAsTransformer; -import hex.pipeline.PipelineModel.PipelineParameters; -import org.junit.Test; -import org.junit.runner.RunWith; -import water.Scope; -import water.fvec.Frame; -import water.fvec.TestFrameBuilder; -import water.fvec.Vec; -import water.runner.CloudSize; -import water.runner.H2ORunner; -import water.util.ArrayUtils; -import water.util.Pair; - -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Collectors; - -import static hex.pipeline.DataTransformerTest.*; -import static org.junit.Assert.*; -import static water.TestUtil.*; - -@RunWith(H2ORunner.class) -@CloudSize(1) -public class TransformerChainTest { - - @Test - public void test_transformers_are_applied_in_sequential_order() { - try { - Scope.enter(); - final Frame fr = Scope.track(dummyFrame()); - FrameTrackerAsTransformer tracker = new FrameTrackerAsTransformer().name("tracker"); - DataTransformer[] dts = new DataTransformer[] { - tracker, - new FrameCheckerAsTransformer(f -> { - assertFalse(ArrayUtils.contains(f.names(), "foo")); - assertFalse(ArrayUtils.contains(f.names(), "bar")); - }), - new AddRandomColumnTransformer("foo").name("add_foo"), - new FrameCheckerAsTransformer(f -> { - assertTrue(ArrayUtils.contains(f.names(), "foo")); - assertFalse(ArrayUtils.contains(f.names(), "bar")); - }), - tracker, - new AddRandomColumnTransformer("bar").name("add_bar"), - new FrameCheckerAsTransformer(f -> { - assertTrue(ArrayUtils.contains(f.names(), "foo")); - assertTrue(ArrayUtils.contains(f.names(), "bar")); - }), - tracker, - new RenameFrameTransformer("dumdum").name("rename_dumdum"), - tracker, - }; - TransformerChain chain = Scope.track_generic(new TransformerChain(dts).init()); - chain.prepare(null); - Frame transformed = chain.transform(fr); - assertEquals("dumdum", transformed.getKey().toString()); - assertArrayEquals(ArrayUtils.append(fr.names(), new String[]{"foo", "bar"}), transformed.names()); - assertEquals(4, tracker.getTransformations().size()); - } finally { - Scope.exit(); - } - } - - @Test - - public void test_a_failing_transformer_fails_the_entire_chain() { - try { - Scope.enter(); - final Frame fr = Scope.track(dummyFrame()); - DataTransformer[] dts = new DataTransformer[] { - new AddRandomColumnTransformer("foo").name("add_foo"), - new FailingAssertionTransformer().name("failing_assertion"), - new AddRandomColumnTransformer("bar").name("add_bar"), - new RenameFrameTransformer("dumdum").name("rename_dumdum"), - }; - TransformerChain chain = Scope.track_generic(new TransformerChain(dts).init()); - AssertionError err = assertThrows(AssertionError.class, () -> chain.transform(fr)); - assertEquals("expected", err.getMessage()); - } finally { - Scope.exit(); - } - } - - - @Test - public void test_transformers_are_applied_in_sequential_order_using_context_tracker() { - try { - Scope.enter(); - final Frame fr = Scope.track(dummyFrame()); - DataTransformer[] dts = new DataTransformer[] { - new AddRandomColumnTransformer("foo").name("add_foo"), - new AddRandomColumnTransformer("bar").name("add_bar"), - new RenameFrameTransformer("dumdum").name("rename_dumdum"), - }; - List dtIds = Arrays.stream(dts).map(DataTransformer::name).collect(Collectors.toList()); - final AtomicInteger dtIdx = new AtomicInteger(0); - PipelineContext context = new PipelineContext(new PipelineParameters(), new FrameTracker() { - @Override - public void apply(Frame transformed, Frame original, FrameType type, PipelineContext context, DataTransformer transformer) { - assertTrue(dtIds.contains(transformer.name())); - switch (transformer.name()) { - case "add_foo": - assertEquals(0, dtIdx.getAndIncrement()); - assertEquals(fr, original); - assertArrayEquals(ArrayUtils.append(fr.names(), new String[]{"foo"}), transformed.names()); - break; - case "add_bar": - assertEquals(1, dtIdx.getAndIncrement()); - assertArrayEquals(ArrayUtils.append(fr.names(), new String[]{"foo", "bar"}), transformed.names()); - break; - case "rename_dumdum": - assertEquals(2, dtIdx.getAndIncrement()); - assertArrayEquals(ArrayUtils.append(fr.names(), new String[]{"foo", "bar"}), transformed.names()); - assertEquals("dumdum", transformed.getKey().toString()); - break; - } - } - }, null); - - TransformerChain chain = Scope.track_generic(new TransformerChain(dts).init()); - Frame transformed = chain.transform(fr); - assertEquals("dumdum", transformed.getKey().toString()); - assertArrayEquals(ArrayUtils.append(fr.names(), new String[]{"foo", "bar"}), transformed.names()); - } finally { - Scope.exit(); - } - } - - @Test - public void test_chain_can_be_applied_multiple_times() { - try { - Scope.enter(); - final Frame fr1 = Scope.track(dummyFrame()); - final Frame fr2 = Scope.track(oddFrame()); - DataTransformer[] dts = new DataTransformer[] { - new AddRandomColumnTransformer("foo").name("add_foo"), - new AddRandomColumnTransformer("bar").name("add_bar"), - new RenameFrameTransformer("dumdum").name("rename_dumdum"), - }; - TransformerChain chain = Scope.track_generic(new TransformerChain(dts).init()); - - Frame tr1 = chain.transform(fr1); - assertEquals("dumdum", tr1.getKey().toString()); - assertArrayEquals(ArrayUtils.append(fr1.names(), new String[]{"foo", "bar"}), tr1.names()); - - Frame tr1bis = chain.transform(fr1); - assertNotSame(tr1, tr1bis); - assertEquals(tr1.getKey(), tr1bis.getKey()); - assertArrayEquals(tr1.names(), tr1bis.names()); - assertFrameEquals(tr1, tr1bis, 1e-10); - - Frame tr2 = chain.transform(fr2); - assertEquals("dumdum", tr2.getKey().toString()); - assertArrayEquals(ArrayUtils.append(fr2.names(), new String[]{"foo", "bar"}), tr2.names()); - } finally { - Scope.exit(); - } - } - - - @Test - public void test_chain_can_be_finalized_by_completer() { - try { - Scope.enter(); - final Frame fr1 = Scope.track(dummyFrame()); - final Frame fr2 = Scope.track(oddFrame()); - DataTransformer[] dts = new DataTransformer[] { - new AddRandomColumnTransformer("foo").name("add_foo"), - new AddRandomColumnTransformer("bar").name("add_bar"), - new RenameFrameTransformer("dumdum").name("rename_dumdum"), - }; - TransformerChain chain = Scope.track_generic(new TransformerChain(dts).init()); - TransformerChain.Completer> dim = new TransformerChain.UnaryCompleter>() { - @Override - public Pair apply(Frame frame, PipelineContext context) { - return new Pair<>(frame.numRows(), (long)frame.numCols()); - } - }; - - assertEquals(new Pair<>(3L, 5L), chain.transform(fr1, FrameType.Test, null, dim)); - assertEquals(new Pair<>(3L, 7L), chain.transform(fr2, FrameType.Test, null, dim)); - } finally { - Scope.exit(); - } - - } - - @Test - public void test_multiple_frames_can_be_transformed_at_once_to_feed_a_completer() { - try { - Scope.enter(); - final Frame fr1 = Scope.track(dummyFrame()); - final Frame fr2 = Scope.track(oddFrame()); - DataTransformer[] dts = new DataTransformer[] { - new AddRandomColumnTransformer("foo").name("add_foo"), - new AddRandomColumnTransformer("bar").name("add_bar"), - new RenameFrameTransformer("dumdum").name("rename_dumdum"), - }; - TransformerChain chain = Scope.track_generic(new TransformerChain(dts).init()); - - Frame[] transformed = chain.transform( - new Frame[] {fr1, fr1, fr2}, - new FrameType[]{FrameType.Test, FrameType.Test, FrameType.Test}, - null, - (fs, c) -> fs - ); - Frame tr1 = transformed[0]; - assertEquals("dumdum", tr1.getKey().toString()); - assertArrayEquals(ArrayUtils.append(fr1.names(), new String[]{"foo", "bar"}), tr1.names()); - - Frame tr1bis = transformed[1]; - assertNotSame(tr1, tr1bis); - assertEquals(tr1.getKey(), tr1bis.getKey()); - assertArrayEquals(tr1.names(), tr1bis.names()); - assertFrameEquals(tr1, tr1bis, 1e-10); - - Frame tr2 = transformed[2]; - assertEquals("dumdum", tr2.getKey().toString()); - assertArrayEquals(ArrayUtils.append(fr2.names(), new String[]{"foo", "bar"}), tr2.names()); - - Set uniqueCols = chain.transform( - new Frame[] {fr1, fr1, fr2}, - new FrameType[]{FrameType.Test, FrameType.Test, FrameType.Test}, - null, - (Frame[] fs, PipelineContext c) -> - Arrays.stream(fs) - .flatMap(f -> Arrays.stream(f.names())) - .collect(Collectors.toSet()) - ); - assertEquals(new HashSet<>(Arrays.asList("one", "two", "three", "five", "seven", "nine", "foo", "bar")), uniqueCols); - } finally { - Scope.exit(); - } - } - - private Frame dummyFrame() { - return new TestFrameBuilder() - .withColNames("one", "two", "three") - .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM) - .withDataForCol(0, ard(1, 1, 1)) - .withDataForCol(1, ard(2, 2, 2)) - .withDataForCol(2, ard(3, 3, 3)) - .build(); - } - private Frame oddFrame() { - return new TestFrameBuilder() - .withColNames("one", "three", "five", "seven", "nine") - .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_NUM) - .withDataForCol(0, ard(1, 1, 1)) - .withDataForCol(1, ard(3, 3, 3)) - .withDataForCol(2, ard(5, 5, 5)) - .withDataForCol(3, ard(7, 7, 7)) - .withDataForCol(4, ard(9, 9, 9)) - .build(); - } - -} diff --git a/h2o-core/src/test/java/hex/pipeline/transformers/KFoldColumnGeneratorTest.java b/h2o-core/src/test/java/hex/pipeline/transformers/KFoldColumnGeneratorTest.java deleted file mode 100644 index 8d6b6a6ee431..000000000000 --- a/h2o-core/src/test/java/hex/pipeline/transformers/KFoldColumnGeneratorTest.java +++ /dev/null @@ -1,108 +0,0 @@ -package hex.pipeline.transformers; - -import hex.Model.Parameters.FoldAssignmentScheme; -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineContext; -import hex.pipeline.PipelineModel.PipelineParameters; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import water.Scope; -import water.TestUtil; -import water.fvec.Frame; -import water.fvec.TestFrameBuilder; -import water.fvec.Vec; -import water.junit.rules.ScopeTracker; -import water.runner.CloudSize; -import water.runner.H2ORunner; -import water.util.ArrayUtils; - -import static hex.pipeline.DataTransformer.FrameType.Training; -import static hex.pipeline.DataTransformer.FrameType.Validation; -import static org.junit.Assert.*; -import static water.TestUtil.ar; -import static water.TestUtil.ard; - -@RunWith(H2ORunner.class) -@CloudSize(1) -public class KFoldColumnGeneratorTest { - - @Rule - public ScopeTracker scope = new ScopeTracker(); - - @Test - public void test_transformer_modifies_only_training_frames() { - String foldc = "foldc"; - int nfolds = 3; - DataTransformer kfold = new KFoldColumnGenerator(foldc, FoldAssignmentScheme.AUTO, nfolds, 0); - final Frame fr = scope.track(new TestFrameBuilder() - .withColNames("one", "two", "target") - .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) - .withDataForCol(0, ard(1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3)) - .withDataForCol(1, ard(3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1)) - .withDataForCol(2, ar("y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y")) - .build()); - - Frame transformed = kfold.transform(fr); - assertSame(fr, transformed); - - Frame validTransformed = kfold.transform(fr, Validation, null); - assertSame(fr, validTransformed); - - Frame trainTransformed = kfold.transform(fr, Training, null); - TestUtil.printOutFrameAsTable(trainTransformed); - assertArrayEquals(new String[] {"one", "two", "target", foldc}, trainTransformed.names()); - assertEquals(0, (int) trainTransformed.vec(3).min()); - assertEquals(nfolds-1, (int) trainTransformed.vec(3).max()); - } - - - @Test - public void test_transformer_autodetection_in_pipeline_context() { - DataTransformer kfold = Scope.track_generic(new KFoldColumnGenerator()); - final Frame fr = scope.track(new TestFrameBuilder() - .withColNames("one", "two", "target") - .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) - .withDataForCol(0, ard(1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3)) - .withDataForCol(1, ard(3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1)) - .withDataForCol(2, ar("y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y")) - .build()); - - PipelineParameters params = new PipelineParameters(); - params._nfolds = 4; - params._fold_assignment = FoldAssignmentScheme.Stratified; - params._response_column = "target"; - PipelineContext context = new PipelineContext(params); - kfold.prepare(context); - assertFalse(ArrayUtils.contains(fr.names(), context._params._fold_column)); - Frame trainTransformed = kfold.transform(fr, Training, context); - TestUtil.printOutFrameAsTable(trainTransformed); - assertArrayEquals(new String[] {"one", "two", "target", context._params._fold_column}, trainTransformed.names()); - assertEquals(0, (int) trainTransformed.vec(3).min()); - assertEquals(params._nfolds-1, (int) trainTransformed.vec(3).max()); - //hard to verify that it's properly stratified on a smaller dataset given the algo being used, but trusting the utility method that is used is other places. - } - - @Test - public void test_transformer_transforms_context_train_frame_during_prepare() { - DataTransformer kfold = Scope.track_generic(new KFoldColumnGenerator()); - final Frame fr = scope.track(new TestFrameBuilder() - .withColNames("one", "two", "target") - .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) - .withDataForCol(0, ard(1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3)) - .withDataForCol(1, ard(3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1)) - .withDataForCol(2, ar("y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y")) - .build()); - - PipelineParameters params = new PipelineParameters(); - params._nfolds = 3; - params._response_column = "target"; - PipelineContext context = new PipelineContext(params); - context.setTrain(fr); - kfold.prepare(context); - assertFalse(ArrayUtils.contains(fr.names(), context._params._fold_column)); - assertTrue(ArrayUtils.contains(context.getTrain().names(), context._params._fold_column)); - TestUtil.printOutFrameAsTable(context.getTrain()); - } - -} diff --git a/h2o-core/src/test/java/hex/pipeline/transformers/ModelAsFeatureTransformerTest.java b/h2o-core/src/test/java/hex/pipeline/transformers/ModelAsFeatureTransformerTest.java deleted file mode 100644 index 013a75fe20ee..000000000000 --- a/h2o-core/src/test/java/hex/pipeline/transformers/ModelAsFeatureTransformerTest.java +++ /dev/null @@ -1,157 +0,0 @@ -package hex.pipeline.transformers; - -import hex.Model; -import hex.pipeline.PipelineContext; -import hex.pipeline.PipelineModel; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import water.Key; -import water.Scope; -import water.TestUtil; -import water.fvec.Frame; -import water.fvec.TestFrameBuilder; -import water.fvec.Vec; -import water.junit.rules.ScopeTracker; -import water.runner.CloudSize; -import water.runner.H2ORunner; -import water.test.dummy.DummyModel; -import water.test.dummy.DummyModelBuilder; -import water.test.dummy.DummyModelParameters; - -import static org.junit.Assert.*; -import static water.TestUtil.*; - -@RunWith(H2ORunner.class) -@CloudSize(1) -public class ModelAsFeatureTransformerTest { - - private static class DummyModelAsFeatureTransformer extends ModelAsFeatureTransformer { - - boolean cvSensitive = false; - - public DummyModelAsFeatureTransformer(DummyModelParameters params) { - super(params); - } - - public DummyModelAsFeatureTransformer(DummyModelParameters params, Key modelKey) { - super(params, modelKey); - } - - @Override - public boolean isCVSensitive() { - return cvSensitive; - } - } - - - @Rule - public ScopeTracker scope = new ScopeTracker(); - - @Test - public void test_delegate_model_trained_and_cached_in_prepare() { - DummyModelParameters params = makeModelParams(null); - DummyModelAsFeatureTransformer transformer = Scope.track_generic(new DummyModelAsFeatureTransformer(params)); - assertNull(transformer.getModel()); - - Frame fr = makeTrain(); - PipelineModel.PipelineParameters pParams = new PipelineModel.PipelineParameters(); - assignTrainingParams(pParams, fr); - PipelineContext context = new PipelineContext(pParams); - transformer.prepare(context); - DummyModel m1 = Scope.track_generic(transformer.getModel()); - assertNotNull(m1); - long m1Checksum = m1.checksum(); - - transformer.prepare(context); - DummyModel m2 = transformer.getModel(); - assertSame(m1, m2); - assertEquals("model shouldn't be modified during second prepare", m1Checksum, m2.checksum()); - } - - @Test - public void test_transform_delegates_to_provided_pretrained_model() { - Frame fr = makeTrain(); - DummyModelParameters params = makeModelParams(fr); - DummyModel model = Scope.track_generic(new DummyModelBuilder(params).trainModel().get()); - long oriChecksum = model.checksum(); - DummyModelAsFeatureTransformer transformer = Scope.track_generic(new DummyModelAsFeatureTransformer(params, model._key)); - assertNotNull(transformer.getModel()); - - PipelineContext context = new PipelineContext(new PipelineModel.PipelineParameters()); - transformer.prepare(context); - DummyModel m = transformer.getModel(); - assertSame(model, m); - assertEquals("model shouldn't be modified during prepare", oriChecksum, m.checksum()); - } - - - @Test - public void test_transform_delegates_to_internally_trained_model() { - Frame fr = makeTrain(); - DummyModelParameters params = makeModelParams(fr); - DummyModel model = Scope.track_generic(new DummyModelBuilder(params).trainModel().get()); - DummyModelAsFeatureTransformer transformer = Scope.track_generic(new DummyModelAsFeatureTransformer(params, model._key)); - assertNotNull(transformer.getModel()); - - PipelineContext context = new PipelineContext(new PipelineModel.PipelineParameters()); - transformer.prepare(context); - DummyModel m = transformer.getModel(); - assertSame(model, m); - assertEquals("model shouldn't be modified during prepare", model.checksum(), m.checksum()); - - Frame trans = transformer.transform(fr); - assertEquals(fr.getKey()+"_stats", trans.getKey().toString()); - } - - - @Test - public void test_data_input_model_params_are_detected_from_context() { - DummyModelParameters params = makeModelParams(null); - DummyModelAsFeatureTransformer transformer = Scope.track_generic(new DummyModelAsFeatureTransformer(params)); - assertNull(transformer.getModel()); - - PipelineModel.PipelineParameters pParams = new PipelineModel.PipelineParameters(); - pParams._response_column = "target"; - PipelineContext context = new PipelineContext(pParams); - Frame train = makeTrain(); - context.setTrain(train); - transformer.prepare(context); - DummyModel m = Scope.track_generic(transformer.getModel()); - assertNotNull(m); - assertEquals("target", m._parms._response_column); - assertEquals(train._key, m._parms._train); - - Frame trans = transformer.transform(train); - TestUtil.printOutFrameAsTable(trans); - assertEquals(train.getKey()+"_stats", trans.getKey().toString()); - } - - @Test - public void test_grid_like_scenario() { - - } - - private DummyModelParameters makeModelParams(Frame train) { - DummyModelParameters params = new DummyModelParameters(); - params._makeModel = true; - if (train != null) assignTrainingParams(params, train); - return params; - } - - private void assignTrainingParams(Model.Parameters params, Frame train) { - params._train = train.getKey(); - params._response_column = train.name(train._names.length - 1); - } - - private Frame makeTrain() { - return scope.track(new TestFrameBuilder() - .withColNames("one", "two", "target") - .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) - .withDataForCol(0, ard(1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3)) - .withDataForCol(1, ard(3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1)) - .withDataForCol(2, ar("y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y", "n", "y")) - .build()); - } - -} diff --git a/h2o-core/src/test/java/hex/pipeline/transformers/UnionTransformerTest.java b/h2o-core/src/test/java/hex/pipeline/transformers/UnionTransformerTest.java deleted file mode 100644 index aadf4489b048..000000000000 --- a/h2o-core/src/test/java/hex/pipeline/transformers/UnionTransformerTest.java +++ /dev/null @@ -1,118 +0,0 @@ -package hex.pipeline.transformers; - -import hex.pipeline.DataTransformer; -import hex.pipeline.PipelineContext; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import water.MRTask; -import water.Scope; -import water.fvec.*; -import water.junit.rules.ScopeTracker; -import water.runner.CloudSize; -import water.runner.H2ORunner; -import water.util.StringUtils; - -import static water.TestUtil.ar; -import static water.TestUtil.assertFrameEquals; - -@RunWith(H2ORunner.class) -@CloudSize(1) -public class UnionTransformerTest { - - @Rule - public ScopeTracker scope = new ScopeTracker(); - - @Test - public void test_union_transform_append() { - DataTransformer plus = new StrConcatColumnsTransformer("+"); - DataTransformer minus = new StrConcatColumnsTransformer("-"); - - DataTransformer union = new UnionTransformer(UnionTransformer.UnionStrategy.append, plus, minus); - final Frame fr = Scope.track(new TestFrameBuilder() - .withColNames("one", "two") - .withVecTypes(Vec.T_STR, Vec.T_STR) - .withDataForCol(0, ar("un", "eins", "jeden")) - .withDataForCol(1, ar("deux", "zwei", "dva")) - .build()); - - final Frame expected = Scope.track(new TestFrameBuilder() - .withColNames("one", "two", "one+two", "one-two") - .withVecTypes(Vec.T_STR, Vec.T_STR, Vec.T_STR, Vec.T_STR) - .withDataForCol(0, ar("un", "eins", "jeden")) - .withDataForCol(1, ar("deux", "zwei", "dva")) - .withDataForCol(2, ar("un+deux", "eins+zwei", "jeden+dva")) - .withDataForCol(3, ar("un-deux", "eins-zwei", "jeden-dva")) - .build()); - - Frame transformed = Scope.track(union.transform(fr)); - assertFrameEquals(expected, transformed, 0); - } - - @Test - public void test_union_transform_replace() { - DataTransformer plus = new StrConcatColumnsTransformer("+"); - DataTransformer minus = new StrConcatColumnsTransformer("-"); - - DataTransformer union = new UnionTransformer(UnionTransformer.UnionStrategy.replace, plus, minus); - final Frame fr = Scope.track(new TestFrameBuilder() - .withColNames("one", "two") - .withVecTypes(Vec.T_STR, Vec.T_STR) - .withDataForCol(0, ar("un", "eins", "jeden")) - .withDataForCol(1, ar("deux", "zwei", "dva")) - .build()); - - final Frame expected = Scope.track(new TestFrameBuilder() - .withColNames("one+two", "one-two") - .withVecTypes(Vec.T_STR, Vec.T_STR) - .withDataForCol(0, ar("un+deux", "eins+zwei", "jeden+dva")) - .withDataForCol(1, ar("un-deux", "eins-zwei", "jeden-dva")) - .build()); - - Frame transformed = Scope.track(union.transform(fr)); - assertFrameEquals(expected, transformed, 0); - } - - public static class StrConcatColumnsTransformer extends DataTransformer { - - private String separator; - - public StrConcatColumnsTransformer(String separator) { - this.separator = separator; - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - return new VecsConcatenizer(separator).concat(fr); - } - - public static class VecsConcatenizer extends MRTask { - private final String _sep; - - public VecsConcatenizer(String separator) { - _sep = separator; - } - - @Override - public void map(Chunk[] cs, NewChunk nc) { - for (int row = 0; row < cs[0]._len; row++) { - StringBuilder tmpStr = new StringBuilder(); - for (int col = 0; col < cs.length; col++) { - Chunk chk = cs[col]; - if (chk.isNA(row)) continue; - String s = chk.stringAt(row); - tmpStr.append(s); - if (col+1 < cs.length) tmpStr.append(_sep); - } - nc.addStr(tmpStr.toString()); - } - } - - public Frame concat(Frame fr) { - String name = StringUtils.join(_sep, fr.names()); - return doAll(Vec.T_STR, fr).outputFrame(new String[]{name}, null); - } - } - } - -} diff --git a/h2o-core/src/test/java/water/AutoBufferTest.java b/h2o-core/src/test/java/water/AutoBufferTest.java index 9ab0aa09fa29..d7cc31fe431e 100644 --- a/h2o-core/src/test/java/water/AutoBufferTest.java +++ b/h2o-core/src/test/java/water/AutoBufferTest.java @@ -8,7 +8,6 @@ import java.io.Serializable; import java.util.Arrays; -import java.util.stream.Collectors; import static org.junit.Assert.*; diff --git a/h2o-core/src/test/java/water/util/IcedHashMapTest.java b/h2o-core/src/test/java/water/util/IcedHashMapTest.java index 6ad35f1b1884..bf9d0392ef6e 100644 --- a/h2o-core/src/test/java/water/util/IcedHashMapTest.java +++ b/h2o-core/src/test/java/water/util/IcedHashMapTest.java @@ -188,14 +188,6 @@ private void testWriteJSON(Map map, Type type) { testWriteJSON(map); } - @Test public void testLongFreezable() { - final Map map = Collections.unmodifiableMap(new HashMap() {{ - put(1L, Key.make("one")); - put(2L, Key.make("two")); - put(3L, Key.make("three")); - }}); - testWriteRead(map); - } @Test public void testStringFreezable() { final Map map = Collections.unmodifiableMap(new HashMap() {{ put("one", Key.make("one")); diff --git a/h2o-core/src/test/java/water/util/Sort.java b/h2o-core/src/test/java/water/util/Sort.java index cded20c82a6f..9ea2aa9f0c61 100644 --- a/h2o-core/src/test/java/water/util/Sort.java +++ b/h2o-core/src/test/java/water/util/Sort.java @@ -5,7 +5,6 @@ import hex.ModelMetrics; import org.junit.Ignore; -import water.Futures; import water.rapids.Merge; @Ignore @@ -47,8 +46,8 @@ protected int nModelsInParallel(int folds) { } @Override - protected ModelMetrics.MetricBuilder makeCVMetricBuilder(ModelBuilder cvModelBuilder, Futures fs) { - return null; + protected boolean makeCVMetrics(ModelBuilder cvModelBuilder) { + return false; } @Override diff --git a/h2o-core/testMultiNode.sh b/h2o-core/testMultiNode.sh index 4eb9d661fa70..d04e0982c5e0 100755 --- a/h2o-core/testMultiNode.sh +++ b/h2o-core/testMultiNode.sh @@ -99,8 +99,6 @@ JUNIT_RUNNER="water.junit.H2OTestRunner" # '/usr/bin/sort' needed to avoid windows native sort when run in cygwin (cd src/test/java; /usr/bin/find . -name '*.java' | cut -c3- | sed 's/.....$//' | sed -e 's/\//./g') | grep -v $JUNIT_TESTS_SLOW | grep -v $JUNIT_TESTS_BOOT | /usr/bin/sort > $OUTDIR/all_tests.txt -cat $OUTDIR/all_tests.txt | egrep "water\." > $OUTDIR/all_sorted_tests.txt -cat $OUTDIR/all_tests.txt | egrep -v "water\." >> $OUTDIR/all_sorted_tests.txt set -f # no globbing if [ foo"$DOONLY" = foo ]; then @@ -112,7 +110,7 @@ fi # Output the comma-separated list of ignored/dooonly tests # Ignored tests trump do-only tests -cat $OUTDIR/all_sorted_tests.txt | egrep -v "$IGNORE" > $OUTDIR/tests.not_ignored.txt +cat $OUTDIR/all_tests.txt | egrep -v "$IGNORE" > $OUTDIR/tests.not_ignored.txt cat $OUTDIR/tests.not_ignored.txt | egrep "$DOONLY" > $OUTDIR/tests.txt set +f diff --git a/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/TargetEncoderModel.java b/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/TargetEncoderModel.java index 7ce23ad65d12..6307d32693bd 100644 --- a/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/TargetEncoderModel.java +++ b/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/TargetEncoderModel.java @@ -448,12 +448,10 @@ private double defaultNoiseLevel(Frame fr, int targetIndex) { /** * Ideally there should be no need to deep copy columns that are not listed as input in _input_to_output_columns. - * However if we keep the original columns in the output, then they are deleted in the model integration: {@link hex.ModelBuilder#track}. - * On the other side, if copied as a "ShallowVec" (extending WrappedVec) to prevent deletion of data in trackFrame, - * then we expose WrappedVec to the client in all non-integration use cases, which is strongly discouraged. - * Catch-22 situation, so keeping the deepCopy for now. - * NOTE! New tracking keys logic should keep vecs from original training frame protected in any case (see {@link Scope#protect}), - * which should allow us to get rid of this deep copy, and always replace old vec by a new one when transforming it. + * However if we keep the original columns in the output, then they are deleted in the model integration: {@link hex.ModelBuilder#trackEncoded}. + * On the other side, if copied as a "ShallowVec" (extending WrappedVec) to prevent deletion of data in trackEncoded, + * then we expose WrappedVec to the client it all non-integration use cases, which is strongly discouraged. + * Catch-22 situation, so keeping the deepCopy for now as is occurs only for predictions, so the data are usually smaller. * @param fr * @return the working frame used to make predictions */ @@ -628,7 +626,7 @@ public Frame doApply(Frame fr, String columnToEncode, Frame encodings, String en Frame workingFrame = fr; int teColumnIdx = fr.find(columnToEncode); int foldColIdx; - if (_outOfFold == NO_FOLD) { + if (_outOfFold== NO_FOLD) { foldColIdx = fr.find(_foldColumn); } else { workingFrame = new Frame(fr); @@ -650,7 +648,7 @@ public Frame doApply(Frame fr, String columnToEncode, Frame encodings, String en maxFoldValue ); Scope.track(joinedFrame); - if (_outOfFold != NO_FOLD) { + if (_outOfFold!= NO_FOLD) { joinedFrame.remove(foldColIdx); } diff --git a/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/TargetEncoderPreprocessor.java b/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/TargetEncoderPreprocessor.java new file mode 100644 index 000000000000..b40e925a8743 --- /dev/null +++ b/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/TargetEncoderPreprocessor.java @@ -0,0 +1,61 @@ +package ai.h2o.targetencoding; + +import hex.Model; +import hex.ModelPreprocessor; +import water.DKV; +import water.Futures; +import water.Key; +import water.fvec.Frame; + +import java.util.Objects; + +import static ai.h2o.targetencoding.TargetEncoderModel.DataLeakageHandlingStrategy.*; + +public class TargetEncoderPreprocessor extends ModelPreprocessor { + + private TargetEncoderModel _targetEncoder; + + public TargetEncoderPreprocessor(TargetEncoderModel targetEncoder) { + super(Key.make(Objects.toString(targetEncoder._key)+"_preprocessor")); + this._targetEncoder = targetEncoder; + DKV.put(this); + } + + @Override + public Frame processTrain(Frame fr, Model.Parameters params) { + if (useFoldTransform(params)) { + return _targetEncoder.transformTraining(fr, params._cv_fold); + } else { + return _targetEncoder.transformTraining(fr); + } + } + + @Override + public Frame processValid(Frame fr, Model.Parameters params) { + if (useFoldTransform(params)) { + return _targetEncoder.transformTraining(fr); + } else { + return _targetEncoder.transform(fr); + } + } + + @Override + public Frame processScoring(Frame fr, Model model) { + return _targetEncoder.transform(fr); + } + + @Override + public Model asModel() { + return _targetEncoder; + } + + @Override + protected Futures remove_impl(Futures fs, boolean cascade) { + if (cascade && _targetEncoder != null) _targetEncoder.remove(); + return super.remove_impl(fs, cascade); + } + + private boolean useFoldTransform(Model.Parameters params) { + return params._is_cv_model && _targetEncoder._parms._data_leakage_handling == KFold; + } +} diff --git a/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/pipeline/transformers/FeatureInteractionTransformer.java b/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/pipeline/transformers/FeatureInteractionTransformer.java deleted file mode 100644 index 42d68cb475b6..000000000000 --- a/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/pipeline/transformers/FeatureInteractionTransformer.java +++ /dev/null @@ -1,43 +0,0 @@ -package ai.h2o.targetencoding.pipeline.transformers; - -import ai.h2o.targetencoding.interaction.InteractionSupport; -import hex.pipeline.DataTransformer; -import hex.pipeline.transformers.FeatureTransformer; -import hex.pipeline.PipelineContext; -import water.fvec.Frame; - -public class FeatureInteractionTransformer extends FeatureTransformer { - - private String[] _columns; - private String _interaction_column; - - private String[] _interaction_domain; - - protected FeatureInteractionTransformer() {} - - public FeatureInteractionTransformer(String[] columns) { - this(columns, null); - } - - public FeatureInteractionTransformer(String[] columns, String interactionColumn) { - _columns = columns; - _interaction_column = interactionColumn; - } - - @Override - protected void doPrepare(PipelineContext context) { - assert context != null; - assert context._params != null; - Frame train = new Frame(context.getTrain()); - // FIXME: InteractionSupport should be improved to not systematically modify frames in-place - int interactionCol = InteractionSupport.addFeatureInteraction(train, _columns); - _interaction_domain = train.vec(interactionCol).domain(); - train.remove(interactionCol); - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - InteractionSupport.addFeatureInteraction(fr, _columns, _interaction_domain); //FIXME: same as above. Also should be able to specify the interaction column name. - return fr; - } -} diff --git a/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/pipeline/transformers/TargetEncoderFeatureTransformer.java b/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/pipeline/transformers/TargetEncoderFeatureTransformer.java deleted file mode 100644 index b5b1a0680a86..000000000000 --- a/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/pipeline/transformers/TargetEncoderFeatureTransformer.java +++ /dev/null @@ -1,64 +0,0 @@ -package ai.h2o.targetencoding.pipeline.transformers; - -import ai.h2o.targetencoding.TargetEncoderModel; -import ai.h2o.targetencoding.TargetEncoderModel.TargetEncoderParameters; -import hex.Model; -import hex.pipeline.transformers.ModelAsFeatureTransformer; -import hex.pipeline.PipelineContext; -import water.Key; -import water.fvec.Frame; - -import static ai.h2o.targetencoding.TargetEncoderModel.DataLeakageHandlingStrategy.KFold; - -public class TargetEncoderFeatureTransformer extends ModelAsFeatureTransformer { - - public TargetEncoderFeatureTransformer(TargetEncoderParameters params) { - super(params); - } - - public TargetEncoderFeatureTransformer(TargetEncoderParameters params, Key modelKey) { - super(params, modelKey); - } - - @Override - public boolean isCVSensitive() { - return _params._data_leakage_handling == KFold; - } - - @Override - protected void prepareModelParams(PipelineContext context) { - super.prepareModelParams(context); - // TODO: future improvement: move some of the decision logic in `ai.h2o.automl.preprocessing.TargetEncoding` to this class - // especially the logic related with the dynamic column selection based on cardinality. - // By parametrizing this here, it allows us to consider parameters like `_columnCardinalityThreshold` as hyper-parameters in pipeline grids. - } - - @Override - protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { - assert type != null; - assert context != null || type == FrameType.Test; - validateTransform(); - switch (type) { - case Training: - if (useFoldTransform(context._params)) { - return getModel().transformTraining(fr, context._params._cv_fold); - } else { - return getModel().transformTraining(fr); - } - case Validation: - if (useFoldTransform(context._params)) { - return getModel().transformTraining(fr); - } else { - return getModel().transform(fr); - } - case Test: - default: - return getModel().transform(fr); - } - } - - private boolean useFoldTransform(Model.Parameters params) { - return isCVSensitive() && params._cv_fold >= 0; - } - -} diff --git a/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/pipeline/transformers/TargetEncoderFeatureTransformerTest.java b/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderPreprocessorTest.java similarity index 66% rename from h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/pipeline/transformers/TargetEncoderFeatureTransformerTest.java rename to h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderPreprocessorTest.java index 696ee3cdaf35..1981d2450e80 100644 --- a/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/pipeline/transformers/TargetEncoderFeatureTransformerTest.java +++ b/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderPreprocessorTest.java @@ -1,22 +1,22 @@ -package ai.h2o.targetencoding.pipeline.transformers; +package ai.h2o.targetencoding; import ai.h2o.targetencoding.TargetEncoderModel.DataLeakageHandlingStrategy; import ai.h2o.targetencoding.TargetEncoderModel.TargetEncoderParameters; import hex.Model; import hex.Model.Parameters.CategoricalEncodingScheme; -import hex.ModelBuilder; import hex.genmodel.MojoModel; +import hex.genmodel.algos.targetencoder.TargetEncoderMojoModel; import hex.genmodel.easy.EasyPredictModelWrapper; import hex.genmodel.easy.RowData; import hex.genmodel.easy.prediction.BinomialModelPrediction; -import hex.pipeline.Pipeline; -import hex.pipeline.PipelineModel; +import hex.tree.gbm.GBM; import hex.tree.gbm.GBMModel; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; +import water.Key; import water.Scope; import water.fvec.Frame; import water.fvec.TestFrameBuilder; @@ -24,20 +24,26 @@ import water.runner.CloudSize; import water.runner.H2ORunner; import water.util.ArrayUtils; +import water.util.RandomUtils; import java.io.File; import java.io.FileOutputStream; +import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Random; +import java.util.stream.IntStream; import java.util.stream.Stream; +import java.util.zip.ZipEntry; +import java.util.zip.ZipFile; import static org.junit.Assert.*; +import static org.mockito.Mockito.*; import static water.TestUtil.*; @RunWith(H2ORunner.class) @CloudSize(1) -public class TargetEncoderFeatureTransformerTest { +public class TargetEncoderPreprocessorTest { private static String TO_ENCODE = "categorical"; private static String ENCODED = "categorical_te"; @@ -56,24 +62,21 @@ public void test_model_building_with_CV_and_TE_KFold_strategy() { Frame train = makeTrainFrame(true); Frame valid = makeValidFrame(); - TargetEncoderParameters teParams = makeTEParams(train, DataLeakageHandlingStrategy.KFold, false, false); + TargetEncoderModel teModel = trainTE(train, DataLeakageHandlingStrategy.KFold, false, false); + Scope.track_generic(teModel); + TargetEncoderPreprocessor tePreproc = new TargetEncoderPreprocessor(teModel); + Scope.track_generic(tePreproc); - PipelineModel pModel = buildPipeline(train, null, teParams, CategoricalEncodingScheme.AUTO); - Scope.track_generic(pModel); + Model model = buildModel(train, null, tePreproc, CategoricalEncodingScheme.AUTO); + Scope.track_generic(model); int expectedCVModels = 3; //3 folds -> 3 cv models - assertEquals(expectedCVModels, pModel._output._cross_validation_models.length); - //assertions on the pipeline model (build on a clean frame) - assertFalse(ArrayUtils.contains(pModel._output._names, ENCODED)); - assertTrue(ArrayUtils.contains(pModel._output._names, TO_ENCODE)); - assertTrue(ArrayUtils.contains(pModel._output._names, NOT_ENCODED)); - // assertions on the estimator model (build on an encoded frame) - Model eModel = pModel._output.getEstimatorModel(); - assertTrue(ArrayUtils.contains(eModel._output._names, ENCODED)); - assertFalse(ArrayUtils.contains(eModel._output._names, TO_ENCODE)); - assertTrue(ArrayUtils.contains(eModel._output._names, NOT_ENCODED)); - - Frame preds = pModel.score(valid); + assertEquals(expectedCVModels, model._output._cross_validation_models.length); + assertTrue(ArrayUtils.contains(model._output._names, ENCODED)); + assertFalse(ArrayUtils.contains(model._output._names, TO_ENCODE)); + assertTrue(ArrayUtils.contains(model._output._names, NOT_ENCODED)); + + Frame preds = model.score(valid); Scope.track(preds); } finally { Scope.exit(); @@ -88,22 +91,19 @@ public void test_model_building_without_CV_and_with_TE_None_strategy() { Frame train = makeTrainFrame(false); Frame valid = makeValidFrame(); - TargetEncoderParameters teParams = makeTEParams(train, DataLeakageHandlingStrategy.None, false, false); + TargetEncoderModel teModel = trainTE(train, DataLeakageHandlingStrategy.None, false, false); + Scope.track_generic(teModel); + TargetEncoderPreprocessor tePreproc = new TargetEncoderPreprocessor(teModel); + Scope.track_generic(tePreproc); - PipelineModel pModel = buildPipeline(train, valid, teParams, CategoricalEncodingScheme.AUTO); - Scope.track_generic(pModel); + Model model = buildModel(train, valid, tePreproc, CategoricalEncodingScheme.AUTO); + Scope.track_generic(model); - //assertions on the pipeline model (build on a clean frame) - assertFalse(ArrayUtils.contains(pModel._output._names, ENCODED)); - assertTrue(ArrayUtils.contains(pModel._output._names, TO_ENCODE)); - assertTrue(ArrayUtils.contains(pModel._output._names, NOT_ENCODED)); - // assertions on the estimator model (build on an encoded frame) - Model eModel = pModel._output.getEstimatorModel(); - assertTrue(ArrayUtils.contains(eModel._output._names, ENCODED)); - assertFalse(ArrayUtils.contains(eModel._output._names, TO_ENCODE)); - assertTrue(ArrayUtils.contains(eModel._output._names, NOT_ENCODED)); + assertTrue(ArrayUtils.contains(model._output._names, ENCODED)); + assertFalse(ArrayUtils.contains(model._output._names, TO_ENCODE)); + assertTrue(ArrayUtils.contains(model._output._names, NOT_ENCODED)); - Frame preds = pModel.score(valid); + Frame preds = model.score(valid); Scope.track(preds); } finally { Scope.exit(); @@ -150,9 +150,9 @@ private Frame makeValidFrame() { return row; } - private TargetEncoderParameters makeTEParams(Frame train, DataLeakageHandlingStrategy strategy, boolean encodeAll, boolean keepOriginalCategoricalPredictors) { + private TargetEncoderModel trainTE(Frame train, DataLeakageHandlingStrategy strategy, boolean encodeAll, boolean keepOriginalCategoricalPredictors) { TargetEncoderParameters params = new TargetEncoderParameters(); - params._keep_original_categorical_columns = keepOriginalCategoricalPredictors; + params._keep_original_categorical_columns= keepOriginalCategoricalPredictors; params._train = train._key; params._response_column = TARGET; params._fold_column = ArrayUtils.contains(train.names(), FOLDC) ? FOLDC : null; @@ -160,32 +160,31 @@ private TargetEncoderParameters makeTEParams(Frame train, DataLeakageHandlingStr params._data_leakage_handling = strategy; params._noise = 0; params._seed = 42; - return params; + + TargetEncoder te = new TargetEncoder(params); + return te.trainModel().get(); } - private PipelineModel buildPipeline(Frame train, Frame valid, TargetEncoderParameters teParams, CategoricalEncodingScheme categoricalEncoding) { - GBMModel.GBMParameters eparams = new GBMModel.GBMParameters(); - eparams._min_rows = 1; - eparams._max_depth = 1; - eparams._categorical_encoding = categoricalEncoding; + private Model buildModel(Frame train, Frame valid, TargetEncoderPreprocessor preprocessor, CategoricalEncodingScheme categoricalEncoding) { + GBMModel.GBMParameters params = new GBMModel.GBMParameters(); + params._seed = 987; + params._train = train._key; + params._valid = valid == null ? null : valid._key; + params._response_column = TARGET; + params._preprocessors = preprocessor == null ? null : new Key[] {preprocessor._key}; + params._min_rows = 1; + params._max_depth = 1; + params._categorical_encoding = categoricalEncoding; if (ArrayUtils.contains(train.names(), FOLDC)) { - eparams._keep_cross_validation_models = true; - eparams._keep_cross_validation_predictions = true; + params._fold_column = FOLDC; + params._keep_cross_validation_models = true; + params._keep_cross_validation_predictions = true; } - PipelineModel.PipelineParameters pparams = new PipelineModel.PipelineParameters(); - TargetEncoderFeatureTransformer teTrans = new TargetEncoderFeatureTransformer(teParams).init(); - pparams.setTransformers(teTrans); - pparams._estimatorParams = eparams; - pparams._seed = 987; - pparams._train = train._key; - pparams._valid = valid == null ? null : valid._key; - pparams._response_column = TARGET; - pparams._fold_column = ArrayUtils.contains(train.names(), FOLDC) ? FOLDC : null; - Pipeline pipeline = ModelBuilder.make(pparams); - PipelineModel pmodel = Scope.track_generic(pipeline.trainModel().get()); - return pmodel; + GBM gbm = new GBM(params); + GBMModel model = gbm.trainModel().get(); + return model; } @@ -222,7 +221,7 @@ public void test_pubdev_7775() throws Exception { Scope.enter(); Frame train = makeTrainFrame(true); //without the fold column, the test pass: reordering issue - Model model = buildPipeline(train, null, null, CategoricalEncodingScheme.OneHotExplicit); + Model model = buildModel(train, null, null, CategoricalEncodingScheme.OneHotExplicit); Scope.track_generic(model); File mojoFile = folder.newFile(model._key+".zip"); diff --git a/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderRGSTest.java b/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderRGSTest.java index 8c90ac5f16be..601a833f965d 100644 --- a/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderRGSTest.java +++ b/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderRGSTest.java @@ -16,7 +16,7 @@ import water.Key; import water.Scope; import water.TestUtil; -import water.api.GridSearchHandler.SchemaModelParametersBuilderFactory; +import water.api.GridSearchHandler.DefaultModelParametersBuilderFactory; import water.fvec.Frame; import java.util.Arrays; @@ -49,7 +49,7 @@ public void getTargetEncodingMapByTrainingTEBuilder() { TargetEncoderParameters parameters = new TargetEncoderParameters(); - SchemaModelParametersBuilderFactory modelParametersBuilderFactory = new SchemaModelParametersBuilderFactory<>(); + DefaultModelParametersBuilderFactory modelParametersBuilderFactory = new DefaultModelParametersBuilderFactory<>(); RandomDiscreteValueSearchCriteria hyperSpaceSearchCriteria = new RandomDiscreteValueSearchCriteria(); RandomDiscreteValueWalker walker = new RandomDiscreteValueWalker<>(parameters, hpGrid, modelParametersBuilderFactory, hyperSpaceSearchCriteria); @@ -105,8 +105,8 @@ public void regularGSOverTEParameters_parallel() { parameters._response_column = responseColumn; parameters._ignored_columns = ignoredColumns(trainingFrame, "home.dest", "embarked", parameters._response_column); - SchemaModelParametersBuilderFactory modelParametersBuilderFactory = - new SchemaModelParametersBuilderFactory<>(); + DefaultModelParametersBuilderFactory modelParametersBuilderFactory = + new DefaultModelParametersBuilderFactory<>(); RandomDiscreteValueSearchCriteria hyperSpaceSearchCriteria = new RandomDiscreteValueSearchCriteria(); RandomDiscreteValueWalker walker = new RandomDiscreteValueWalker<>( diff --git a/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderTestSuite.java b/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderTestSuite.java index ab704e488517..1a2bc20a090c 100644 --- a/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderTestSuite.java +++ b/h2o-extensions/target-encoder/src/test/java/ai/h2o/targetencoding/TargetEncoderTestSuite.java @@ -1,7 +1,6 @@ package ai.h2o.targetencoding; -import ai.h2o.targetencoding.pipeline.transformers.TargetEncoderFeatureTransformerTest; import org.junit.Ignore; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -22,7 +21,7 @@ TargetEncoderBroadcastJoinTest.class, TargetEncodingImmutabilityTest.class, TargetEncoderMojoIntegrationTest.class, - TargetEncoderFeatureTransformerTest.class, + TargetEncoderPreprocessorTest.class, TargetEncoderTest.class, TargetEncoderMojoWriterTest.class, TargetEncoderRGSTest.class diff --git a/h2o-extensions/xgboost/src/main/java/hex/tree/xgboost/XGBoost.java b/h2o-extensions/xgboost/src/main/java/hex/tree/xgboost/XGBoost.java index 35fe91003cfc..9a7c32891c5d 100755 --- a/h2o-extensions/xgboost/src/main/java/hex/tree/xgboost/XGBoost.java +++ b/h2o-extensions/xgboost/src/main/java/hex/tree/xgboost/XGBoost.java @@ -809,7 +809,7 @@ protected CVModelBuilder makeCVModelBuilder(ModelBuilder[] modelBuilder } } - @Override protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { + @Override public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) { if( _parms._stopping_rounds == 0 && _parms._max_runtime_secs == 0) return; // No exciting changes to stopping conditions // Extract stopping conditions from each CV model, and compute the best stopping answer _parms._stopping_rounds = 0; diff --git a/h2o-py/docs/modeling.rst b/h2o-py/docs/modeling.rst index 722241b9185e..9ceecb83b361 100644 --- a/h2o-py/docs/modeling.rst +++ b/h2o-py/docs/modeling.rst @@ -144,6 +144,12 @@ Unsupervised :show-inheritance: :members: +:mod:`H2OGenericEstimator` +-------------------------- +.. autoclass:: h2o.estimators.generic.H2OGenericEstimator + :show-inheritance: + :members: + :mod:`H2OGeneralizedLowRankEstimator` ------------------------------------- .. autoclass:: h2o.estimators.glrm.H2OGeneralizedLowRankEstimator @@ -190,12 +196,6 @@ Miscellaneous :show-inheritance: :members: -:mod:`H2OGenericEstimator` --------------------------- -.. autoclass:: h2o.estimators.generic.H2OGenericEstimator - :show-inheritance: - :members: - :mod:`H2OSingularValueDecompositionEstimator` --------------------------------------------- .. autoclass:: h2o.estimators.svd.H2OSingularValueDecompositionEstimator diff --git a/h2o-py/h2o/__init__.py b/h2o-py/h2o/__init__.py index c9e8937ac5d1..70ffd42e142d 100644 --- a/h2o-py/h2o/__init__.py +++ b/h2o-py/h2o/__init__.py @@ -96,12 +96,10 @@ def _read_txt_from_whl(name, fallback): def _init_(): from .display import ReplHook, in_py_repl from .backend.connection import register_session_hook - from .schemas import register_schemas if in_py_repl(): replhook = ReplHook() register_session_hook('open', replhook.__enter__) register_session_hook('close', replhook.__exit__) - register_schemas() _init_() diff --git a/h2o-py/h2o/backend/connection.py b/h2o-py/h2o/backend/connection.py index 7167cce9a630..bc99c25f5b93 100644 --- a/h2o-py/h2o/backend/connection.py +++ b/h2o-py/h2o/backend/connection.py @@ -28,10 +28,12 @@ import requests from requests.auth import AuthBase -from h2o.backend import H2OLocalServer +from h2o.backend import H2OCluster, H2OLocalServer from h2o.display import print2 from h2o.exceptions import H2OConnectionError, H2OServerError, H2OResponseError, H2OValueError -from h2o.schemas import H2OErrorV3, define_classes_from_schema, get_schema_handler +from h2o.model.metrics import make_metrics +from h2o.schemas import H2OMetadataV3, H2OErrorV3, H2OModelBuilderErrorV3, define_classes_from_schema +from h2o.two_dim_table import H2OTwoDimTable from h2o.utils.metaclass import CallableString, backwards_compatibility, h2o_meta from h2o.utils.shared_utils import stringify_list, stringify_dict, as_resource from h2o.utils.typechecks import (assert_is_type, assert_matches, assert_satisfies, is_type, numeric) @@ -675,7 +677,7 @@ def _test_connection(self, max_retries=5, messages=None): if self._local_server and not self._local_server.is_running(): raise H2OServerError("Local server was unable to start") try: - define_classes_from_schema(self) + define_classes_from_schema(_classes_defined_from_schema_, self) cld = self.request("GET /3/Cloud") if self.name and cld.cloud_name != self.name: @@ -887,6 +889,7 @@ def __exit__(self, *args): class H2OResponse(dict): + """Temporary...""" def __new__(cls, keyvals): # This method is called by the simplejson.json(object_pairs_hook=) @@ -900,10 +903,15 @@ def __new__(cls, keyvals): if k == "__schema" and is_type(v, str): schema = v break - if schema is not None: - handler = get_schema_handler(schema) - if handler is not None: - return handler(keyvals) + if schema == "MetadataV3": return H2OMetadataV3.make(keyvals) + if schema == "CloudV3": return H2OCluster.make(keyvals) + if schema == "H2OErrorV3": return H2OErrorV3.make(keyvals) + if schema == "H2OModelBuilderErrorV3": return H2OModelBuilderErrorV3.make(keyvals) + if schema == "TwoDimTableV3": return H2OTwoDimTable.make(keyvals) + if schema and schema.startswith("ModelMetrics"): + metrics = make_metrics(schema, keyvals) + if metrics is not None: + return metrics return super(H2OResponse, cls).__new__(cls, keyvals) # def __getattr__(self, key): @@ -913,6 +921,9 @@ def __new__(cls, keyvals): # return None +_classes_defined_from_schema_ = [H2OCluster, H2OErrorV3, H2OModelBuilderErrorV3] + + # Find the exception that occurs on invalid JSON input JSONDecodeError, _r = None, None try: diff --git a/h2o-py/h2o/display.py b/h2o-py/h2o/display.py index 063c64ff273f..6815c19a1f79 100644 --- a/h2o-py/h2o/display.py +++ b/h2o-py/h2o/display.py @@ -14,10 +14,10 @@ import tabulate # noinspection PyUnresolvedReferences -from h2o.utils.compatibility import * # NOQA -from h2o.utils.compatibility import str2 as str, bytes2 as bytes -from h2o.utils.shared_utils import can_use_pandas -from h2o.utils.threading import local_context, local_context_safe, local_env +from .utils.compatibility import * # NOQA +from .utils.compatibility import str2 as str, bytes2 as bytes +from .utils.shared_utils import can_use_pandas +from .utils.threading import local_context, local_context_safe, local_env __no_export = set(dir()) # all variables defined above this are not exported diff --git a/h2o-py/h2o/estimators/__init__.py b/h2o-py/h2o/estimators/__init__.py index 6a55f7c90af5..766e1678b950 100644 --- a/h2o-py/h2o/estimators/__init__.py +++ b/h2o-py/h2o/estimators/__init__.py @@ -28,7 +28,6 @@ from .model_selection import H2OModelSelectionEstimator from .naive_bayes import H2ONaiveBayesEstimator from .pca import H2OPrincipalComponentAnalysisEstimator -from .pipeline import H2OPipeline from .psvm import H2OSupportVectorMachineEstimator from .random_forest import H2ORandomForestEstimator from .rulefit import H2ORuleFitEstimator @@ -67,8 +66,8 @@ def create_estimator(algo, **params): "H2OExtendedIsolationForestEstimator", "H2OGeneralizedAdditiveEstimator", "H2OGradientBoostingEstimator", "H2OGenericEstimator", "H2OGeneralizedLinearEstimator", "H2OGeneralizedLowRankEstimator", "H2OInfogram", "H2OIsolationForestEstimator", "H2OIsotonicRegressionEstimator", "H2OKMeansEstimator", "H2OModelSelectionEstimator", - "H2ONaiveBayesEstimator", "H2OPrincipalComponentAnalysisEstimator", "H2OPipeline", - "H2OSupportVectorMachineEstimator", "H2ORandomForestEstimator", "H2ORuleFitEstimator", - "H2OStackedEnsembleEstimator", "H2OSingularValueDecompositionEstimator", "H2OTargetEncoderEstimator", - "H2OUpliftRandomForestEstimator", "H2OWord2vecEstimator", "H2OXGBoostEstimator" + "H2ONaiveBayesEstimator", "H2OPrincipalComponentAnalysisEstimator", "H2OSupportVectorMachineEstimator", + "H2ORandomForestEstimator", "H2ORuleFitEstimator", "H2OStackedEnsembleEstimator", + "H2OSingularValueDecompositionEstimator", "H2OTargetEncoderEstimator", "H2OUpliftRandomForestEstimator", + "H2OWord2vecEstimator", "H2OXGBoostEstimator" ) diff --git a/h2o-py/h2o/estimators/adaboost.py b/h2o-py/h2o/estimators/adaboost.py index 5323d3ec84e9..467d4f07045d 100644 --- a/h2o-py/h2o/estimators/adaboost.py +++ b/h2o-py/h2o/estimators/adaboost.py @@ -264,3 +264,4 @@ def seed(self, seed): assert_is_type(seed, None, int) self._parms["seed"] = seed + diff --git a/h2o-py/h2o/estimators/aggregator.py b/h2o-py/h2o/estimators/aggregator.py index f95b1083ac23..274e84852536 100644 --- a/h2o-py/h2o/estimators/aggregator.py +++ b/h2o-py/h2o/estimators/aggregator.py @@ -408,6 +408,7 @@ def export_checkpoints_dir(self, export_checkpoints_dir): assert_is_type(export_checkpoints_dir, None, str) self._parms["export_checkpoints_dir"] = export_checkpoints_dir + @property def aggregated_frame(self): if (self._model_json is not None @@ -423,4 +424,3 @@ def mapping_frame(self): if mj.get("output", {}).get("mapping_frame", {}).get("name") is not None: mapping_frame_name = mj["output"]["mapping_frame"]["name"] return H2OFrame.get_frame(mapping_frame_name) - diff --git a/h2o-py/h2o/estimators/anovaglm.py b/h2o-py/h2o/estimators/anovaglm.py index ea01bac719c5..6fdb66c88f7c 100644 --- a/h2o-py/h2o/estimators/anovaglm.py +++ b/h2o-py/h2o/estimators/anovaglm.py @@ -819,6 +819,7 @@ def type(self, type): assert_is_type(type, None, int) self._parms["type"] = type + @property def Lambda(self): """DEPRECATED. Use ``self.lambda_`` instead""" @@ -835,4 +836,3 @@ def result(self): :return: the H2OFrame that contains information about the model building process like for modelselection and anovaglm. """ return H2OFrame._expr(expr=ExprNode("result", ASTId(self.key)))._frame(fill_cache=True) - diff --git a/h2o-py/h2o/estimators/coxph.py b/h2o-py/h2o/estimators/coxph.py index 9e88edb22ee2..c0771c142795 100644 --- a/h2o-py/h2o/estimators/coxph.py +++ b/h2o-py/h2o/estimators/coxph.py @@ -570,6 +570,7 @@ def single_node_mode(self, single_node_mode): assert_is_type(single_node_mode, None, bool) self._parms["single_node_mode"] = single_node_mode + @property def baseline_hazard_frame(self): if (self._model_json is not None @@ -583,4 +584,3 @@ def baseline_survival_frame(self): and self._model_json.get("output", {}).get("baseline_survival", {}).get("name") is not None): baseline_survival_name = self._model_json["output"]["baseline_survival"]["name"] return H2OFrame.get_frame(baseline_survival_name) - diff --git a/h2o-py/h2o/estimators/decision_tree.py b/h2o-py/h2o/estimators/decision_tree.py index 8fe69a97cf25..e598396b2a82 100644 --- a/h2o-py/h2o/estimators/decision_tree.py +++ b/h2o-py/h2o/estimators/decision_tree.py @@ -186,3 +186,4 @@ def min_rows(self, min_rows): assert_is_type(min_rows, None, int) self._parms["min_rows"] = min_rows + diff --git a/h2o-py/h2o/estimators/deeplearning.py b/h2o-py/h2o/estimators/deeplearning.py index e78ffc557676..8c71881eea64 100644 --- a/h2o-py/h2o/estimators/deeplearning.py +++ b/h2o-py/h2o/estimators/deeplearning.py @@ -3260,6 +3260,7 @@ def gainslift_bins(self, gainslift_bins): self._parms["gainslift_bins"] = gainslift_bins + class H2OAutoEncoderEstimator(H2ODeepLearningEstimator): """ :examples: @@ -3279,4 +3280,3 @@ class H2OAutoEncoderEstimator(H2ODeepLearningEstimator): def __init__(self, **kwargs): super(H2OAutoEncoderEstimator, self).__init__(**kwargs) self.autoencoder = True - diff --git a/h2o-py/h2o/estimators/extended_isolation_forest.py b/h2o-py/h2o/estimators/extended_isolation_forest.py index 8eb5f9d42f3f..5be8ae966918 100644 --- a/h2o-py/h2o/estimators/extended_isolation_forest.py +++ b/h2o-py/h2o/estimators/extended_isolation_forest.py @@ -333,3 +333,4 @@ def disable_training_metrics(self, disable_training_metrics): assert_is_type(disable_training_metrics, None, bool) self._parms["disable_training_metrics"] = disable_training_metrics + diff --git a/h2o-py/h2o/estimators/gam.py b/h2o-py/h2o/estimators/gam.py index bcba5dcf4a4c..7eb2f88fe072 100644 --- a/h2o-py/h2o/estimators/gam.py +++ b/h2o-py/h2o/estimators/gam.py @@ -1650,4 +1650,3 @@ def get_gam_knot_column_names(self): raise H2OValueError("Knot locations are not available. Please re-run with store_knot_locations=True") return self._model_json['output']['gam_knot_column_names'] - diff --git a/h2o-py/h2o/estimators/gbm.py b/h2o-py/h2o/estimators/gbm.py index 93a22a1bf3ff..2452b5136390 100644 --- a/h2o-py/h2o/estimators/gbm.py +++ b/h2o-py/h2o/estimators/gbm.py @@ -2239,3 +2239,4 @@ def auto_rebalance(self, auto_rebalance): assert_is_type(auto_rebalance, None, bool) self._parms["auto_rebalance"] = auto_rebalance + diff --git a/h2o-py/h2o/estimators/generic.py b/h2o-py/h2o/estimators/generic.py index 60483aee7239..37067ea0a6db 100644 --- a/h2o-py/h2o/estimators/generic.py +++ b/h2o-py/h2o/estimators/generic.py @@ -100,6 +100,7 @@ def path(self, path): assert_is_type(path, None, str) self._parms["path"] = path + @staticmethod def from_file(file=str, model_id=None): """ @@ -126,4 +127,3 @@ def from_file(file=str, model_id=None): model.train() return model - diff --git a/h2o-py/h2o/estimators/glm.py b/h2o-py/h2o/estimators/glm.py index e8302692feb1..fd4062c78b81 100644 --- a/h2o-py/h2o/estimators/glm.py +++ b/h2o-py/h2o/estimators/glm.py @@ -2599,4 +2599,3 @@ def makeGLMModel(model, coefs, threshold=.5): m = H2OGeneralizedLinearEstimator() m._resolve_model(model_json["model_id"]["name"], model_json) return m - diff --git a/h2o-py/h2o/estimators/glrm.py b/h2o-py/h2o/estimators/glrm.py index d79b9dbed0ac..8574c8a6f186 100644 --- a/h2o-py/h2o/estimators/glrm.py +++ b/h2o-py/h2o/estimators/glrm.py @@ -1054,6 +1054,7 @@ def export_checkpoints_dir(self, export_checkpoints_dir): assert_is_type(export_checkpoints_dir, None, str) self._parms["export_checkpoints_dir"] = export_checkpoints_dir + def transform_frame(self, fr): """ GLRM performs A=X*Y during training. When a new dataset is given, GLRM will perform Anew = Xnew*Y. When @@ -1061,4 +1062,3 @@ def transform_frame(self, fr): :return: an H2OFrame that contains Xnew. """ return H2OFrame._expr(expr=ExprNode("transform", ASTId(self.key), ASTId(fr.key)))._frame(fill_cache=True) - diff --git a/h2o-py/h2o/estimators/infogram.py b/h2o-py/h2o/estimators/infogram.py index 2b65c3bb54de..6fe54b08c1a7 100644 --- a/h2o-py/h2o/estimators/infogram.py +++ b/h2o-py/h2o/estimators/infogram.py @@ -860,6 +860,7 @@ def top_n_features(self, top_n_features): assert_is_type(top_n_features, None, int) self._parms["top_n_features"] = top_n_features + def _extract_x_from_model(self): """ extract admissible features from an Infogram model. @@ -1189,4 +1190,3 @@ def train_subset_models(self, model_class, y, training_frame, test_frame, protec if protected_columns is None or len(protected_columns) == 0: return make_leaderboard(models, leaderboard_frame=test_frame) return disparate_analysis(models, test_frame, protected_columns, reference, favorable_class) - diff --git a/h2o-py/h2o/estimators/isolation_forest.py b/h2o-py/h2o/estimators/isolation_forest.py index b6f931f30cf5..64f2ae58afc9 100644 --- a/h2o-py/h2o/estimators/isolation_forest.py +++ b/h2o-py/h2o/estimators/isolation_forest.py @@ -744,3 +744,4 @@ def validation_response_column(self, validation_response_column): assert_is_type(validation_response_column, None, str) self._parms["validation_response_column"] = validation_response_column + diff --git a/h2o-py/h2o/estimators/isotonicregression.py b/h2o-py/h2o/estimators/isotonicregression.py index d8f05f5189c1..d5d72eaee640 100644 --- a/h2o-py/h2o/estimators/isotonicregression.py +++ b/h2o-py/h2o/estimators/isotonicregression.py @@ -291,3 +291,4 @@ def fold_column(self, fold_column): assert_is_type(fold_column, None, str) self._parms["fold_column"] = fold_column + diff --git a/h2o-py/h2o/estimators/kmeans.py b/h2o-py/h2o/estimators/kmeans.py index 506f452d7dd5..c63249385eb4 100644 --- a/h2o-py/h2o/estimators/kmeans.py +++ b/h2o-py/h2o/estimators/kmeans.py @@ -741,3 +741,4 @@ def cluster_size_constraints(self, cluster_size_constraints): assert_is_type(cluster_size_constraints, None, [int]) self._parms["cluster_size_constraints"] = cluster_size_constraints + diff --git a/h2o-py/h2o/estimators/model_selection.py b/h2o-py/h2o/estimators/model_selection.py index ae7a1d42e370..8632b1a12fb9 100644 --- a/h2o-py/h2o/estimators/model_selection.py +++ b/h2o-py/h2o/estimators/model_selection.py @@ -1326,6 +1326,7 @@ def multinode_mode(self, multinode_mode): assert_is_type(multinode_mode, None, bool) self._parms["multinode_mode"] = multinode_mode + def get_regression_influence_diagnostics(self, predictor_size=None): """ Get the regression influence diagnostics frames for all models with different number of predictors. If a @@ -1549,4 +1550,3 @@ def get_best_model_predictors(self): :return: a list of best predictors subset """ return self._model_json["output"]["best_predictors_subset"] - diff --git a/h2o-py/h2o/estimators/naive_bayes.py b/h2o-py/h2o/estimators/naive_bayes.py index 8c43df51e7e7..31123f55087e 100644 --- a/h2o-py/h2o/estimators/naive_bayes.py +++ b/h2o-py/h2o/estimators/naive_bayes.py @@ -906,3 +906,4 @@ def auc_type(self, auc_type): assert_is_type(auc_type, None, Enum("auto", "none", "macro_ovr", "weighted_ovr", "macro_ovo", "weighted_ovo")) self._parms["auc_type"] = auc_type + diff --git a/h2o-py/h2o/estimators/pca.py b/h2o-py/h2o/estimators/pca.py index bc324cd7e60d..7ecca50f3795 100644 --- a/h2o-py/h2o/estimators/pca.py +++ b/h2o-py/h2o/estimators/pca.py @@ -508,6 +508,7 @@ def export_checkpoints_dir(self, export_checkpoints_dir): assert_is_type(export_checkpoints_dir, None, str) self._parms["export_checkpoints_dir"] = export_checkpoints_dir + def init_for_pipeline(self): """ Returns H2OPCA object which implements fit and transform method to be used in sklearn.Pipeline properly. @@ -532,4 +533,3 @@ def init_for_pipeline(self): var_names = list(dict(inspect.getmembers(H2OPCA.__init__.__code__))['co_varnames']) parameters = {k: v for k, v in self._parms.items() if k in var_names} return H2OPCA(**parameters) - diff --git a/h2o-py/h2o/estimators/pipeline.py b/h2o-py/h2o/estimators/pipeline.py deleted file mode 100644 index 6d317e7cef94..000000000000 --- a/h2o-py/h2o/estimators/pipeline.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- -# -# This file is auto-generated by h2o-3/h2o-bindings/bin/gen_python.py -# Copyright 2016 H2O.ai; Apache License Version 2.0 (see LICENSE for details) -# - -import h2o -from h2o.base import Keyed -from h2o.display import H2ODisplay, repr_def -from h2o.expr import ASTId, ExprNode -from h2o.schemas import H2OSchema, register_schema_handler -from h2o.estimators.estimator_base import H2OEstimator -from h2o.exceptions import H2OValueError -from h2o.frame import H2OFrame -from h2o.utils.typechecks import assert_is_type, Enum, numeric - - -class H2OPipeline(H2OEstimator): - """ - Pipeline - - """ - - algo = "pipeline" - supervised_learning = None - - def __init__(self, - model_id=None, # type: Optional[Union[None, str, H2OEstimator]] - ): - """ - :param model_id: Destination id for this model; auto-generated if not specified. - Defaults to ``None``. - :type model_id: Union[None, str, H2OEstimator], optional - """ - super(H2OPipeline, self).__init__() - self._parms = {} - self._id = self._parms['model_id'] = model_id - - @staticmethod - def get_transformer(id): - assert_is_type(id, str) - return h2o.api("GET /3/Pipeline/DataTransformer/%s" % id) - - @property - def transformers(self): - trs_json = self._model_json['output']['transformers'] - return None if (trs_json is None) else [H2OPipeline.get_transformer(k['name']) for k in trs_json] - - @property - def estimator_model(self): - m_json = self._model_json['output']['estimator'] - return None if (m_json is None or m_json['name'] is None) else h2o.get_model(m_json['name']) - - def transform(self, fr): - """ - Applies all the pipeline transformers to the given input frame. - :return: the transformed frame, as it would be passed to `estimator_model`, if calling `predict` instead. - """ - return H2OFrame._expr(expr=ExprNode("transform", ASTId(self.key), ASTId(fr.key)))._frame(fill_cache=True) - - -class H2ODataTransformer(Keyed, H2ODisplay): - @classmethod - def make(cls, kvs): - dt = H2ODataTransformer(**{k: v for k, v in kvs if k not in H2OSchema._ignored_schema_keys_}) - dt._json = kvs - return dt - - def __init__(self, key=None, name=None, description=None): - self._json = None - self._id = key['name'] - self._name = name - self._description = description - - @property - def key(self): - return self._id - - def _repr_(self): - return repr(self._json) - - def _str_(self, verbosity=None): - return repr_def(self) - - -# self-register transformer class: done as soon as `h2o.estimators` is loaded, which means as soon as h2o.h2o is... -register_schema_handler("DataTransformerV3", H2ODataTransformer) - diff --git a/h2o-py/h2o/estimators/psvm.py b/h2o-py/h2o/estimators/psvm.py index 5bd3ece1ab3c..0261660f7652 100644 --- a/h2o-py/h2o/estimators/psvm.py +++ b/h2o-py/h2o/estimators/psvm.py @@ -553,3 +553,4 @@ def seed(self, seed): assert_is_type(seed, None, int) self._parms["seed"] = seed + diff --git a/h2o-py/h2o/estimators/rulefit.py b/h2o-py/h2o/estimators/rulefit.py index c7718e3b7933..529b371780ea 100644 --- a/h2o-py/h2o/estimators/rulefit.py +++ b/h2o-py/h2o/estimators/rulefit.py @@ -408,4 +408,3 @@ def predict_rules(self, frame, rule_ids): from h2o.expr import ExprNode assert_is_type(frame, H2OFrame) return H2OFrame._expr(expr=ExprNode("rulefit.predict.rules", self, frame, rule_ids)) - diff --git a/h2o-py/h2o/estimators/stackedensemble.py b/h2o-py/h2o/estimators/stackedensemble.py index bb2fb4c58946..6a2117c007b2 100644 --- a/h2o-py/h2o/estimators/stackedensemble.py +++ b/h2o-py/h2o/estimators/stackedensemble.py @@ -925,6 +925,7 @@ def gainslift_bins(self, gainslift_bins): assert_is_type(gainslift_bins, None, int) self._parms["gainslift_bins"] = gainslift_bins + def metalearner(self): """Print the metalearner of an H2OStackedEnsembleEstimator. @@ -1046,4 +1047,3 @@ def extend_parms(parms): raise H2OResponseError("Meta learner didn't get to be trained in time. " "Try increasing max_runtime_secs or setting it to 0 (unlimited).") return self - diff --git a/h2o-py/h2o/estimators/svd.py b/h2o-py/h2o/estimators/svd.py index 84d3bbaae3b6..996790e57af9 100644 --- a/h2o-py/h2o/estimators/svd.py +++ b/h2o-py/h2o/estimators/svd.py @@ -432,6 +432,7 @@ def export_checkpoints_dir(self, export_checkpoints_dir): assert_is_type(export_checkpoints_dir, None, str) self._parms["export_checkpoints_dir"] = export_checkpoints_dir + def init_for_pipeline(self): """ Returns H2OSVD object which implements fit and transform method to be used in sklearn.Pipeline properly. @@ -457,4 +458,3 @@ def init_for_pipeline(self): var_names = list(dict(inspect.getmembers(H2OSVD.__init__.__code__))['co_varnames']) parameters = {k: v for k, v in self._parms.items() if k in var_names} return H2OSVD(**parameters) - diff --git a/h2o-py/h2o/estimators/targetencoder.py b/h2o-py/h2o/estimators/targetencoder.py index 0345642d6512..e878dad4fb6e 100644 --- a/h2o-py/h2o/estimators/targetencoder.py +++ b/h2o-py/h2o/estimators/targetencoder.py @@ -452,4 +452,3 @@ def transform(self, frame, blending=None, inflection_point=None, smoothing=None, output = h2o.api("GET /3/TargetEncoderTransform", data=params) return h2o.get_frame(output["name"]) - diff --git a/h2o-py/h2o/estimators/uplift_random_forest.py b/h2o-py/h2o/estimators/uplift_random_forest.py index efa3bd71f347..7ebb1b13f133 100644 --- a/h2o-py/h2o/estimators/uplift_random_forest.py +++ b/h2o-py/h2o/estimators/uplift_random_forest.py @@ -668,3 +668,4 @@ def stopping_tolerance(self, stopping_tolerance): assert_is_type(stopping_tolerance, None, numeric) self._parms["stopping_tolerance"] = stopping_tolerance + diff --git a/h2o-py/h2o/estimators/word2vec.py b/h2o-py/h2o/estimators/word2vec.py index 828c9fa74c7d..f4930384ae8b 100644 --- a/h2o-py/h2o/estimators/word2vec.py +++ b/h2o-py/h2o/estimators/word2vec.py @@ -419,6 +419,7 @@ def export_checkpoints_dir(self, export_checkpoints_dir): assert_is_type(export_checkpoints_dir, None, str) self._parms["export_checkpoints_dir"] = export_checkpoints_dir + @staticmethod def from_external(external=H2OFrame): """ @@ -458,4 +459,3 @@ def _determine_vec_size(pre_trained): pre_trained.frame_id) return pre_trained.dim[1] - 1 - diff --git a/h2o-py/h2o/estimators/xgboost.py b/h2o-py/h2o/estimators/xgboost.py index 789042901e62..569181c9bd7f 100644 --- a/h2o-py/h2o/estimators/xgboost.py +++ b/h2o-py/h2o/estimators/xgboost.py @@ -2532,6 +2532,7 @@ def score_eval_metric_only(self, score_eval_metric_only): assert_is_type(score_eval_metric_only, None, bool) self._parms["score_eval_metric_only"] = score_eval_metric_only + @staticmethod def available(): """ @@ -2598,4 +2599,3 @@ def convert_H2OXGBoostParams_2_XGBoostParams(self): paramsSet = self.full_parameters return nativeXGBoostParams, paramsSet['ntrees']['actual_value'] - diff --git a/h2o-py/h2o/expr.py b/h2o-py/h2o/expr.py index 628ffa2a4b6c..6c6a4c340a6f 100644 --- a/h2o-py/h2o/expr.py +++ b/h2o-py/h2o/expr.py @@ -18,9 +18,10 @@ import tabulate import h2o -from h2o.exceptions import H2OConnectionError -from h2o.expr_optimizer import optimize +from h2o.backend.connection import H2OConnectionError from h2o.utils.shared_utils import _is_fr, _py_tmp_key +from h2o.model.model_base import ModelBase +from h2o.expr_optimizer import optimize class ExprNode(object): @@ -190,7 +191,7 @@ def _arg_to_expr(arg): return "[%d:%s]" % (start, str(stop - start)) else: return "[%d:%s:%d]" % (start, str((stop - start + step - 1) // step), step) - if isinstance(arg, h2o.model.ModelBase): + if isinstance(arg, ModelBase): return arg.model_id # Number representation without Py2 L suffix enforced if isinstance(arg, numbers.Integral): diff --git a/h2o-py/h2o/h2o.py b/h2o-py/h2o/h2o.py index 457a0394c258..542cb117d872 100644 --- a/h2o-py/h2o/h2o.py +++ b/h2o-py/h2o/h2o.py @@ -17,7 +17,14 @@ from .base import Keyed from .estimators import create_estimator from .estimators.generic import H2OGenericEstimator -from .exceptions import H2OError, H2OConnectionError, H2OValueError, H2ODependencyWarning, H2ODeprecationWarning +from .exceptions import H2OError, H2ODeprecationWarning +from .estimators.gbm import H2OGradientBoostingEstimator +from .estimators.glm import H2OGeneralizedLinearEstimator +from .estimators.xgboost import H2OXGBoostEstimator +from .estimators.infogram import H2OInfogram +from .estimators.deeplearning import H2OAutoEncoderEstimator, H2ODeepLearningEstimator +from .estimators.extended_isolation_forest import H2OExtendedIsolationForestEstimator +from .exceptions import H2OConnectionError, H2OValueError, H2ODependencyWarning from .expr import ExprNode from .frame import H2OFrame from .grid.grid_search import H2OGridSearch @@ -510,8 +517,8 @@ def load_grid(grid_file_path, load_params_references=False): :examples: >>> from collections import OrderedDict - >>> from h2o.grid import H2OGridSearch - >>> from h2o.estimators import H2OGradientBoostingEstimator + >>> from h2o.grid.grid_search import H2OGridSearch + >>> from h2o.estimators.gbm import H2OGradientBoostingEstimator >>> train = h2o.import_file("http://h2o-public-test-data.s3.amazonaws.com/smalldata/iris/iris_wheader.csv") # Run GBM Grid Search >>> ntrees_opts = [1, 3] @@ -555,8 +562,8 @@ def save_grid(grid_directory, grid_id, save_params_references=False, export_cros :examples: >>> from collections import OrderedDict - >>> from h2o.grid import H2OGridSearch - >>> from h2o.estimators import H2OGradientBoostingEstimator + >>> from h2o.grid.grid_search import H2OGridSearch + >>> from h2o.estimators.gbm import H2OGradientBoostingEstimator >>> train = h2o.import_file("http://h2o-public-test-data.s3.amazonaws.com/smalldata/iris/iris_wheader.csv") # Run GBM Grid Search >>> ntrees_opts = [1, 3] @@ -1051,7 +1058,6 @@ def models(): :examples: - >>> from h2o.estimators import H2OGeneralizedLinearEstimator, H2OXGBoostEstimator >>> airlines= h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/airlines/allyears2k_headers.zip") >>> airlines["Year"]= airlines["Year"].asfactor() >>> airlines["Month"]= airlines["Month"].asfactor() @@ -1077,7 +1083,6 @@ def get_model(model_id): :examples: - >>> from h2o.estimators import H2OGeneralizedLinearEstimator >>> airlines= h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/airlines/allyears2k_headers.zip") >>> airlines["Year"]= airlines["Year"].asfactor() >>> airlines["Month"]= airlines["Month"].asfactor() @@ -1117,7 +1122,7 @@ def get_grid(grid_id): :examples: - >>> from h2o.grid import H2OGridSearch + >>> from h2o.grid.grid_search import H2OGridSearch >>> from h2o.estimators import H2OGradientBoostingEstimator >>> airlines= h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/airlines/allyears2k_headers.zip") >>> x = ["DayofMonth", "Month"] @@ -1187,7 +1192,6 @@ def no_progress(): :examples: - >>> from h2o.estimators import H2OGeneralizedLinearEstimator >>> h2o.no_progress() >>> airlines= h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/airlines/allyears2k_headers.zip") >>> x = ["DayofMonth", "Month"] @@ -1204,7 +1208,6 @@ def show_progress(): :examples: - >>> from h2o.estimators import H2OGeneralizedLinearEstimator >>> h2o.no_progress() >>> airlines= h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/airlines/allyears2k_headers.zip") >>> x = ["DayofMonth", "Month"] @@ -1391,7 +1394,6 @@ def download_pojo(model, path="", get_jar=True, jar_name=""): :examples: - >>> from h2o.estimators import H2OGeneralizedLinearEstimator >>> h2o_df = h2o.import_file("http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv.zip") >>> h2o_df['CAPSULE'] = h2o_df['CAPSULE'].asfactor() >>> from h2o.estimators.glm import H2OGeneralizedLinearEstimator @@ -1493,7 +1495,7 @@ def save_model(model, path="", force=False, export_cross_validation_predictions= :examples: - >>> from h2o.estimators import H2OGeneralizedLinearEstimator + >>> from h2o.estimators.glm import H2OGeneralizedLinearEstimator >>> h2o_df = h2o.import_file("http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv.zip") >>> my_model = H2OGeneralizedLinearEstimator(family = "binomial") >>> my_model.train(y = "CAPSULE", @@ -1529,7 +1531,7 @@ def download_model(model, path="", export_cross_validation_predictions=False, fi :examples: - >>> from h2o.estimators import H2OGeneralizedLinearEstimator + >>> from h2o.estimators.glm import H2OGeneralizedLinearEstimator >>> h2o_df = h2o.import_file("http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv.zip") >>> my_model = H2OGeneralizedLinearEstimator(family = "binomial") >>> my_model.train(y = "CAPSULE", @@ -1575,7 +1577,6 @@ def load_model(path): :examples: - >>> from h2o.estimators import H2OGeneralizedLinearEstimator >>> training_data = h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/airlines/allyears2k_headers.zip") >>> predictors = ["Origin", "Dest", "Year", "UniqueCarrier", ... "DayOfWeek", "Month", "Distance", "FlightNum"] @@ -1622,7 +1623,6 @@ def export_file(frame, path, force=False, sep=",", compression=None, parts=1, he :examples: - >>> from h2o.estimators import H2OGeneralizedLinearEstimator >>> h2o_df = h2o.import_file("http://h2o-public-test-data.s3.amazonaws.com/smalldata/prostate/prostate.csv") >>> h2o_df['CAPSULE'] = h2o_df['CAPSULE'].asfactor() >>> rand_vec = h2o_df.runif(1234) @@ -2026,7 +2026,6 @@ def make_metrics(predicted, actual, domain=None, distribution=None, weights=None the new thresholds from the predicted data. :examples: - >>> from h2o.estimators import H2OGradientBoostingEstimator >>> fr = h2o.import_file("http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv.zip") >>> fr["CAPSULE"] = fr["CAPSULE"].asfactor() >>> fr["RACE"] = fr["RACE"].asfactor() @@ -2386,7 +2385,7 @@ def print_mojo(mojo_path, format="json", tree_index=None): :example: >>> import json - >>> from h2o.estimators import H2OGradientBoostingEstimator + >>> from h2o.estimators.gbm import H2OGradientBoostingEstimator >>> prostate = h2o.import_file("http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv") >>> prostate["CAPSULE"] = prostate["CAPSULE"].asfactor() >>> gbm_h2o = H2OGradientBoostingEstimator(ntrees = 5, diff --git a/h2o-py/h2o/pipeline/__init__.py b/h2o-py/h2o/pipeline/__init__.py index 285e84550e68..082f127c5374 100644 --- a/h2o-py/h2o/pipeline/__init__.py +++ b/h2o-py/h2o/pipeline/__init__.py @@ -1,3 +1,9 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +# +# This file is auto-generated by h2o-3/h2o-bindings/bin/gen_python.py +# Copyright 2016 H2O.ai; Apache License Version 2.0 (see LICENSE for details) +# from .mojo_pipeline import H2OMojoPipeline __all__ = ( diff --git a/h2o-py/h2o/schemas/__init__.py b/h2o-py/h2o/schemas/__init__.py index 2a645a046694..048943138f71 100644 --- a/h2o-py/h2o/schemas/__init__.py +++ b/h2o-py/h2o/schemas/__init__.py @@ -1,53 +1,6 @@ -from functools import partial - from .error import H2OErrorV3, H2OModelBuilderErrorV3 from .metadata import H2OMetadataV3 -from .schema import H2OSchema +from .schema import H2OSchema, define_classes_from_schema -__all__ = ['H2OSchema', +__all__ = ['H2OSchema', 'define_classes_from_schema', 'H2OErrorV3', 'H2OModelBuilderErrorV3', 'H2OMetadataV3'] - - -__schema_handlers = [] - - -def get_schema_handler(schema): - for s, h in __schema_handlers: - if s == schema: - return h - elif callable(s) and s(schema): - return partial(h, schema) - - -def register_schema_handler(schema, handler): - """ - :param Union[str, Callable] schema: a string representing a schema name, or a predicate on a schema name. - :param Callable handler: a function taking the schema payload as parameter, represented as a list or key-value tuples. - If the handler is a factory object with a `make` method, then that method will be used as the handler. - if the schema is a predicate, then the handler function must accept 2 parameters: the schema name + the schema payload. - """ - if hasattr(handler, "make") and callable(handler.make): - handler = handler.make - __schema_handlers.append((schema, handler)) - - -def register_schemas(): - from h2o.backend import H2OCluster - from h2o.model.metrics import make_metrics - from h2o.two_dim_table import H2OTwoDimTable - - for (schema, handler) in [ - ("MetadataV3", H2OMetadataV3), - ("CloudV3", H2OCluster), - ("H2OErrorV3", H2OErrorV3), - ("H2OModelBuilderErrorV3", H2OModelBuilderErrorV3), - ("TwoDimTableV3", H2OTwoDimTable), - (lambda s: s.startswith("ModelMetrics"), make_metrics), - ]: - register_schema_handler(schema, handler) - - -def define_classes_from_schema(conn): - from h2o.backend import H2OCluster - for cls in [H2OCluster, H2OErrorV3, H2OModelBuilderErrorV3]: - cls.define_from_schema(conn) diff --git a/h2o-py/h2o/schemas/schema.py b/h2o-py/h2o/schemas/schema.py index 19703a0c48f4..45fc047dc5df 100644 --- a/h2o-py/h2o/schemas/schema.py +++ b/h2o-py/h2o/schemas/schema.py @@ -11,6 +11,11 @@ from h2o.exceptions import H2OConnectionError, H2OServerError +def define_classes_from_schema(classes, connection): + for cls in classes: + cls.define_from_schema(connection) + + class H2OSchema(object): _ignored_schema_keys_ = {"__meta", "_exclude_fields", "__schema"} diff --git a/h2o-py/h2o/sklearn/__init__.py b/h2o-py/h2o/sklearn/__init__.py index 14f4f0d4ffef..dd5157884a7c 100644 --- a/h2o-py/h2o/sklearn/__init__.py +++ b/h2o-py/h2o/sklearn/__init__.py @@ -197,7 +197,6 @@ def h2o_connection(**init_args): _excluded_estimators = ( # e.g. abstract classes 'H2OEstimator', 'H2OTransformer', - 'H2OPipeline', 'H2OInfogram', 'H2OANOVAGLMEstimator', # fully disabled as it does not support `predict` method. 'H2OModelSelectionEstimator', # fully disabled as it does no support `predict` method. diff --git a/h2o-py/h2o/utils/shared_utils.py b/h2o-py/h2o/utils/shared_utils.py index b3e0358452c6..67b622fd62c6 100644 --- a/h2o-py/h2o/utils/shared_utils.py +++ b/h2o-py/h2o/utils/shared_utils.py @@ -44,6 +44,7 @@ def __subclasshook__(cls, C): return NotImplemented +from h2o.backend.server import H2OLocalServer from h2o.exceptions import H2OValueError from h2o.utils.typechecks import assert_is_type, is_type, numeric from h2o.utils.threading import local_env @@ -446,8 +447,6 @@ def mojo_predict_csv(input_csv_path, mojo_zip_path, output_csv_path=None, genmod :param extra_cmd_args: Optional, a list of additional arguments to append to genmodel.jar's command line. :return: List of computed predictions """ - from h2o.backend.server import H2OLocalServer - default_java_options = '-Xmx4g -XX:ReservedCodeCacheSize=256m' prediction_output_file = 'prediction.csv' diff --git a/h2o-py/tests/testdir_algos/automl/pyunit_automl_preprocessing.py b/h2o-py/tests/testdir_algos/automl/pyunit_automl_preprocessing.py index 863c4420c50e..db5c4b8f7ea1 100644 --- a/h2o-py/tests/testdir_algos/automl/pyunit_automl_preprocessing.py +++ b/h2o-py/tests/testdir_algos/automl/pyunit_automl_preprocessing.py @@ -21,26 +21,19 @@ def import_dataset(seed=0, mode='binary'): return pu.ns(train=fr[0], test=fr[1], target=target) -def check_mojo_pojo_availability(model): +def check_mojo_pojo_availability(model_id): + model = h2o.get_model(model_id) if model.algo in ['stackedensemble']: assert not model.have_mojo, "Model %s should not support MOJO" % model.model_id # because base models don't assert not model.have_pojo, "Model %s should not support POJO" % model.model_id - elif model.algo in ['deeplearning']: + elif model.algo in ['glm', 'deeplearning']: assert model.have_mojo, "Model %s should support MOJO" % model.model_id assert model.have_pojo, "Model %s should support POJO" % model.model_id else: - assert model.algo in ['pipeline'] assert not model.have_mojo, "Model %s should not support MOJO" % model.model_id assert not model.have_pojo, "Model %s should not support POJO" % model.model_id -def check_predict(model, test): - predictions = model.predict(test) - assert predictions is not None - assert predictions.nrows == test.nrows - # print(predictions) - - def test_target_encoding_binary(): ds = import_dataset(mode='binary') aml = H2OAutoML(project_name="automl_with_te_binary", @@ -50,17 +43,13 @@ def test_target_encoding_binary(): aml.train(y=ds.target, training_frame=ds.train, leaderboard_frame=ds.test) lb = aml.leaderboard print(lb) - model_ids = list(h2o.as_list(lb['model_id'])['model_id']) - assert any(m.startswith("Pipeline") for m in model_ids), "at least a Pipeline model should have been trained" # we can't really verify from client if TE was correctly applied... so just using a poor man's check: mem_keys = h2o.ls().key # print(mem_keys) - assert any(k == "default_TE_1_model" for k in mem_keys), "a TE model should have been trained" + assert any(k.startswith("TargetEncoding_AutoML") for k in mem_keys) for mid in get_partitioned_model_names(lb).all: - model = h2o.get_model(mid) - check_mojo_pojo_availability(model) - check_predict(model, ds.test) - + check_mojo_pojo_availability(mid) + def test_target_encoding_multiclass(): ds = import_dataset(mode='multiclass') @@ -71,16 +60,12 @@ def test_target_encoding_multiclass(): aml.train(y=ds.target, training_frame=ds.train, leaderboard_frame=ds.test) lb = aml.leaderboard print(lb) - model_ids = list(h2o.as_list(lb['model_id'])['model_id']) - assert any(m.startswith("Pipeline") for m in model_ids), "at least a Pipeline model should have been trained" # we can't really verify from client if TE was correctly applied... so just using a poor man's check: mem_keys = h2o.ls().key # print(mem_keys) - assert any(k == "default_TE_1_model" for k in mem_keys), "a TE model should have been trained" + assert any(k.startswith("TargetEncoding_AutoML") for k in mem_keys) for mid in get_partitioned_model_names(lb).all: - model = h2o.get_model(mid) - check_mojo_pojo_availability(model) - check_predict(model, ds.test) + check_mojo_pojo_availability(mid) def test_target_encoding_regression(): @@ -92,16 +77,12 @@ def test_target_encoding_regression(): aml.train(y=ds.target, training_frame=ds.train, leaderboard_frame=ds.test) lb = aml.leaderboard print(lb) - model_ids = list(h2o.as_list(lb['model_id'])['model_id']) - assert any(m.startswith("Pipeline") for m in model_ids), "at least a Pipeline model should have been trained" # we can't really verify from client if TE was correctly applied... so just using a poor man's check: mem_keys = h2o.ls().key # print(mem_keys) - assert any(k == "default_TE_1_model" for k in mem_keys), "a TE model should have been trained" + assert any(k.startswith("TargetEncoding_AutoML") for k in mem_keys) for mid in get_partitioned_model_names(lb).all: - model = h2o.get_model(mid) - check_mojo_pojo_availability(model) - check_predict(model, ds.test) + check_mojo_pojo_availability(mid) pu.run_tests([ diff --git a/h2o-py/tests/testdir_algos/glm/pyunit_PUBDEV_8150_warning_alpha_grid.py b/h2o-py/tests/testdir_algos/glm/pyunit_PUBDEV_8150_warning_alpha_grid.py index 49d6cd4cda8b..383363fe470b 100644 --- a/h2o-py/tests/testdir_algos/glm/pyunit_PUBDEV_8150_warning_alpha_grid.py +++ b/h2o-py/tests/testdir_algos/glm/pyunit_PUBDEV_8150_warning_alpha_grid.py @@ -1,6 +1,4 @@ from builtins import range -import contextlib -from io import StringIO import sys sys.path.insert(1,"../../../") import h2o @@ -8,6 +6,11 @@ from h2o.estimators.glm import H2OGeneralizedLinearEstimator from h2o.grid.grid_search import H2OGridSearch +try: # redirect python output + from StringIO import StringIO # for python 3 +except ImportError: + from io import StringIO # for python 2 + # This test is used to make sure when a user tries to set alpha in the hyper-parameter of gridsearch, a warning # should appear to tell the user to set the alpha array as an parameter in the algorithm. def grid_alpha_search(): @@ -22,17 +25,23 @@ def grid_alpha_search(): hyper_parameters = {'alpha': [0, 0.5]} # set hyper_parameters for grid search print("Create models with lambda_search") - err = StringIO() - with contextlib.redirect_stderr(err): - model_h2o_grid_search = H2OGridSearch(H2OGeneralizedLinearEstimator(family="tweedie", Lambda=0.5), - hyper_parameters) - model_h2o_grid_search.train(x=x, y=y, training_frame=hdf) + buffer = StringIO() # redirect output + sys.stderr=buffer + model_h2o_grid_search = H2OGridSearch(H2OGeneralizedLinearEstimator(family="tweedie", Lambda=0.5), + hyper_parameters) + model_h2o_grid_search.train(x=x, y=y, training_frame=hdf) + sys.stderr=sys.__stderr__ # redirect printout back to normal path # check and make sure we get the correct warning message warn_phrase = "Adding alpha array to hyperparameter runs slower with gridsearch." - warns = err.getvalue() - print("*** captured warning message: {0}".format(warns)) - assert warn_phrase in warns + try: # for python 2.7 + assert len(buffer.buflist)==warnNumber + print(buffer.buflist[0]) + assert warn_phrase in buffer.buflist[0] + except: # for python 3. + warns = buffer.getvalue() + print("*** captured warning message: {0}".format(warns)) + assert warn_phrase in warns if __name__ == "__main__": pyunit_utils.standalone_test(grid_alpha_search) diff --git a/h2o-py/tests/testdir_algos/grid/pyunit_grid_parallel_cv_error.py b/h2o-py/tests/testdir_algos/grid/pyunit_grid_parallel_cv_error.py index 562c7b57d9e9..79c5b6c66ca3 100644 --- a/h2o-py/tests/testdir_algos/grid/pyunit_grid_parallel_cv_error.py +++ b/h2o-py/tests/testdir_algos/grid/pyunit_grid_parallel_cv_error.py @@ -29,7 +29,6 @@ def grid_parallel(): gs.train(x=list(range(4)), y=4, training_frame=train, fold_column="fold_assignment") assert gs is not None # only six models are trained, since CV is not possible with min_rows=100 - print(gs.model_ids) assert len(gs.model_ids) == 6 diff --git a/h2o-py/tests_rest_smoke/testdir_multi_jvm/test_rest_api.py b/h2o-py/tests_rest_smoke/testdir_multi_jvm/test_rest_api.py index 6510800f0472..13c93693ae8e 100644 --- a/h2o-py/tests_rest_smoke/testdir_multi_jvm/test_rest_api.py +++ b/h2o-py/tests_rest_smoke/testdir_multi_jvm/test_rest_api.py @@ -23,7 +23,7 @@ algos = ['coxph', 'kmeans', 'deeplearning', 'drf', 'glm', 'gbm', 'pca', 'naivebayes', 'glrm', 'svd', 'isotonicregression', 'psvm', 'aggregator', 'word2vec', 'stackedensemble', 'xgboost', 'isolationforest', 'gam', 'generic', 'targetencoder', 'rulefit', 'extendedisolationforest', 'anovaglm', 'modelselection', - 'upliftdrf', 'infogram', 'dt', 'adaboost', 'pipeline'] + 'upliftdrf', 'infogram', 'dt', 'adaboost'] algo_additional_default_params = { 'grep' : { 'regex' : '.*' }, 'kmeans' : { 'k' : 2 }, diff --git a/h2o-r/h2o-package/R/classes.R b/h2o-r/h2o-package/R/classes.R index 245f51c6b627..296105062f62 100755 --- a/h2o-r/h2o-package/R/classes.R +++ b/h2o-r/h2o-package/R/classes.R @@ -868,30 +868,6 @@ setClass("H2OSegmentModelsFuture", slots = c(job_key = "character", segment_mode #' @export setClass("H2OSegmentModels", slots = c(segment_models_id = "character")) -#' H2O Data Transformer -#' -#' A representation of a transformer used in an H2O Pipeline -#' @slot id the unique identifier for the transformer. -#' @slot name the readable name for the transformer and its variants. -#' @slot description a description of what the transformer does on data. -#' @export -setClass("H2ODataTransformer", slots = c(id = "character", name = "character", description = "character")) - -#' @rdname h2o.keyof -setMethod("h2o.keyof", signature("H2ODataTransformer"), function(object) object@id) - -#' H2O Pipeline -#' -#' A representation of a pipeline model consisting in a sequence of transformers applied to data -#' and usually followed by a final estimator model. -#' @slot transformers the list of H2O Data Transformers in the pipeline. -#' @slot estimator_model the final estimator model. -setClass("H2OPipeline", contains="H2OModel", - slots = c( - transformers = "list", - estimator_model = "H2OModel" - )) - #' H2O Grid #' #' A class to contain the information about grid results diff --git a/h2o-r/h2o-package/R/kvstore.R b/h2o-r/h2o-package/R/kvstore.R index 55300d5c5045..aba10a750f82 100644 --- a/h2o-r/h2o-package/R/kvstore.R +++ b/h2o-r/h2o-package/R/kvstore.R @@ -239,7 +239,7 @@ h2o.getModel <- function(model_id) { names(model$random_coefficients) <- model$random_coefficients_table[,1] } } else { # with AnovaGLM - coefLen <- length(model$coefficients_table) + coefLen = length(model$coefficients_table) model$coefficients <- vector("list", coefLen) for (index in 1:coefLen) { model$coefficients[[index]] <- model$coefficients_table[[index]] diff --git a/h2o-r/h2o-package/R/models.R b/h2o-r/h2o-package/R/models.R index 2a344d173295..dcff26351981 100755 --- a/h2o-r/h2o-package/R/models.R +++ b/h2o-r/h2o-package/R/models.R @@ -682,17 +682,6 @@ setMethod("h2o.transform", signature("H2OWordEmbeddingModel"), function(model, w }) -#' -#' Transform the given data frame using the model if the latter supports transformations. -#' -#' @param model A trained model representing the transformation strategy (currently supported algorithms are `glrm` and `pipeline`). -#' @param data An H2OFrame on which the transformation is applied. -#' @return an H2OFrame object representing the transformed data. -#' @export -setMethod("h2o.transform", signature("H2OModel"), function(model, data) { - if (!model@algorithm %in% c("glrm", "pipeline")) stop("h2o.transform is not available for this type of model.") - return(.newExpr("transform", model@model_id, h2o.getId(data))) -}) #' #' @rdname predict.H2OModel diff --git a/h2o-r/h2o-package/R/pipeline.R b/h2o-r/h2o-package/R/pipeline.R index 85961b5a05e3..626536b79209 100644 --- a/h2o-r/h2o-package/R/pipeline.R +++ b/h2o-r/h2o-package/R/pipeline.R @@ -56,13 +56,8 @@ h2o.pipeline <- function(model_id = NULL) } else { model$estimator_model <- NULL } - model$transformers <- unlist(lapply(model$transformers, function(k) .h2o.fetch_datatransformer(k$name))) + model$transformers <- unlist(lapply(model$transformers, function(dt) new("H2ODataTransformer", id=dt$id, description=dt$description))) # class(model) <- "H2OPipeline" return(model) } -.h2o.fetch_datatransformer <- function(id) { - resp <- .h2o.__remoteSend(method="GET", h2oRestApiVersion=3, page=paste0("Pipeline/DataTransformer/", id)) - tr <- new("H2ODataTransformer", id=resp$key$name, name=resp$name, description=resp$description) - return (tr) -} diff --git a/h2o-r/tests/testdir_algos/automl/runit_automl_preprocessing.R b/h2o-r/tests/testdir_algos/automl/runit_automl_preprocessing.R index 7f23d08a9655..a7a4dd200501 100644 --- a/h2o-r/tests/testdir_algos/automl/runit_automl_preprocessing.R +++ b/h2o-r/tests/testdir_algos/automl/runit_automl_preprocessing.R @@ -23,12 +23,9 @@ automl.preprocessing.suite <- function() { preprocessing = list("target_encoding"), seed = 1 ) - print(h2o.get_leaderboard(aml)) - model_ids <- as.vector(aml@leaderboard$model_id) - expect_equal(sum(grepl("Pipeline_", model_ids)) > 0, TRUE) - + print(h2o.get_leaderboard(aml)) keys <- h2o.ls()$key - expect_true(any(grepl("default_TE_1_model", keys))) + expect_true(any(grepl("TargetEncoding_AutoML", keys))) } diff --git a/h2o-r/tests/testdir_misc/runit_connect.R b/h2o-r/tests/testdir_misc/runit_connect.R index 2e76ed275a72..46495aa5b4dd 100644 --- a/h2o-r/tests/testdir_misc/runit_connect.R +++ b/h2o-r/tests/testdir_misc/runit_connect.R @@ -8,7 +8,7 @@ to_src <- c("aggregator.R", "classes.R", "connection.R","config.R", "constants.R "coxph.R", "coxphutils.R", "gbm.R", "glm.R", "gam.R", "glrm.R", "kmeans.R", "deeplearning.R", "randomforest.R", "generic.R", "naivebayes.R", "pca.R", "svd.R", "locate.R", "grid.R", "word2vec.R", "w2vutils.R", "stackedensemble.R", "rulefit.R", "predict.R", "xgboost.R", "isolationforest.R", "psvm.R", "segment.R", "tf-idf.R", "explain.R", "permutation_varimp.R", - "extendedisolationforest.R", "upliftrandomforest.R", "pipeline.R") + "extendedisolationforest.R", "upliftrandomforest.R") src_path <- paste(h2oRDir,"h2o-package","R",sep=.Platform$file.sep) invisible(lapply(to_src,function(x){source(paste(src_path, x, sep = .Platform$file.sep))})) diff --git a/h2o-test-accuracy/src/test/java/water/TestCase.java b/h2o-test-accuracy/src/test/java/water/TestCase.java index d55a698f00d5..3a904a1aa833 100644 --- a/h2o-test-accuracy/src/test/java/water/TestCase.java +++ b/h2o-test-accuracy/src/test/java/water/TestCase.java @@ -9,7 +9,6 @@ import hex.grid.Grid; import hex.grid.GridSearch; import hex.grid.HyperSpaceSearchCriteria; -import hex.grid.SimpleParametersBuilderFactory; import hex.tree.SharedTreeModel; import hex.tree.drf.DRF; import hex.tree.drf.DRFModel; @@ -211,7 +210,7 @@ public TestCaseResult execute() throws Exception, AssertionError { // TODO: ModelParametersBuilderFactory parameter must be instantiated properly Job gs = GridSearch.startGridSearch( null,params,hyperParms, - new SimpleParametersBuilderFactory<>(), + new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria, 0 ); grid = gs.get(); diff --git a/h2o-test-support/src/main/java/water/test/dummy/DummyModel.java b/h2o-test-support/src/main/java/water/test/dummy/DummyModel.java index b418bab71fc9..35c8f146dee8 100644 --- a/h2o-test-support/src/main/java/water/test/dummy/DummyModel.java +++ b/h2o-test-support/src/main/java/water/test/dummy/DummyModel.java @@ -3,10 +3,8 @@ import hex.Model; import hex.ModelMetrics; import hex.ModelMetricsBinomial; -import hex.ModelMetricsRegression; import water.Futures; import water.Key; -import water.fvec.Frame; public class DummyModel extends Model { public DummyModel(Key selfKey, DummyModelParameters parms, DummyModelOutput output) { @@ -14,17 +12,10 @@ public DummyModel(Key selfKey, DummyModelParameters parms, DummyMode } @Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) { - if (domain == null) return new ModelMetricsRegression.MetricBuilderRegression(); return new ModelMetricsBinomial.MetricBuilderBinomial(domain); } @Override protected double[] score0(double[] data, double[] preds) { return preds; } - - @Override - public Frame transform(Frame fr) { - return fr == null ? null : fr.toTwoDimTable(0, 10, true).asFrame(Key.make(fr._key+"_stats"), true); - } - @Override protected Futures remove_impl(Futures fs, boolean cascade) { super.remove_impl(fs, cascade); diff --git a/h2o-test-support/src/main/java/water/test/dummy/DummyModelBuilder.java b/h2o-test-support/src/main/java/water/test/dummy/DummyModelBuilder.java index 59b5a41ceaf5..05fd08c0df91 100644 --- a/h2o-test-support/src/main/java/water/test/dummy/DummyModelBuilder.java +++ b/h2o-test-support/src/main/java/water/test/dummy/DummyModelBuilder.java @@ -30,7 +30,7 @@ public void computeImpl() { init(true); Model model = null; try { - model = new DummyModel(dest(), _parms, new DummyModelOutput(DummyModelBuilder.this, msg)); + model = new DummyModel(dest(), _parms, new DummyModelOutput(DummyModelBuilder.this, train(), msg)); model.delete_and_lock(_job); model.update(_job); } finally { diff --git a/h2o-test-support/src/main/java/water/test/dummy/DummyModelOutput.java b/h2o-test-support/src/main/java/water/test/dummy/DummyModelOutput.java index e7320ab84f9a..e8ee91998ee9 100644 --- a/h2o-test-support/src/main/java/water/test/dummy/DummyModelOutput.java +++ b/h2o-test-support/src/main/java/water/test/dummy/DummyModelOutput.java @@ -7,8 +7,8 @@ public class DummyModelOutput extends Model.Output { public final String _msg; - public DummyModelOutput(ModelBuilder b, String msg) { - super(b, b.train()); + public DummyModelOutput(ModelBuilder b, Frame train, String msg) { + super(b, train); _msg = msg; } @Override