Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-15855: core pipeline API #16039

Merged
merged 11 commits into from
Feb 12, 2024
37 changes: 25 additions & 12 deletions h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package hex.Infogram;

import hex.*;
import hex.Infogram.InfogramModel.InfogramModelOutput;
import hex.Infogram.InfogramModel.InfogramParameters;
import hex.ModelMetrics.MetricBuilder;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
Expand All @@ -18,8 +21,8 @@
import static water.util.ArrayUtils.sort;
import static water.util.ArrayUtils.sum;

public class Infogram extends ModelBuilder<hex.Infogram.InfogramModel, hex.Infogram.InfogramModel.InfogramParameters,
hex.Infogram.InfogramModel.InfogramModelOutput> {
public class Infogram extends ModelBuilder<hex.Infogram.InfogramModel, InfogramParameters,
InfogramModelOutput> {
static final double NORMALIZE_ADMISSIBLE_INDEX = 1.0/Math.sqrt(2.0);
boolean _buildCore; // true to find core predictors, false to find admissible predictors
String[] _topKPredictors; // contain the names of top predictors to consider for infogram
Expand All @@ -45,14 +48,14 @@ public class Infogram extends ModelBuilder<hex.Infogram.InfogramModel, hex.Infog
Model.Parameters.FoldAssignmentScheme _foldAssignmentOrig = null;
String _foldColumnOrig = null;

public Infogram(boolean startup_once) { super(new hex.Infogram.InfogramModel.InfogramParameters(), startup_once);}
public Infogram(boolean startup_once) { super(new InfogramParameters(), startup_once);}

public Infogram(hex.Infogram.InfogramModel.InfogramParameters parms) {
public Infogram(InfogramParameters parms) {
super(parms);
init(false);
}

public Infogram(hex.Infogram.InfogramModel.InfogramParameters parms, Key<hex.Infogram.InfogramModel> key) {
public Infogram(InfogramParameters parms, Key<hex.Infogram.InfogramModel> key) {
super(parms, key);
init(false);
}
Expand All @@ -71,18 +74,23 @@ protected int nModelsInParallel(int folds) {
* This is called before cross-validation is carried out
*/
@Override
public void computeCrossValidation() {
protected void cv_init() {
Copy link
Contributor Author

@sebhrusen sebhrusen Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes in the algos are mainly due to the fact that the CV API now only exposes protected hooks at various places in the model building cycle, otherwise it breaks the pipeline logic that needs a very strict behaviour when building CV models (esp. as it needs full control over the frames being used at that time).
Algos are therefore encouraged to override only those small hooks, and the ModelBuilder itself remains algo-agnostic.

super.cv_init();
info("cross-validation", "cross-validation infogram information is stored in frame with key" +
" labeled as admissible_score_key_cv and the admissible features in admissible_features_cv.");
if (error_count() > 0) {
throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(Infogram.this);
}
super.computeCrossValidation();
}

@Override
protected MetricBuilder makeCVMetricBuilder(ModelBuilder<InfogramModel, InfogramParameters, InfogramModelOutput> cvModelBuilder, Futures fs) {
return null; //infogram does not support scoring
}

// find the best alpha/lambda values used to build the main model moving forward by looking at the devianceValid
@Override
public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
int nBuilders = cvModelBuilders.length;
double[][] cmiRaw = new double[nBuilders][];
List<List<String>> columns = new ArrayList<>();
Expand All @@ -103,7 +111,12 @@ public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
}
_cvDone = true; // cv is done and we are going to build main model next
}


@Override
protected void cv_mainModelScores(int N, MetricBuilder[] mbs, ModelBuilder<InfogramModel, InfogramParameters, InfogramModelOutput>[] cvModelBuilders) {
//infogram does not support scoring
}

public void calculateMeanInfogramInfo(double[][] cmiRaw, List<List<String>> columns,
long[] nObs) {
int nFolds = cmiRaw.length;
Expand Down Expand Up @@ -304,7 +317,7 @@ public final void buildModel() {
try {
boolean validPresent = _parms.valid() != null;
prepareModelTrainingFrame(); // generate training frame with predictors and sensitive features (if specified)
InfogramModel model = new hex.Infogram.InfogramModel(dest(), _parms, new hex.Infogram.InfogramModel.InfogramModelOutput(Infogram.this));
InfogramModel model = new hex.Infogram.InfogramModel(dest(), _parms, new InfogramModelOutput(Infogram.this));
_model = model.delete_and_lock(_job);
_model._output._start_time = System.currentTimeMillis();
_cmiRaw = new double[_numModels];
Expand Down Expand Up @@ -359,7 +372,7 @@ public final void buildModel() {
* relevance >= relevance_threshold. Derive _admissible_index as distance from point with cmi = 1 and
* relevance = 1. In addition, all arrays are sorted on _admissible_index.
*/
private void copyCMIRelevance(InfogramModel.InfogramModelOutput modelOutput) {
private void copyCMIRelevance(InfogramModelOutput modelOutput) {
modelOutput._cmi_raw = new double[_cmi.length];
System.arraycopy(_cmiRaw, 0, modelOutput._cmi_raw, 0, modelOutput._cmi_raw.length);
modelOutput._admissible_index = new double[_cmi.length];
Expand All @@ -375,7 +388,7 @@ private void copyCMIRelevance(InfogramModel.InfogramModelOutput modelOutput) {
modelOutput._admissible_index, modelOutput._admissible, modelOutput._all_predictor_names);
}

public void copyCMIRelevanceValid(InfogramModel.InfogramModelOutput modelOutput) {
public void copyCMIRelevanceValid(InfogramModelOutput modelOutput) {
modelOutput._cmi_raw_valid = new double[_cmiValid.length];
System.arraycopy(_cmiRawValid, 0, modelOutput._cmi_raw_valid, 0, modelOutput._cmi_raw_valid.length);
modelOutput._admissible_index_valid = new double[_cmiValid.length];
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/deeplearning/DeepLearning.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ static DataInfo makeDataInfo(Frame train, Frame valid, DeepLearningParameters pa
}
}

@Override public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
@Override protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
_parms._overwrite_with_best_model = false;

if( _parms._stopping_rounds == 0 && _parms._max_runtime_secs == 0) return; // No exciting changes to stopping conditions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ public ModelMetrics makeModelMetrics(Frame fr, Frame adaptFrm) {
@Override
public ModelMetrics.MetricBuilder<?> getMetricBuilder() {
throw new UnsupportedOperationException("Stacked Ensemble model doesn't implement MetricBuilder infrastructure code, " +
"retrieve your metrics by calling getOrMakeMetrics method.");
"retrieve your metrics by calling makeModelMetrics method.");
}
}

Expand Down
69 changes: 15 additions & 54 deletions h2o-algos/src/main/java/hex/glm/GLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import hex.util.LinearAlgebraUtils;
import hex.util.LinearAlgebraUtils.BMulTask;
import hex.util.LinearAlgebraUtils.FindMaxIndex;
import jsr166y.CountedCompleter;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import water.*;
Expand Down Expand Up @@ -119,7 +118,8 @@ public boolean isSupervised() {
public ModelCategory[] can_build() {
return new ModelCategory[]{
ModelCategory.Regression,
ModelCategory.Binomial,
ModelCategory.Binomial,
ModelCategory.Multinomial
};
}

Expand Down Expand Up @@ -148,13 +148,12 @@ public ModelCategory[] can_build() {
* (builds N+1 models, all have train+validation metrics, the main model has N-fold cross-validated validation metrics)
*/
@Override
public void computeCrossValidation() {
protected void cv_init() {
// init computes global list of lambdas
init(true);
_cvRuns = true;
if (error_count() > 0)
throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(GLM.this);
super.computeCrossValidation();
}


Expand Down Expand Up @@ -293,7 +292,7 @@ private double[] alignSubModelsAcrossCVModels(ModelBuilder[] cvModelBuilders) {
* 4. unlock the n-folds models (they are changed here, so the unlocking happens here)
*/
@Override
public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
setMaxRuntimeSecsForMainModel();
double bestTestDev = Double.POSITIVE_INFINITY;
double[] alphasAndLambdas = alignSubModelsAcrossCVModels(cvModelBuilders);
Expand Down Expand Up @@ -372,12 +371,6 @@ public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
break;
}
}
for (int i = 0; i < cvModelBuilders.length; ++i) {
GLM g = (GLM) cvModelBuilders[i];
if (g._toRemove != null)
for (Key k : g._toRemove)
Keyed.remove(k);
}

for (int i = 0; i < cvModelBuilders.length; ++i) {
GLM g = (GLM) cvModelBuilders[i];
Expand Down Expand Up @@ -1543,11 +1536,11 @@ private void buildModel() {

protected static final long WORK_TOTAL = 1000000;

transient Key [] _toRemove;

private Key[] removeLater(Key ...k){
_toRemove = _toRemove == null?k:ArrayUtils.append(_toRemove,k);
return k;
@Override
protected void cleanUp() {
if (_parms._lambda_search && _parms._is_cv_model)
keepUntilCompletion(_dinfo.getWeightsVec()._key);
super.cleanUp();
}

@Override protected GLMDriver trainModelImpl() { return _driver = new GLMDriver(); }
Expand Down Expand Up @@ -1576,23 +1569,6 @@ public final class GLMDriver extends Driver implements ProgressMonitor {
private transient GLMTask.GLMIterationTask _gramInfluence;
private transient double[][] _cholInvInfluence;

private void doCleanup() {
try {
if (_parms._lambda_search && _parms._is_cv_model)
Scope.untrack(removeLater(_dinfo.getWeightsVec()._key));
if (_parms._HGLM) {
Key[] vecKeys = _toRemove;
for (int index = 0; index < vecKeys.length; index++) {
Vec tempVec = DKV.getGet(vecKeys[index]);
tempVec.remove();
}
}
} catch (Exception e) {
Log.err("Error while cleaning up GLM " + _result);
Log.err(e);
}
}

private transient Cholesky _chol;
private transient L1Solver _lslvr;

Expand Down Expand Up @@ -3564,9 +3540,8 @@ private Vec[] genGLMVectors(DataInfo dinfo, double[] nb) {
sumExp += Math.exp(nb[i * N + P] - maxRow);
}
Vec[] vecs = dinfo._adaptedFrame.anyVec().makeDoubles(2, new double[]{sumExp, maxRow});
if (_parms._lambda_search && _parms._is_cv_model) {
Scope.untrack(vecs[0]._key, vecs[1]._key);
removeLater(vecs[0]._key, vecs[1]._key);
if (_parms._lambda_search) {
track(vecs[0]); track(vecs[1]);
}
return vecs;
}
Expand Down Expand Up @@ -3848,7 +3823,7 @@ private void checkCoeffsBounds() {
* - column 2: zi, intermediate values
* - column 3: eta = X*beta, intermediate values
*/
public void addWdataZiEtaOld2Response() { // attach wdata, zi, eta to response for HGLM
private void addWdataZiEtaOld2Response() { // attach wdata, zi, eta to response for HGLM
int moreColnum = 3 + _parms._random_columns.length;
Vec[] vecs = _dinfo._adaptedFrame.anyVec().makeZeros(moreColnum);
String[] colNames = new String[moreColnum];
Expand All @@ -3861,25 +3836,11 @@ public void addWdataZiEtaOld2Response() { // attach wdata, zi, eta to response f
vecs[index] = _parms.train().vec(randColIndices[index - 3]).makeCopy();
}
_dinfo.addResponse(colNames, vecs);
for (int index = 0; index < moreColnum; index++) {
Scope.untrack(vecs[index]._key);
removeLater(vecs[index]._key);
}
}

@Override
public void onCompletion(CountedCompleter caller) {
doCleanup();
super.onCompletion(caller);
Frame wdataZiEta = new Frame(Key.make("wdataZiEta"+Key.rand()), colNames, vecs);
DKV.put(wdataZiEta);
track(wdataZiEta);
}

@Override
public boolean onExceptionalCompletion(Throwable t, CountedCompleter caller) {
doCleanup();
return super.onExceptionalCompletion(t, caller);
}


@Override
public boolean progress(double[] beta, GradientInfo ginfo) {
_state._iter++;
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/kmeans/KMeans.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ else if( user_points.numRows() != _parms._k)
if (expensive && error_count() == 0) checkMemoryFootPrint();
}

public void cv_makeAggregateModelMetrics(ModelMetrics.MetricBuilder[] mbs){
protected void cv_makeAggregateModelMetrics(ModelMetrics.MetricBuilder[] mbs){
super.cv_makeAggregateModelMetrics(mbs);
((ModelMetricsClustering.MetricBuilderClustering) mbs[0])._within_sumsqe = null;
((ModelMetricsClustering.MetricBuilderClustering) mbs[0])._size = null;
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/tree/SharedTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ public double initialValue() {
return _parms._parallel_main_model_building;
}

@Override public void cv_computeAndSetOptimalParameters(ModelBuilder<M, P, O>[] cvModelBuilders) {
@Override protected void cv_computeAndSetOptimalParameters(ModelBuilder<M, P, O>[] cvModelBuilders) {
// Extract stopping conditions from each CV model, and compute the best stopping answer
if (!cv_initStoppingParameters())
return; // No exciting changes to stopping conditions
Expand Down
5 changes: 3 additions & 2 deletions h2o-bindings/bin/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ def get_customizations_for(language, algo, property=None, default=None):
tokens = property.split('.')
value = customizations
for token in tokens:
value = value.get(token)
if value is None:
if token in value:
value = value.get(token)
else:
return default
return value
else:
Expand Down
28 changes: 28 additions & 0 deletions h2o-bindings/bin/custom/R/gen_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

extensions = dict(
required_params=[],
frame_params=[],
validate_required_params="",
set_required_params="",
module="""
.h2o.fill_pipeline <- function(model, parameters, allparams) {
if (!is.null(model$estimator)) {
model$estimator_model <- h2o.getModel(model$estimator$name)
} else {
model$estimator_model <- NULL
}
model$transformers <- unlist(lapply(model$transformers, function(dt) new("H2ODataTransformer", id=dt$id, description=dt$description)))
# class(model) <- "H2OPipeline"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented? If needed, you can assign multiple classes, e.g., class(model) <- c("H2OPipeline", "H2OModel").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! right I forgot about the multiple inheritance. I think I commented it out because it didn't seem necessary (single algo don't have dedicated class) and it broke some behaviour somewhere (can't remember what exactly).
The funny part is that the class is defined as follow:

setClass("H2OPipeline", contains="H2OModel",

but afair, it was still not recognized as a model somewhere…
I will give a try to your suggestion, and see if it breaks some R tests.

return(model)
}
"""
)

doc = dict(
preamble="""
Build a pipeline model given a list of transformers and a final model.

Currently R model pipelines, as produced by AutoML for example,
are only available as read-only models that can not be constructed and trained directly by the end-user.
""",
)
56 changes: 56 additions & 0 deletions h2o-bindings/bin/custom/python/gen_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
supervised_learning = None # actually depends on the estimator model in the pipeline, leave it to None for now as it is needed only for training and we don't support pipeline as input yet


# in future update, we'll want to expose parameters applied to each transformer
def module_extensions():
class H2ODataTransformer(H2ODisplay):
@classmethod
def make(cls, kvs):
dt = H2ODataTransformer(**{k: v for k, v in kvs if k not in H2OSchema._ignored_schema_keys_})
dt._json = kvs
return dt

def __init__(self, id=None, description=None):
self._json = None
self.id = id
self.description = description

def _repr_(self):
return repr(self._json)

def _str_(self, verbosity=None):
return repr_def(self)


# self-register transformer class: done as soon as `h2o.estimators` is loaded, which means as soon as h2o.h2o is...
register_schema_handler("DataTransformerV3", H2ODataTransformer)


def class_extensions():
@property
def transformers(self):
return self._model_json['output']['transformers']

@property
def estimator_model(self):
m_json = self._model_json['output']['estimator']
return None if (m_json is None or m_json['name'] is None) else h2o.get_model(m_json['name'])

def transform(self, fr):
"""
Applies all the pipeline transformers to the given input frame.
:return: the transformed frame, as it would be passed to `estimator_model`, if calling `predict` instead.
"""
return H2OFrame._expr(expr=ExprNode("transform", ASTId(self.key), ASTId(fr.key)))._frame(fill_cache=True)


extensions = dict(
__imports__="""
import h2o
from h2o.display import H2ODisplay, repr_def
from h2o.expr import ASTId, ExprNode
from h2o.schemas import H2OSchema, register_schema_handler
""",
__class__=class_extensions,
__module__=module_extensions,
)
4 changes: 2 additions & 2 deletions h2o-bindings/bin/gen_R.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def get_schema_params(pname):
"verbose",
"destination_key"] # destination_key is only for SVD
bulk_params = list(zip(*filter(lambda t: not t[0] in bulk_pnames_skip, zip(sig_pnames, sig_params))))
bulk_pnames = list(bulk_params[0])
sig_bulk_params = list(bulk_params[1])
bulk_pnames = list(bulk_params[0]) if bulk_params else []
sig_bulk_params = list(bulk_params[1]) if bulk_params else []
sig_bulk_params.append("segment_columns = NULL")
sig_bulk_params.append("segment_models_id = NULL")
sig_bulk_params.append("parallelism = 1")
Expand Down
Loading
Loading