Skip to content

Commit

Permalink
updated rlang messages in the predict functions to use cli. Fixes #1138
Browse files Browse the repository at this point in the history
  • Loading branch information
shum461 committed Aug 15, 2024
1 parent aa788b8 commit 3cfd404
Showing 1 changed file with 32 additions and 45 deletions.
77 changes: 32 additions & 45 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
#' @export
predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) {
if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

Expand All @@ -156,7 +156,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)

type <- check_pred_type(object, type)
if (type != "raw" && length(opts) > 0) {
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
cli::cli_warn("{.arg opts} is only used with type = 'raw' and was ignored.")
}
check_pred_type_dots(object, type, ...)

Expand All @@ -173,7 +173,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
linear_pred = predict_linear_pred(object = object, new_data = new_data, ...),
hazard = predict_hazard(object = object, new_data = new_data, ...),
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
rlang::abort(glue::glue("I don't know about type = '{type}'"))
cli::cli_abort("I don't know about type = {.arg {type}}")
)
if (!inherits(res, "tbl_spark")) {
res <- switch(
Expand All @@ -198,38 +198,34 @@ check_pred_type <- function(object, type, ...) {
regression = "numeric",
classification = "class",
"censored regression" = "time",
rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'."))
cli::cli_abort("{.arg type} should be 'regression', 'censored regression', or 'classification'."))
}
if (!(type %in% pred_types))
rlang::abort(
glue::glue(
"`type` should be one of: ",
glue_collapse(pred_types, sep = ", ", last = " and ")
)
)
cli::cli_abort(
"{.arg type} should be one of:{.arg {pred_types}}")

switch(
type,
"numeric" = if (object$spec$mode != "regression") {
rlang::abort("For numeric predictions, the object should be a regression model.")
cli::cli_abort("For numeric predictions, the object should be a regression model.")
},
"class" = if (object$spec$mode != "classification") {
rlang::abort("For class predictions, the object should be a classification model.")
cli::cli_abort("For class predictions, the object should be a classification model.")
},
"prob" = if (object$spec$mode != "classification") {
rlang::abort("For probability predictions, the object should be a classification model.")
cli::cli_abort("For probability predictions, the object should be a classification model.")
},
"time" = if (object$spec$mode != "censored regression") {
rlang::abort("For event time predictions, the object should be a censored regression.")
cli::cli_abort("For event time predictions, the object should be a censored regression.")
},
"survival" = if (object$spec$mode != "censored regression") {
rlang::abort("For survival probability predictions, the object should be a censored regression.")
cli::cli_abort("For survival probability predictions, the object should be a censored regression.")
},
"hazard" = if (object$spec$mode != "censored regression") {
rlang::abort("For hazard predictions, the object should be a censored regression.")
cli::cli_abort("For hazard predictions, the object should be a censored regression.")
},
"linear_pred" = if (object$spec$mode != "censored regression") {
rlang::abort("For the linear predictor, the object should be a censored regression.")
cli::cli_abort("For the linear predictor, the object should be a censored regression.")
}
)

Expand Down Expand Up @@ -349,56 +345,47 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())

other_args <- c("interval", "level", "std_error", "quantile",
"time", "eval_time", "increasing")

eval_time_types <- c("survival", "hazard")

is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
rlang::abort(
glue::glue(
"The ellipses are not used to pass args to the model function's ",
"predict function. These arguments cannot be used: {bad_args}",
)
cli::cli_abort(
"The ellipses are not used to pass args to the model function's
predict function. These arguments cannot be used: {.val bad_args}",
)
}

# ----------------------------------------------------------------------------
# places where eval_time should not be given
if (any(nms == "eval_time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"`eval_time` should only be passed to `predict()` when `type` is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
)
cli::cli_abort("{.arg eval_time} should only be passed to {.fn predict} when {.arg type} is one of:
{.val {eval_time_types}}")


}
if (any(nms == "time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"'time' should only be passed to `predict()` when 'type' is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
)
cli::cli_abort("{.arg time} should only be passed to {.fn predict} when {.arg type} is one of:
{.val {eval_time_types}}")
}
# when eval_time should be passed
if (!any(nms %in% c("eval_time", "time")) & type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"When using `type` values of 'survival' or 'hazard',",
"a numeric vector `eval_time` should also be given."
)
)
cli::cli_abort(
"When using `type` values of 'survival' or 'hazard' a numeric vector `eval_time` should also be given.")

}

# `increasing` only applies to linear_pred for censored regression
if (any(nms == "increasing") &
!(type == "linear_pred" &
object$spec$mode == "censored regression")) {
rlang::abort(
paste(
"The 'increasing' argument only applies to predictions of",
"type 'linear_pred' for the mode censored regression."
cli::cli_abort(
"The {.arg increasing} argument only applies to predictions of
type 'linear_pred' for the mode censored regression."
)
)

}

invisible(TRUE)
Expand Down

0 comments on commit 3cfd404

Please sign in to comment.