From eb526fa1092b1c72e280ce5bd78f05392691938c Mon Sep 17 00:00:00 2001
From: "Simon P. Couch" <simonpatrickcouch@gmail.com>
Date: Thu, 4 Apr 2024 13:40:39 -0500
Subject: [PATCH] clarify case weight support in `show_model_info()` (#1102)

---
 R/aaa_models.R                        | 14 ++--
 tests/testthat/_snaps/registration.md | 98 +++++++++++++++++++++++++++
 tests/testthat/test_registration.R    | 21 ++----
 3 files changed, 113 insertions(+), 20 deletions(-)

diff --git a/R/aaa_models.R b/R/aaa_models.R
index a5bd0a553..321f7de3b 100644
--- a/R/aaa_models.R
+++ b/R/aaa_models.R
@@ -991,7 +991,7 @@ show_model_info <- function(model) {
       ) %>%
       dplyr::select(engine, mode, has_wts)
 
-      engines %>%
+      engine_weight_info <- engines %>%
         dplyr::left_join(weight_info, by = c("engine", "mode")) %>%
       dplyr::mutate(
         engine = paste0(engine, has_wts),
@@ -1005,9 +1005,15 @@ show_model_info <- function(model) {
         lab = paste0("   ", mode, engine, "\n")
       ) %>%
       dplyr::ungroup() %>%
-      dplyr::pull(lab) %>%
-      cat(sep = "")
-    cat("\n", cli::symbol$sup_1, "The model can use case weights.\n\n", sep = "")
+      dplyr::pull(lab)
+
+    cat(engine_weight_info, sep = "")
+
+    if (!all(weight_info$has_wts == "")) {
+      cat("\n", cli::symbol$sup_1, "The model can use case weights.", sep = "")
+    }
+
+    cat("\n\n")
   } else {
     cat(" no registered engines.\n\n")
   }
diff --git a/tests/testthat/_snaps/registration.md b/tests/testthat/_snaps/registration.md
index 690e25445..723bacac8 100644
--- a/tests/testthat/_snaps/registration.md
+++ b/tests/testthat/_snaps/registration.md
@@ -6,3 +6,101 @@
       Error in `check_mode_for_new_engine()`:
       ! "regression" is not a known mode for model `sponge()`.
 
+# showing model info
+
+    Code
+      show_model_info("rand_forest")
+    Output
+      Information for `rand_forest`
+       modes: unknown, classification, regression, censored regression 
+      
+       engines: 
+         classification: randomForest, ranger1, spark
+         regression:     randomForest, ranger1, spark
+      
+      1The model can use case weights.
+      
+       arguments: 
+         ranger:       
+            mtry  --> mtry
+            trees --> num.trees
+            min_n --> min.node.size
+         randomForest: 
+            mtry  --> mtry
+            trees --> ntree
+            min_n --> nodesize
+         spark:        
+            mtry  --> feature_subset_strategy
+            trees --> num_trees
+            min_n --> min_instances_per_node
+      
+       fit modules:
+               engine           mode
+               ranger classification
+               ranger     regression
+         randomForest classification
+         randomForest     regression
+                spark classification
+                spark     regression
+      
+       prediction modules:
+                   mode       engine                    methods
+         classification randomForest           class, prob, raw
+         classification       ranger class, conf_int, prob, raw
+         classification        spark                class, prob
+             regression randomForest               numeric, raw
+             regression       ranger     conf_int, numeric, raw
+             regression        spark                    numeric
+      
+
+---
+
+    Code
+      show_model_info("mlp")
+    Output
+      Information for `mlp`
+       modes: unknown, classification, regression 
+      
+       engines: 
+         classification: brulee, keras, nnet
+         regression:     brulee, keras, nnet
+      
+      
+       arguments: 
+         keras:  
+            hidden_units --> hidden_units
+            penalty      --> penalty
+            dropout      --> dropout
+            epochs       --> epochs
+            activation   --> activation
+         nnet:   
+            hidden_units --> size
+            penalty      --> decay
+            epochs       --> maxit
+         brulee: 
+            hidden_units --> hidden_units
+            penalty      --> penalty
+            epochs       --> epochs
+            dropout      --> dropout
+            learn_rate   --> learn_rate
+            activation   --> activation
+      
+       fit modules:
+         engine           mode
+          keras     regression
+          keras classification
+           nnet     regression
+           nnet classification
+         brulee     regression
+         brulee classification
+      
+       prediction modules:
+                   mode engine          methods
+         classification brulee      class, prob
+         classification  keras class, prob, raw
+         classification   nnet class, prob, raw
+             regression brulee          numeric
+             regression  keras     numeric, raw
+             regression   nnet     numeric, raw
+      
+
diff --git a/tests/testthat/test_registration.R b/tests/testthat/test_registration.R
index 414e59818..57e9c054c 100644
--- a/tests/testthat/test_registration.R
+++ b/tests/testthat/test_registration.R
@@ -496,21 +496,10 @@ test_that('adding a new predict method', {
 
 
 test_that('showing model info', {
-  expect_output(
-    show_model_info("rand_forest"),
-    "Information for `rand_forest`"
-  )
-  expect_output(
-    show_model_info("rand_forest"),
-    "trees --> ntree"
-  )
-  expect_output(
-    show_model_info("rand_forest"),
-    "fit modules:"
-  )
-  expect_output(
-    show_model_info("rand_forest"),
-    "prediction modules:"
-  )
+  expect_snapshot(show_model_info("rand_forest"))
+
+  # ensure that we don't mention case weight support when the
+  # notation would be ambiguous (#1000)
+  expect_snapshot(show_model_info("mlp"))
 })