From eb526fa1092b1c72e280ce5bd78f05392691938c Mon Sep 17 00:00:00 2001 From: "Simon P. Couch" <simonpatrickcouch@gmail.com> Date: Thu, 4 Apr 2024 13:40:39 -0500 Subject: [PATCH] clarify case weight support in `show_model_info()` (#1102) --- R/aaa_models.R | 14 ++-- tests/testthat/_snaps/registration.md | 98 +++++++++++++++++++++++++++ tests/testthat/test_registration.R | 21 ++---- 3 files changed, 113 insertions(+), 20 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index a5bd0a553..321f7de3b 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -991,7 +991,7 @@ show_model_info <- function(model) { ) %>% dplyr::select(engine, mode, has_wts) - engines %>% + engine_weight_info <- engines %>% dplyr::left_join(weight_info, by = c("engine", "mode")) %>% dplyr::mutate( engine = paste0(engine, has_wts), @@ -1005,9 +1005,15 @@ show_model_info <- function(model) { lab = paste0(" ", mode, engine, "\n") ) %>% dplyr::ungroup() %>% - dplyr::pull(lab) %>% - cat(sep = "") - cat("\n", cli::symbol$sup_1, "The model can use case weights.\n\n", sep = "") + dplyr::pull(lab) + + cat(engine_weight_info, sep = "") + + if (!all(weight_info$has_wts == "")) { + cat("\n", cli::symbol$sup_1, "The model can use case weights.", sep = "") + } + + cat("\n\n") } else { cat(" no registered engines.\n\n") } diff --git a/tests/testthat/_snaps/registration.md b/tests/testthat/_snaps/registration.md index 690e25445..723bacac8 100644 --- a/tests/testthat/_snaps/registration.md +++ b/tests/testthat/_snaps/registration.md @@ -6,3 +6,101 @@ Error in `check_mode_for_new_engine()`: ! "regression" is not a known mode for model `sponge()`. +# showing model info + + Code + show_model_info("rand_forest") + Output + Information for `rand_forest` + modes: unknown, classification, regression, censored regression + + engines: + classification: randomForest, ranger1, spark + regression: randomForest, ranger1, spark + + 1The model can use case weights. + + arguments: + ranger: + mtry --> mtry + trees --> num.trees + min_n --> min.node.size + randomForest: + mtry --> mtry + trees --> ntree + min_n --> nodesize + spark: + mtry --> feature_subset_strategy + trees --> num_trees + min_n --> min_instances_per_node + + fit modules: + engine mode + ranger classification + ranger regression + randomForest classification + randomForest regression + spark classification + spark regression + + prediction modules: + mode engine methods + classification randomForest class, prob, raw + classification ranger class, conf_int, prob, raw + classification spark class, prob + regression randomForest numeric, raw + regression ranger conf_int, numeric, raw + regression spark numeric + + +--- + + Code + show_model_info("mlp") + Output + Information for `mlp` + modes: unknown, classification, regression + + engines: + classification: brulee, keras, nnet + regression: brulee, keras, nnet + + + arguments: + keras: + hidden_units --> hidden_units + penalty --> penalty + dropout --> dropout + epochs --> epochs + activation --> activation + nnet: + hidden_units --> size + penalty --> decay + epochs --> maxit + brulee: + hidden_units --> hidden_units + penalty --> penalty + epochs --> epochs + dropout --> dropout + learn_rate --> learn_rate + activation --> activation + + fit modules: + engine mode + keras regression + keras classification + nnet regression + nnet classification + brulee regression + brulee classification + + prediction modules: + mode engine methods + classification brulee class, prob + classification keras class, prob, raw + classification nnet class, prob, raw + regression brulee numeric + regression keras numeric, raw + regression nnet numeric, raw + + diff --git a/tests/testthat/test_registration.R b/tests/testthat/test_registration.R index 414e59818..57e9c054c 100644 --- a/tests/testthat/test_registration.R +++ b/tests/testthat/test_registration.R @@ -496,21 +496,10 @@ test_that('adding a new predict method', { test_that('showing model info', { - expect_output( - show_model_info("rand_forest"), - "Information for `rand_forest`" - ) - expect_output( - show_model_info("rand_forest"), - "trees --> ntree" - ) - expect_output( - show_model_info("rand_forest"), - "fit modules:" - ) - expect_output( - show_model_info("rand_forest"), - "prediction modules:" - ) + expect_snapshot(show_model_info("rand_forest")) + + # ensure that we don't mention case weight support when the + # notation would be ambiguous (#1000) + expect_snapshot(show_model_info("mlp")) })