diff --git a/h2o-py/h2o/estimators/rulefit.py b/h2o-py/h2o/estimators/rulefit.py index ba6823c558ab..639e34129c53 100644 --- a/h2o-py/h2o/estimators/rulefit.py +++ b/h2o-py/h2o/estimators/rulefit.py @@ -308,7 +308,7 @@ def max_num_rules(self): >>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"] >>> y = "survived" >>> rfit = H2ORuleFitEstimator(max_rule_length=10, - ... max_num_rules=-2, + ... max_num_rules=3, ... seed=1) >>> rfit.train(training_frame=df, x=x, y=y) >>> print(rfit.rule_importance()) @@ -523,9 +523,9 @@ def rule_importance(self): >>> y = "survived" >>> rfit = H2ORuleFitEstimator(max_rule_length=10, ... max_num_rules=100, - ... algorithm="gbm", ... seed=1) >>> rfit.train(training_frame=df, x=x, y=y) + >>> rule_importance = rfit.rule_importance() >>> print(rfit.rule_importance()) """ if self._model_json["algo"] != "rulefit": @@ -549,18 +549,19 @@ def predict_rules(self, frame, rule_ids): >>> import h2o >>> h2o.init() >>> from h2o.estimators import H2ORuleFitEstimator - >>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv" - >>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"}) - >>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"] - >>> y = "survived" - >>> rfit = H2ORuleFitEstimator(max_rule_length=10, - ... max_num_rules=100, - ... rule_generation_ntrees=60, - ... seed=1) - >>> rfit.train(training_frame=df, x=x, y=y) - >>> rules_to_predict = ['rule_1', 'rule_2'] # Replace with actual rule IDs - >>> predictions = rfit.predict_rules(frame=df, rule_ids=rules_to_predict) - >>> print(predictions) + >>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/iris/iris_train.csv" + >>> df = h2o.import_file(path=f, col_types={'species': "enum"}) + >>> x = df.columns + >>> y = "species" + >>> x.remove(y) + >>> train, test = df.split_frame(ratios=[.8], seed=1234) + >>> rfit = H2ORuleFitEstimator(min_rule_length=4, + ... max_rule_length=5, + ... max_num_rules=3, + ... seed=1234, + ... model_type="rules") + >>> rfit.train(training_frame=train, x=x, y=y, validation_frame=test) + >>> print(rfit.predict_rules(train, ['M0T38N5_Iris-virginica'])) """ from h2o.frame import H2OFrame from h2o.utils.typechecks import assert_is_type