diff --git a/h2o-algos/src/main/java/hex/adaboost/AdaBoost.java b/h2o-algos/src/main/java/hex/adaboost/AdaBoost.java index b606027c6bfa..b04e9b9dc16e 100644 --- a/h2o-algos/src/main/java/hex/adaboost/AdaBoost.java +++ b/h2o-algos/src/main/java/hex/adaboost/AdaBoost.java @@ -3,6 +3,8 @@ import hex.Model; import hex.ModelBuilder; import hex.ModelCategory; +import hex.deeplearning.DeepLearning; +import hex.deeplearning.DeepLearningModel; import hex.glm.GLM; import hex.glm.GLMModel; import hex.tree.drf.DRF; @@ -169,6 +171,8 @@ private ModelBuilder chooseWeakLearner(Frame frame) { return getGLMWeakLearner(frame); case GBM: return getGBMWeakLearner(frame); + case DEEP_LEARNING: + return getDeepLearningWeakLearner(frame); default: case DRF: return getDRFWeakLearner(frame); @@ -212,6 +216,17 @@ private GBM getGBMWeakLearner(Frame frame) { return new GBM(parms); } + private DeepLearning getDeepLearningWeakLearner(Frame frame) { + DeepLearningModel.DeepLearningParameters parms = new DeepLearningModel.DeepLearningParameters(); + parms._train = frame._key; + parms._response_column = _parms._response_column; + parms._weights_column = _weightsName; + parms._seed = _parms._seed; + parms._epochs = 10; + parms._hidden = new int[]{2}; + return new DeepLearning(parms); + } + public TwoDimTable createModelSummaryTable() { List colHeaders = new ArrayList<>(); List colTypes = new ArrayList<>(); diff --git a/h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java b/h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java index 24d848b27b16..3727ad2c0547 100644 --- a/h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java +++ b/h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java @@ -10,7 +10,7 @@ public class AdaBoostModel extends Model { private static final Logger LOG = Logger.getLogger(AdaBoostModel.class); - public enum Algorithm {DRF, GLM, GBM, AUTO} + public enum Algorithm {DRF, GLM, GBM, DEEP_LEARNING,AUTO} public AdaBoostModel(Key selfKey, AdaBoostParameters parms, AdaBoostOutput output) { diff --git a/h2o-algos/src/main/java/hex/schemas/AdaBoostV3.java b/h2o-algos/src/main/java/hex/schemas/AdaBoostV3.java index 1a1edb52189c..f2b9f8ea3219 100644 --- a/h2o-algos/src/main/java/hex/schemas/AdaBoostV3.java +++ b/h2o-algos/src/main/java/hex/schemas/AdaBoostV3.java @@ -29,7 +29,7 @@ public static final class AdaBoostParametersV3 extends ModelParametersSchemaV3