Skip to content

Commit

Permalink
feat: add select_features to ranger (#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
bommert authored Dec 18, 2024
1 parent 597f5e1 commit 09bbb2f
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 2 deletions.
10 changes: 9 additions & 1 deletion R/LearnerClassifRanger.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger",
param_set = ps,
predict_types = c("response", "prob"),
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
properties = c("weights", "twoclass", "multiclass", "importance", "oob_error", "hotstart_backward", "missings"),
properties = c("weights", "twoclass", "multiclass", "importance", "oob_error", "hotstart_backward", "missings", "selected_features"),
packages = c("mlr3learners", "ranger"),
label = "Random Forest",
man = "mlr3learners::mlr_learners_classif.ranger"
Expand Down Expand Up @@ -117,6 +117,14 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger",
}
self$model$prediction.error
}

#' @description
#' The set of features used for node splitting in the forest.
#'
#' @return `character()`.
selected_features = function() {
ranger_selected_features(self)
}
),

private = list(
Expand Down
10 changes: 9 additions & 1 deletion R/LearnerRegrRanger.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
param_set = ps,
predict_types = c("response", "se", "quantiles"),
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
properties = c("weights", "importance", "oob_error", "hotstart_backward", "missings"),
properties = c("weights", "importance", "oob_error", "hotstart_backward", "missings", "selected_features"),
packages = c("mlr3learners", "ranger"),
label = "Random Forest",
man = "mlr3learners::mlr_learners_regr.ranger"
Expand Down Expand Up @@ -99,6 +99,14 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
}
self$model$prediction.error
}

#' @description
#' The set of features used for node splitting in the forest.
#'
#' @return `character()`.
selected_features = function() {
ranger_selected_features(self)
}
),

private = list(
Expand Down
23 changes: 23 additions & 0 deletions R/helpers_ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,26 @@ convert_ratio = function(pv, target, ratio, n) {
stopf("Hyperparameters '%s' and '%s' are mutually exclusive", target, ratio)
)
}




ranger_selected_features = function(self) {
if (is.null(self$model)) {
stopf("No model stored")
}

splitvars = ranger::treeInfo(object = self$model, tree = 1)$splitvarName
i = 2
while (i <= self$model$num.trees &&
!all(self$state$feature_names %in% splitvars)) {
sv = ranger::treeInfo(object = self$model, tree = i)$splitvarName
splitvars = union(splitvars, sv)
i = i + 1
}

# order the names of the selected features in the same order as in the task
self$state$feature_names[self$state$feature_names %in% splitvars]
}


9 changes: 9 additions & 0 deletions tests/testthat/test_classif_ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,12 @@ test_that("default_values", {
values = default_values(learner, search_space, task)
expect_names(names(values), permutation.of = c("replace", "sample.fraction", "num.trees", "mtry.ratio"))
})

test_that("selected_features", {
learner = lrn("classif.ranger")
expect_error(learner$selected_features())

task = tsk("iris")
learner$train(task)
expect_set_equal(learner$selected_features(), c("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"))
})
9 changes: 9 additions & 0 deletions tests/testthat/test_regr_ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,12 @@ test_that("quantile prediction", {
expect_names(names(tab), identical.to = c("row_ids", "truth", "q0.1", "q0.5", "q0.9", "response"))
expect_equal(tab$response, tab$q0.5)
})

test_that("selected_features", {
learner = lrn("regr.ranger")
expect_error(learner$selected_features())

task$select(c("am", "cyl", "wt"))
learner$train(task)
expect_set_equal(learner$selected_features(), c("am", "cyl", "wt"))
})

0 comments on commit 09bbb2f

Please sign in to comment.