diff --git a/R/tunable.R b/R/tunable.R index 85c8bff29..271c470e2 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -248,7 +248,7 @@ tunable.linear_reg <- function(x, ...) { } else if (x$engine == "brulee") { res <- add_engine_parameters(res, brulee_linear_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -260,7 +260,7 @@ tunable.logistic_reg <- function(x, ...) { } else if (x$engine == "brulee") { res <- add_engine_parameters(res, brulee_logistic_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -272,7 +272,7 @@ tunable.multinomial_reg <- function(x, ...) { } else if (x$engine == "brulee") { res <- add_engine_parameters(res, brulee_multinomial_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -295,7 +295,7 @@ tunable.boost_tree <- function(x, ...) { res$call_info[res$name == "sample_size"] <- list(list(pkg = "dials", fun = "sample_prop")) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -310,7 +310,7 @@ tunable.rand_forest <- function(x, ...) { } else if (x$engine == "aorsf") { res <- add_engine_parameters(res, aorsf_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -319,7 +319,7 @@ tunable.mars <- function(x, ...) { if (x$engine == "earth") { res <- add_engine_parameters(res, earth_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -333,7 +333,7 @@ tunable.decision_tree <- function(x, ...) { partykit_engine_args %>% dplyr::mutate(component = "decision_tree")) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -343,7 +343,7 @@ tunable.svm_poly <- function(x, ...) { res$call_info[res$name == "degree"] <- list(list(pkg = "dials", fun = "prod_degree", range = c(1L, 3L))) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } @@ -357,7 +357,7 @@ tunable.mlp <- function(x, ...) { res$call_info[res$name == "epochs"] <- list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L))) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -366,7 +366,7 @@ tunable.survival_reg <- function(x, ...) { if (x$engine == "flexsurvspline") { res <- add_engine_parameters(res, flexsurvspline_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } # nocov end diff --git a/tests/testthat/_snaps/tunable.md b/tests/testthat/_snaps/tunable.md new file mode 100644 index 000000000..9f8b6ba3b --- /dev/null +++ b/tests/testthat/_snaps/tunable.md @@ -0,0 +1,542 @@ +# tunable.linear_reg() + + Code + tunable(spec) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("lm")) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("glmnet")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec linear_reg main + 2 mixture model_spec linear_reg main + +--- + + Code + tunable(spec %>% set_engine("brulee")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec linear_reg main + 2 mixture model_spec linear_reg main + +--- + + Code + tunable(spec %>% set_engine("glmnet", dfmax = tune())) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec linear_reg main + 2 mixture model_spec linear_reg main + +# tunable.logistic_reg() + + Code + tunable(spec) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("glm")) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("glmnet")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec logistic_reg main + 2 mixture model_spec logistic_reg main + +--- + + Code + tunable(spec %>% set_engine("brulee")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec logistic_reg main + 2 mixture model_spec logistic_reg main + +--- + + Code + tunable(spec %>% set_engine("glmnet", dfmax = tune())) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec logistic_reg main + 2 mixture model_spec logistic_reg main + +# tunable.multinom_reg() + + Code + tunable(spec) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("glmnet")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + 2 mixture model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("spark")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + 2 mixture model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("keras")) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("nnet")) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("brulee")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + 2 mixture model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("glmnet", dfmax = tune())) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + 2 mixture model_spec multinom_reg main + 3 dfmax model_spec multinom_reg engine + +# tunable.boost_tree() + + Code + tunable(spec) + Output + # A tibble: 8 x 5 + name call_info source component component_id + + 1 tree_depth model_spec boost_tree main + 2 trees model_spec boost_tree main + 3 learn_rate model_spec boost_tree main + 4 mtry model_spec boost_tree main + 5 min_n model_spec boost_tree main + 6 loss_reduction model_spec boost_tree main + 7 sample_size model_spec boost_tree main + 8 stop_iter model_spec boost_tree main + +--- + + Code + tunable(spec %>% set_engine("xgboost")) + Output + # A tibble: 8 x 5 + name call_info source component component_id + + 1 tree_depth model_spec boost_tree main + 2 trees model_spec boost_tree main + 3 learn_rate model_spec boost_tree main + 4 mtry model_spec boost_tree main + 5 min_n model_spec boost_tree main + 6 loss_reduction model_spec boost_tree main + 7 sample_size model_spec boost_tree main + 8 stop_iter model_spec boost_tree main + +--- + + Code + tunable(spec %>% set_engine("C5.0")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 trees model_spec boost_tree main + 2 min_n model_spec boost_tree main + 3 sample_size model_spec boost_tree main + +--- + + Code + tunable(spec %>% set_engine("spark")) + Output + # A tibble: 7 x 5 + name call_info source component component_id + + 1 tree_depth model_spec boost_tree main + 2 trees model_spec boost_tree main + 3 learn_rate model_spec boost_tree main + 4 mtry model_spec boost_tree main + 5 min_n model_spec boost_tree main + 6 loss_reduction model_spec boost_tree main + 7 sample_size model_spec boost_tree main + +--- + + Code + tunable(spec %>% set_engine("xgboost", feval = tune())) + Output + # A tibble: 8 x 5 + name call_info source component component_id + + 1 tree_depth model_spec boost_tree main + 2 trees model_spec boost_tree main + 3 learn_rate model_spec boost_tree main + 4 mtry model_spec boost_tree main + 5 min_n model_spec boost_tree main + 6 loss_reduction model_spec boost_tree main + 7 sample_size model_spec boost_tree main + 8 stop_iter model_spec boost_tree main + +# tunable.rand_forest() + + Code + tunable(spec) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main + +--- + + Code + tunable(spec %>% set_engine("ranger")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main + +--- + + Code + tunable(spec %>% set_engine("randomForest")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main + +--- + + Code + tunable(spec %>% set_engine("spark")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main + +--- + + Code + tunable(spec %>% set_engine("ranger", min.bucket = tune())) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main + +# tunable.mars() + + Code + tunable(spec) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 num_terms model_spec mars main + 2 prod_degree model_spec mars main + 3 prune_method model_spec mars main + +--- + + Code + tunable(spec %>% set_engine("earth")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 num_terms model_spec mars main + 2 prod_degree model_spec mars main + 3 prune_method model_spec mars main + +--- + + Code + tunable(spec %>% set_engine("earth", minspan = tune())) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 num_terms model_spec mars main + 2 prod_degree model_spec mars main + 3 prune_method model_spec mars main + +# tunable.decision_tree() + + Code + tunable(spec) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 tree_depth model_spec decision_tree main + 2 min_n model_spec decision_tree main + 3 cost_complexity model_spec decision_tree main + +--- + + Code + tunable(spec %>% set_engine("rpart")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 tree_depth model_spec decision_tree main + 2 min_n model_spec decision_tree main + 3 cost_complexity model_spec decision_tree main + +--- + + Code + tunable(spec %>% set_engine("C5.0")) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 min_n model_spec decision_tree main + +--- + + Code + tunable(spec %>% set_engine("spark")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 tree_depth model_spec decision_tree main + 2 min_n model_spec decision_tree main + +--- + + Code + tunable(spec %>% set_engine("rpart", parms = tune())) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 tree_depth model_spec decision_tree main + 2 min_n model_spec decision_tree main + 3 cost_complexity model_spec decision_tree main + +# tunable.svm_poly() + + Code + tunable(spec) + Output + # A tibble: 4 x 5 + name call_info source component component_id + + 1 cost model_spec svm_poly main + 2 degree model_spec svm_poly main + 3 scale_factor model_spec svm_poly main + 4 margin model_spec svm_poly main + +--- + + Code + tunable(spec %>% set_engine("kernlab")) + Output + # A tibble: 4 x 5 + name call_info source component component_id + + 1 cost model_spec svm_poly main + 2 degree model_spec svm_poly main + 3 scale_factor model_spec svm_poly main + 4 margin model_spec svm_poly main + +--- + + Code + tunable(spec %>% set_engine("kernlab", tol = tune())) + Output + # A tibble: 4 x 5 + name call_info source component component_id + + 1 cost model_spec svm_poly main + 2 degree model_spec svm_poly main + 3 scale_factor model_spec svm_poly main + 4 margin model_spec svm_poly main + +# tunable.mlp() + + Code + tunable(spec) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 epochs model_spec mlp main + +--- + + Code + tunable(spec %>% set_engine("keras")) + Output + # A tibble: 5 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 dropout model_spec mlp main + 4 epochs model_spec mlp main + 5 activation model_spec mlp main + +--- + + Code + tunable(spec %>% set_engine("nnet")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 epochs model_spec mlp main + +--- + + Code + tunable(spec %>% set_engine("brulee")) + Output + # A tibble: 6 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 epochs model_spec mlp main + 4 dropout model_spec mlp main + 5 learn_rate model_spec mlp main + 6 activation model_spec mlp main + +--- + + Code + tunable(spec %>% set_engine("keras", ragged = tune())) + Output + # A tibble: 5 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 dropout model_spec mlp main + 4 epochs model_spec mlp main + 5 activation model_spec mlp main + +# tunable.survival_reg() + + Code + tunable(spec) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("survival")) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("survival", parms = tune())) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + diff --git a/tests/testthat/test-tunable.R b/tests/testthat/test-tunable.R new file mode 100644 index 000000000..8e8dd6ab7 --- /dev/null +++ b/tests/testthat/test-tunable.R @@ -0,0 +1,116 @@ +# general pattern, for each tunable method: +# define `spec`, run `show_engines()` with only parsnip loaded, +# snapshot test `tunable()` output for each unique engine. +# +# note that, as implemented, parsnip can return `tunable()` information +# for engines that it cannot fit without first loading an extension package. +# +# the specific contents of call_info are just hard-coded tibbles in the +# source, so snapshot testing only for their presence rather than contents. + +test_that("tunable.linear_reg()", { + spec <- linear_reg() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("lm"))) + expect_snapshot(tunable(spec %>% set_engine("glmnet"))) + expect_snapshot(tunable(spec %>% set_engine("brulee"))) + + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("glmnet", dfmax = tune()))) +}) + +test_that("tunable.logistic_reg()", { + spec <- logistic_reg() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("glm"))) + expect_snapshot(tunable(spec %>% set_engine("glmnet"))) + expect_snapshot(tunable(spec %>% set_engine("brulee"))) + + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("glmnet", dfmax = tune()))) +}) + +test_that("tunable.multinom_reg()", { + spec <- multinom_reg() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("glmnet"))) + expect_snapshot(tunable(spec %>% set_engine("spark"))) + expect_snapshot(tunable(spec %>% set_engine("keras"))) + expect_snapshot(tunable(spec %>% set_engine("nnet"))) + expect_snapshot(tunable(spec %>% set_engine("brulee"))) + + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("glmnet", dfmax = tune()))) +}) + +test_that("tunable.boost_tree()", { + spec <- boost_tree() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("xgboost"))) + expect_snapshot(tunable(spec %>% set_engine("C5.0"))) + expect_snapshot(tunable(spec %>% set_engine("spark"))) + + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("xgboost", feval = tune()))) +}) + +test_that("tunable.rand_forest()", { + spec <- rand_forest() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("ranger"))) + expect_snapshot(tunable(spec %>% set_engine("randomForest"))) + expect_snapshot(tunable(spec %>% set_engine("spark"))) + + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("ranger", min.bucket = tune()))) +}) + +test_that("tunable.mars()", { + spec <- mars() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("earth"))) + + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("earth", minspan = tune()))) +}) + +test_that("tunable.decision_tree()", { + spec <- decision_tree() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("rpart"))) + expect_snapshot(tunable(spec %>% set_engine("C5.0"))) + expect_snapshot(tunable(spec %>% set_engine("spark"))) + + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("rpart", parms = tune()))) +}) + +test_that("tunable.svm_poly()", { + spec <- svm_poly() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("kernlab"))) + + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("kernlab", tol = tune()))) +}) + +test_that("tunable.mlp()", { + spec <- mlp() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("keras"))) + expect_snapshot(tunable(spec %>% set_engine("nnet"))) + expect_snapshot(tunable(spec %>% set_engine("brulee"))) + + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("keras", ragged = tune()))) +}) + + +test_that("tunable.survival_reg()", { + spec <- survival_reg() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("survival"))) + + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("survival", parms = tune()))) +})