Skip to content

Commit

Permalink
ht/gradle build
Browse files Browse the repository at this point in the history
  • Loading branch information
hannah-tillman committed Oct 1, 2024
1 parent 72904b5 commit 4e216a8
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions h2o-py/h2o/estimators/rulefit.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def max_num_rules(self):
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=-2,
... max_num_rules=3,
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> print(rfit.rule_importance())
Expand Down Expand Up @@ -523,9 +523,9 @@ def rule_importance(self):
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=100,
... algorithm="gbm",
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> rule_importance = rfit.rule_importance()
>>> print(rfit.rule_importance())
"""
if self._model_json["algo"] != "rulefit":
Expand All @@ -549,18 +549,19 @@ def predict_rules(self, frame, rule_ids):
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv"
>>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"})
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=100,
... rule_generation_ntrees=60,
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> rules_to_predict = ['rule_1', 'rule_2'] # Replace with actual rule IDs
>>> predictions = rfit.predict_rules(frame=df, rule_ids=rules_to_predict)
>>> print(predictions)
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/iris/iris_train.csv"
>>> df = h2o.import_file(path=f, col_types={'species': "enum"})
>>> x = df.columns
>>> y = "species"
>>> x.remove(y)
>>> train, test = df.split_frame(ratios=[.8], seed=1234)
>>> rfit = H2ORuleFitEstimator(min_rule_length=4,
... max_rule_length=5,
... max_num_rules=3,
... seed=1234,
... model_type="rules")
>>> rfit.train(training_frame=train, x=x, y=y, validation_frame=test)
>>> print(rfit.predict_rules(train, ['M0T38N5_Iris-virginica']))
"""
from h2o.frame import H2OFrame
from h2o.utils.typechecks import assert_is_type
Expand Down

0 comments on commit 4e216a8

Please sign in to comment.