diff --git a/NEWS.md b/NEWS.md index 709a7c492..7fda15fa1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # parsnip (development version) +* Improved errors in cases where the outcome column is mis-specified. (#1003) + # parsnip 1.1.1 * Fixed bug where prediction on rank deficient `lm()` models produced `.pred_res` instead of `.pred`. (#985) diff --git a/R/misc.R b/R/misc.R index eb4e30a3c..225fdf915 100644 --- a/R/misc.R +++ b/R/misc.R @@ -354,6 +354,16 @@ check_outcome <- function(y, spec) { return(invisible(NULL)) } + has_no_outcome <- if (is.atomic(y)) {is.null(y)} else {length(y) == 0} + if (isTRUE(has_no_outcome)) { + cli::cli_abort( + c("!" = "{.fun {class(spec)[1]}} was unable to find an outcome.", + "i" = "Ensure that you have specified an outcome column and that it \\ + hasn't been removed in pre-processing."), + call = NULL + ) + } + if (spec$mode == "regression") { outcome_is_numeric <- if (is.atomic(y)) {is.numeric(y)} else {all(map_lgl(y, is.numeric))} if (!outcome_is_numeric) { diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index 79f6fda71..fd4ededa4 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -140,6 +140,33 @@ Error in `check_outcome()`: ! For a regression model, the outcome should be `numeric`, not a `factor`. +--- + + Code + check_outcome(NULL, reg_spec) + Condition + Error: + ! `linear_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + check_outcome(tibble::new_tibble(list(), nrow = 10), reg_spec) + Condition + Error: + ! `linear_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + fit(reg_spec, ~mpg, mtcars) + Condition + Error: + ! `linear_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + --- Code @@ -148,6 +175,33 @@ Error in `check_outcome()`: ! For a classification model, the outcome should be a `factor`, not a `integer`. +--- + + Code + check_outcome(NULL, class_spec) + Condition + Error: + ! `logistic_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + check_outcome(tibble::new_tibble(list(), nrow = 10), class_spec) + Condition + Error: + ! `logistic_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + fit(class_spec, ~mpg, mtcars) + Condition + Error: + ! `logistic_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + --- Code diff --git a/tests/testthat/test_misc.R b/tests/testthat/test_misc.R index f18631472..88c4325ba 100644 --- a/tests/testthat/test_misc.R +++ b/tests/testthat/test_misc.R @@ -205,6 +205,21 @@ test_that('check_outcome works as expected', { check_outcome(factor(1:2), reg_spec) ) + expect_snapshot( + error = TRUE, + check_outcome(NULL, reg_spec) + ) + + expect_snapshot( + error = TRUE, + check_outcome(tibble::new_tibble(list(), nrow = 10), reg_spec) + ) + + expect_snapshot( + error = TRUE, + fit(reg_spec, ~ mpg, mtcars) + ) + class_spec <- logistic_reg() expect_no_error( @@ -220,6 +235,21 @@ test_that('check_outcome works as expected', { check_outcome(1:2, class_spec) ) + expect_snapshot( + error = TRUE, + check_outcome(NULL, class_spec) + ) + + expect_snapshot( + error = TRUE, + check_outcome(tibble::new_tibble(list(), nrow = 10), class_spec) + ) + + expect_snapshot( + error = TRUE, + fit(class_spec, ~ mpg, mtcars) + ) + # Fake specification to avoid having to load {censored} cens_spec <- logistic_reg() cens_spec$mode <- "censored regression"