Skip to content

Commit

Permalink
rlang to cli for abort and warn, Fixes #1141
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesHWade committed Aug 17, 2024
1 parent ffb7570 commit 583d774
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 38 deletions.
14 changes: 8 additions & 6 deletions R/predict_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@
#' @export predict_class.model_fit
#' @export
predict_class.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "classification")
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
if (object$spec$mode != "classification") {
cli::cli_abort("{.code predict.model_fit()} is for predicting factor outcomes.")
}

check_spec_pred_type(object, "class")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$class$pre))
if (!is.null(object$spec$method$pred$class$pre)) {
new_data <- object$spec$method$pred$class$pre(new_data, object)
}

# create prediction call
pred_call <- make_pred_call(object$spec$method$pred$class)
Expand Down Expand Up @@ -56,6 +58,6 @@ predict_class.model_fit <- function(object, new_data, ...) {
# @keywords internal
# @rdname other_predict
# @inheritParams predict.model_fit
predict_class <- function(object, ...)
predict_class <- function(object, ...) {
UseMethod("predict_class")

}
31 changes: 17 additions & 14 deletions R/predict_classprob.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@
#' @export predict_classprob.model_fit
#' @export
predict_classprob.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "classification")
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
if (object$spec$mode != "classification") {
cli::cli_abort("{.code predict.model_fit()} is for predicting factor outcomes.")
}

check_spec_pred_type(object, "prob")
check_spec_levels(object)

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$prob$pre))
if (!is.null(object$spec$method$pred$prob$pre)) {
new_data <- object$spec$method$pred$prob$pre(new_data, object)
}

# create prediction call
pred_call <- make_pred_call(object$spec$method$pred$prob)
Expand All @@ -33,11 +35,13 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
}

# check and sort names
if (!is.data.frame(res) & !inherits(res, "tbl_spark"))
rlang::abort("The was a problem with the probability predictions.")
if (!is.data.frame(res) & !inherits(res, "tbl_spark")) {
cli::cli_abort("The was a problem with the probability predictions.")
}

if (!is_tibble(res) & !inherits(res, "tbl_spark"))
if (!is_tibble(res) & !inherits(res, "tbl_spark")) {
res <- as_tibble(res)
}

res
}
Expand All @@ -46,17 +50,16 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
# @keywords internal
# @rdname other_predict
# @inheritParams predict.model_fit
predict_classprob <- function(object, ...)
predict_classprob <- function(object, ...) {
UseMethod("predict_classprob")
}

check_spec_levels <- function(spec) {
if ("class" %in% spec$lvl) {
rlang::abort(
glue::glue(
"The outcome variable `{spec$preproc$y_var}` has a level called 'class'. ",
"This value is reserved for parsnip's classification internals; please ",
"change the levels, perhaps with `forcats::fct_relevel()`."
),
cli::cli_abort(
"The outcome variable {.var {spec$preproc$y_var}} has a level called {.val class}.
This value is reserved for parsnip's classification internals; please
change the levels, perhaps with {.fn forcats::fct_relevel}.",
call = NULL
)
}
Expand Down
24 changes: 15 additions & 9 deletions R/predict_numeric.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,33 @@
#' @export predict_numeric.model_fit
#' @export
predict_numeric.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "regression")
rlang::abort(glue::glue("`predict_numeric()` is for predicting numeric outcomes. ",
"Use `predict_class()` or `predict_classprob()` for ",
"classification models."))
if (object$spec$mode != "regression") {
cli::cli_abort(
"{.code predict_numeric()} is for predicting numeric outcomes.
Use {.code predict_class()} or {.code predict_classprob()} for
classification models."
)
}

check_spec_pred_type(object, "numeric")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$numeric$pre))
if (!is.null(object$spec$method$pred$numeric$pre)) {
new_data <- object$spec$method$pred$numeric$pre(new_data, object)
}

# create prediction call
pred_call <- make_pred_call(object$spec$method$pred$numeric)

res <- eval_tidy(pred_call)

# post-process the predictions
if (!is.null(object$spec$method$pred$numeric$post)) {
res <- object$spec$method$pred$numeric$post(res, object)
Expand All @@ -36,8 +40,9 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
if (is.vector(res)) {
res <- unname(res)
} else {
if (!inherits(res, "tbl_spark"))
if (!inherits(res, "tbl_spark")) {
res <- as.data.frame(res)
}
}
res
}
Expand All @@ -47,5 +52,6 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict_numeric.model_fit
predict_numeric <- function(object, ...)
predict_numeric <- function(object, ...) {
UseMethod("predict_numeric")
}
21 changes: 13 additions & 8 deletions R/predict_time.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,33 @@
#' @export predict_time.model_fit
#' @export
predict_time.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "censored regression")
rlang::abort(glue::glue("`predict_time()` is for predicting time outcomes. ",
"Use `predict_class()` or `predict_classprob()` for ",
"classification models."))
if (object$spec$mode != "censored regression") {
cli::cli_abort(
"{.code predict_time()} is for predicting time outcomes.
Use {.code predict_class()} or {.code predict_classprob()} for
classification models."
)
}

check_spec_pred_type(object, "time")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$time$pre))
if (!is.null(object$spec$method$pred$time$pre)) {
new_data <- object$spec$method$pred$time$pre(new_data, object)
}

# create prediction call
pred_call <- make_pred_call(object$spec$method$pred$time)

res <- eval_tidy(pred_call)

# post-process the predictions
if (!is.null(object$spec$method$pred$time$post)) {
res <- object$spec$method$pred$time$post(res, object)
Expand All @@ -45,5 +49,6 @@ predict_time.model_fit <- function(object, new_data, ...) {
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict_time.model_fit
predict_time <- function(object, ...)
predict_time <- function(object, ...) {
UseMethod("predict_time")
}
2 changes: 1 addition & 1 deletion tests/testthat/test-predict_formats.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ test_that('predict(type = "prob") with level "class" (see #720)', {
)

expect_error(
regexp = "variable `boop` has a level called 'class'",
regexp = 'variable `boop` has a level called "class"',
predict(mod, type = "prob", new_data = x)
)
})
Expand Down

0 comments on commit 583d774

Please sign in to comment.