Skip to content

Commit

Permalink
transition to cli from rlang in predict() source (#1148)
Browse files Browse the repository at this point in the history
------

Co-authored-by: Simon P. Couch <[email protected]>
  • Loading branch information
shum461 and simonpcouch authored Aug 28, 2024
1 parent d68b765 commit 146bd6b
Showing 1 changed file with 73 additions and 48 deletions.
121 changes: 73 additions & 48 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
#' @export
predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) {
if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

Expand All @@ -156,7 +156,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)

type <- check_pred_type(object, type)
if (type != "raw" && length(opts) > 0) {
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
cli::cli_warn("{.arg opts} is only used with `type = 'raw'` and was ignored.")
}
check_pred_type_dots(object, type, ...)

Expand All @@ -173,7 +173,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
linear_pred = predict_linear_pred(object = object, new_data = new_data, ...),
hazard = predict_hazard(object = object, new_data = new_data, ...),
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
rlang::abort(glue::glue("I don't know about type = '{type}'"))
cli::cli_abort("Unknown prediction {.arg type} '{type}'.")
)
if (!inherits(res, "tbl_spark")) {
res <- switch(
Expand All @@ -191,45 +191,69 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
res
}

check_pred_type <- function(object, type, ...) {
check_pred_type <- function(object, type, ..., call = rlang::caller_env()) {
if (is.null(type)) {
type <-
switch(object$spec$mode,
regression = "numeric",
classification = "class",
"censored regression" = "time",
rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'."))
switch(
object$spec$mode,
regression = "numeric",
classification = "class",
"censored regression" = "time",
cli::cli_abort(
"{.arg type} should be 'regression', 'censored regression', or 'classification'.",
call = call
)
)
}
if (!(type %in% pred_types))
rlang::abort(
glue::glue(
"`type` should be one of: ",
glue_collapse(pred_types, sep = ", ", last = " and ")
)
cli::cli_abort(
"{.arg type} should be one of:{.arg {pred_types}}",
call = call
)

switch(
type,
"numeric" = if (object$spec$mode != "regression") {
rlang::abort("For numeric predictions, the object should be a regression model.")
cli::cli_abort(
"For numeric predictions, the object should be a regression model.",
call = call
)
},
"class" = if (object$spec$mode != "classification") {
rlang::abort("For class predictions, the object should be a classification model.")
cli::cli_abort(
"For class predictions, the object should be a classification model.",
call = call
)
},
"prob" = if (object$spec$mode != "classification") {
rlang::abort("For probability predictions, the object should be a classification model.")
cli::cli_abort(
"For probability predictions, the object should be a classification model.",
call = call
)
},
"time" = if (object$spec$mode != "censored regression") {
rlang::abort("For event time predictions, the object should be a censored regression.")
cli::cli_abort(
"For event time predictions, the object should be a censored regression.",
call = call
)
},
"survival" = if (object$spec$mode != "censored regression") {
rlang::abort("For survival probability predictions, the object should be a censored regression.")
cli::cli_abort(
"For survival probability predictions, the object should be a censored regression.",
call = call
)
},
"hazard" = if (object$spec$mode != "censored regression") {
rlang::abort("For hazard predictions, the object should be a censored regression.")
cli::cli_abort(
"For hazard predictions, the object should be a censored regression.",
call = call
)
},
"linear_pred" = if (object$spec$mode != "censored regression") {
rlang::abort("For the linear predictor, the object should be a censored regression.")
cli::cli_abort(
"For the linear predictor, the object should be a censored regression.",
call = call
)
}
)

Expand Down Expand Up @@ -349,56 +373,57 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())

other_args <- c("interval", "level", "std_error", "quantile",
"time", "eval_time", "increasing")

eval_time_types <- c("survival", "hazard")

is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
rlang::abort(
glue::glue(
"The ellipses are not used to pass args to the model function's ",
"predict function. These arguments cannot be used: {bad_args}",
)
cli::cli_abort(
"The ellipses are not used to pass args to the model function's
predict function. These arguments cannot be used: {.val bad_args}",
call = call
)
}

# ----------------------------------------------------------------------------
# places where eval_time should not be given
if (any(nms == "eval_time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"`eval_time` should only be passed to `predict()` when `type` is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
)
cli::cli_abort(
"{.arg eval_time} should only be passed to {.fn predict} when \\
{.arg type} is one of {.or {.val {eval_time_types}}}.",
call = call
)


}
if (any(nms == "time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"'time' should only be passed to `predict()` when 'type' is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
cli::cli_abort(
"{.arg time} should only be passed to {.fn predict} when {.arg type} is
one of {.or {.val {eval_time_types}}}.",
call = call
)
}
# when eval_time should be passed
if (!any(nms %in% c("eval_time", "time")) & type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"When using `type` values of 'survival' or 'hazard',",
"a numeric vector `eval_time` should also be given."
)
)
cli::cli_abort(
"When using {.arg type} values of {.or {.val {eval_time_types}}} a numeric
vector {.arg eval_time} should also be given.",
call = call
)
}

# `increasing` only applies to linear_pred for censored regression
if (any(nms == "increasing") &
!(type == "linear_pred" &
object$spec$mode == "censored regression")) {
rlang::abort(
paste(
"The 'increasing' argument only applies to predictions of",
"type 'linear_pred' for the mode censored regression."
)
cli::cli_abort(
"{.arg increasing} only applies to predictions of
type 'linear_pred' for the mode censored regression.",
call = call
)

}

invisible(TRUE)
Expand Down

0 comments on commit 146bd6b

Please sign in to comment.