Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better call routing for errors #1214

Merged
merged 19 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ export(bag_mars)
export(bag_mlp)
export(bag_tree)
export(bart)
export(bartMachine_interval_calc)
export(boost_tree)
export(case_weights_allowed)
export(cforest_train)
Expand Down
4 changes: 2 additions & 2 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ make_form_call <- function(object, env = NULL) {
}

# TODO we need something to indicate that case weights are being used.
make_xy_call <- function(object, target, env) {
make_xy_call <- function(object, target, env, call = rlang::caller_env()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll just make this comment once to apply throughout the PR, but rlang is imported with @import rlang so we don't need to namespace rlang for these!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It all comes back to psock... we started namespacing everything in these packages because importing was not enough.

fit_args <- object$method$fit$args
uses_weights <- has_weights(env)

Expand All @@ -283,7 +283,7 @@ make_xy_call <- function(object, target, env) {
data.frame = rlang::expr(maybe_data_frame(x)),
matrix = rlang::expr(maybe_matrix(x)),
dgCMatrix = rlang::expr(maybe_sparse_matrix(x)),
cli::cli_abort("Invalid data type target: {target}.")
cli::cli_abort("Invalid data type target: {target}.", call = call)
)
if (uses_weights) {
object$method$fit$args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights)
Expand Down
10 changes: 6 additions & 4 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL,
}


map_glmnet_coefs <- function(x) {
map_glmnet_coefs <- function(x, call = rlang::caller_env()) {
coefs <- coef(x)
# If parsnip is used to fit the model, glmnet should be attached and this will
# work. If an object is loaded from a new session, they will need to load the
# package.
if (is.null(coefs)) {
cli::cli_abort(
"Please load the {.pkg glmnet} package before running {.fun autoplot}."
"Please load the {.pkg glmnet} package before running {.fun autoplot}.",
call = call
)
}
p <- x$dim[1]
Expand Down Expand Up @@ -89,9 +90,10 @@ top_coefs <- function(x, top_n = 5) {
dplyr::slice(seq_len(top_n))
}

autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) {
autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L,
call = rlang::caller_env(), ...) {
tidy_coefs <-
map_glmnet_coefs(x) %>%
map_glmnet_coefs(x, call = call) %>%
dplyr::filter(penalty >= min_penalty)

actual_min_penalty <- min(tidy_coefs$penalty)
Expand Down
48 changes: 0 additions & 48 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,61 +130,13 @@ update.bart <-
)
}


#' Developer functions for predictions via BART models
#' @export
#' @keywords internal
#' @name bart-internal
#' @inheritParams predict.model_fit
#' @param obj A parsnip object.
#' @param ci Confidence (TRUE) or prediction interval (FALSE)
#' @param level Confidence level.
#' @param std_err Attach column for standard error of prediction or not.
bartMachine_interval_calc <- function(new_data, obj, ci = TRUE, level = 0.95) {
if (obj$spec$mode == "classification") {
cli::cli_abort(
"Prediction intervals are not possible for classification"
)
}
get_std_err <- obj$spec$method$pred$pred_int$extras$std_error

if (ci) {
cl <-
rlang::call2(
"calc_credible_intervals",
.ns = "bartMachine",
bart_machine = rlang::expr(obj$fit),
new_data = rlang::expr(new_data),
ci_conf = level
)

} else {
cl <-
rlang::call2(
"calc_prediction_intervals",
.ns = "bartMachine",
bart_machine = rlang::expr(obj$fit),
new_data = rlang::expr(new_data),
pi_conf = level
)
}
res <- rlang::eval_tidy(cl)
if (!ci) {
if (get_std_err) {
.std_error <- apply(res$all_prediction_samples, 1, stats::sd, na.rm = TRUE)
}
res <- res$interval
}
res <- tibble::as_tibble(res)
names(res) <- c(".pred_lower", ".pred_upper")
if (!ci & get_std_err) {
res$.std_err <- .std_error
}
res
}

#' @export
#' @rdname bart-internal
#' @keywords internal
dbart_predict_calc <- function(obj, new_data, type, level = 0.95, std_err = FALSE) {
types <- c("numeric", "class", "prob", "conf_int", "pred_int")
Expand Down
13 changes: 9 additions & 4 deletions R/condense_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#'
#' @return A control object with the same elements and classes of `ref`, with
#' values of `x`.
#' @param call The execution environment of a currently running function, e.g.
#' `caller_env()`. The function will be mentioned in error messages as the
#' source of the error. See the call argument of [rlang::abort()] for more
#' information.
#' @keywords internal
#' @export
#'
Expand All @@ -20,16 +24,17 @@
#'
#' ctrl <- condense_control(ctrl, control_parsnip())
#' str(ctrl)
condense_control <- function(x, ref) {
condense_control <- function(x, ref, call = rlang::caller_env()) {
topepo marked this conversation as resolved.
Show resolved Hide resolved
mismatch <- setdiff(names(ref), names(x))
if (length(mismatch)) {
cli::cli_abort(
c(
"Object of class {.cls class(x)[1]} cannot be coerced to
object of class {.cls class(ref)[1]}.",
"Object of class {.cls {class(x)[1]}} cannot be coerced to
object of class {.cls {class(ref)[1]}}.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better, I changed it to {.obj_type_friendly {x}}

"i" = "{cli::qty(mismatch)} The argument{?s} {.arg {mismatch}}
{?is/are} missing."
)
),
call = call
)
}
res <- x[names(ref)]
Expand Down
10 changes: 8 additions & 2 deletions R/contr_one_hot.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#' This contrast function produces a model matrix with indicator columns for
#' each level of each factor.
#'
#' @param n A vector of character factor levels or the number of unique levels.
#' @param n A vector of character factor levels (of length >=1) or the number
#' of unique levels (>= 1).
#' @param contrasts This argument is for backwards compatibility and only the
#' default of `TRUE` is supported.
#' @param sparse This argument is for backwards compatibility and only the
Expand All @@ -24,9 +25,14 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {
}

if (is.character(n)) {
check_character(n, empty = FALSE)
topepo marked this conversation as resolved.
Show resolved Hide resolved
if (length(n) < 1) {
cli::cli_abort("A character vector for {.arg n} cannot be empty.")
topepo marked this conversation as resolved.
Show resolved Hide resolved
}
names <- n
n <- length(names)
} else if (is.numeric(n)) {
check_number_whole(n, min = 1)
n <- as.integer(n)

if (length(n) != 1L) {
Expand All @@ -35,7 +41,7 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {

names <- as.character(seq_len(n))
} else {
cli::cli_abort("{.arg n} must be a character vector or an integer of size 1.")
check_number_whole(n, min = 1)
}

out <- diag(n)
Expand Down
28 changes: 17 additions & 11 deletions R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,21 @@
na.action = na.omit,
indicators = "traditional",
composition = "data.frame",
remove_intercept = TRUE) {
remove_intercept = TRUE,
call = rlang::caller_env()) {
if (!(composition %in% c("data.frame", "matrix", "dgCMatrix"))) {
cli::cli_abort(
"{.arg composition} should be either {.val data.frame}, {.val matrix}, or
{.val dgCMatrix}."
{.val dgCMatrix}.",
call = call
)
}

if (sparsevctrs::has_sparse_elements(data)) {
cli::cli_abort(
"Sparse data cannot be used with formula interface. Please use
{.fn fit_xy} instead."
"Sparse data cannot be used with formula interface. Please use
{.fn fit_xy} instead.",
call = call
)
}

Expand Down Expand Up @@ -84,7 +87,7 @@

w <- as.vector(model.weights(mod_frame))
if (!is.null(w) && !is.numeric(w)) {
cli::cli_abort("{.arg weights} must be a numeric vector.")
cli::cli_abort("{.arg weights} must be a numeric vector.", call = call)
}

# TODO: Do we actually use the offset when fitting?
Expand Down Expand Up @@ -175,10 +178,12 @@
.convert_form_to_xy_new <- function(object,
new_data,
na.action = na.pass,
composition = "data.frame") {
composition = "data.frame",
call = rlang::caller_env()) {
if (!(composition %in% c("data.frame", "matrix"))) {
cli::cli_abort(
"{.arg composition} should be either {.val data.frame} or {.val matrix}."
"{.arg composition} should be either {.val data.frame} or {.val matrix}.",
call = call
)
}

Expand Down Expand Up @@ -244,9 +249,10 @@
y,
weights = NULL,
y_name = "..y",
remove_intercept = TRUE) {
remove_intercept = TRUE,
call = rlang::caller_env()) {
if (is.vector(x)) {
cli::cli_abort("{.arg x} cannot be a vector.")
cli::cli_abort("{.arg x} cannot be a vector.", call = call)
}

if (remove_intercept) {
Expand Down Expand Up @@ -279,10 +285,10 @@

if (!is.null(weights)) {
if (!is.numeric(weights)) {
cli::cli_abort("{.arg weights} must be a numeric vector.")
cli::cli_abort("{.arg weights} must be a numeric vector.", call = call)
}
if (length(weights) != nrow(x)) {
cli::cli_abort("{.arg weights} should have {nrow(x)} elements.")
cli::cli_abort("{.arg weights} should have {nrow(x)} elements.", call = call)
}

form <- patch_formula_environment_with_case_weights(
Expand Down
16 changes: 9 additions & 7 deletions R/descriptors.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,23 @@ NULL

# Descriptor retrievers --------------------------------------------------------

get_descr_form <- function(formula, data) {
get_descr_form <- function(formula, data, call = rlang::caller_env()) {
if (inherits(data, "tbl_spark")) {
res <- get_descr_spark(formula, data)
} else {
res <- get_descr_df(formula, data)
res <- get_descr_df(formula, data, call = call)
}
res
}

get_descr_df <- function(formula, data) {
get_descr_df <- function(formula, data, call = rlang::caller_env()) {

tmp_dat <-
.convert_form_to_xy_fit(formula,
data,
indicators = "none",
remove_intercept = TRUE)
remove_intercept = TRUE,
call = call)

if(is.factor(tmp_dat$y)) {
.lvls <- function() {
Expand All @@ -136,7 +137,8 @@ get_descr_df <- function(formula, data) {
formula,
data,
indicators = "traditional",
remove_intercept = TRUE
remove_intercept = TRUE,
call = call
)$x
)
}
Expand Down Expand Up @@ -263,7 +265,7 @@ get_descr_spark <- function(formula, data) {
)
}

get_descr_xy <- function(x, y) {
get_descr_xy <- function(x, y, call = rlang::caller_env()) {

.lvls <- if (is.factor(y)) {
function() table(y, dnn = NULL)
Expand Down Expand Up @@ -291,7 +293,7 @@ get_descr_xy <- function(x, y) {
}

.dat <- function() {
.convert_xy_to_form_fit(x, y, remove_intercept = TRUE)$data
.convert_xy_to_form_fit(x, y, remove_intercept = TRUE, call = call)$data
}

.x <- function() {
Expand Down
5 changes: 3 additions & 2 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ fit.model_spec <-
}

if (all(c("x", "y") %in% names(dots))) {
cli::cli_abort("`fit.model_spec()` is for the formula methods. Use `fit_xy()` instead.")
cli::cli_abort("{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead.")
}
cl <- match.call(expand.dots = TRUE)
# Create an environment with the evaluated argument objects. This will be
Expand Down Expand Up @@ -307,7 +307,8 @@ fit_xy.model_spec <-

if (object$engine == "spark") {
cli::cli_abort(
"spark objects can only be used with the formula interface to {.fn fit} with a spark data object."
"spark objects can only be used with the formula interface to {.fn fit}
with a spark data object."
)
}

Expand Down
9 changes: 5 additions & 4 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ form_form <-

# if descriptors are needed, update descr_env with the calculated values
if (requires_descrs(object)) {
data_stats <- get_descr_form(env$formula, env$data)
data_stats <- get_descr_form(env$formula, env$data, call = call)
scoped_descrs(data_stats)
}

Expand Down Expand Up @@ -86,7 +86,7 @@ xy_xy <- function(object,

# if descriptors are needed, update descr_env with the calculated values
if (requires_descrs(object)) {
data_stats <- get_descr_xy(env$x, env$y)
data_stats <- get_descr_xy(env$x, env$y, call = call)
scoped_descrs(data_stats)
}

Expand All @@ -96,7 +96,7 @@ xy_xy <- function(object,
# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)

fit_call <- make_xy_call(object, target, env)
fit_call <- make_xy_call(object, target, env, call)

res <- list(lvl = levels(env$y), spec = object)

Expand Down Expand Up @@ -141,7 +141,8 @@ form_xy <- function(object, control, env,
...,
composition = target,
indicators = indicators,
remove_intercept = remove_intercept
remove_intercept = remove_intercept,
call = call
)
env$x <- data_obj$x
env$y <- data_obj$y
Expand Down
Loading
Loading