Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use type checkers in fit.R #1182

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ check_dup_names <- function(x, y, call = rlang::caller_env()) {
#' @return A data frame, matrix, or sparse matrix.
#' @export
maybe_matrix <- function(x) {
inher(x, c("data.frame", "matrix", "dgCMatrix"), cl = match.call())
check_inherits(x, c("data.frame", "matrix", "dgCMatrix"))
if (is.data.frame(x)) {
non_num_cols <- vapply(x, function(x) !is.numeric(x), logical(1))
if (any(non_num_cols)) {
Expand Down
89 changes: 37 additions & 52 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@
#' a "reverse Kaplan-Meier" curve that models the probability of censoring. This
#' may be used later to compute inverse probability censoring weights for
#' performance measures.
#'
#'
#' Sparse data is supported, with the use of the `x` argument in `fit_xy()`. See
#' `allow_sparse_x` column of [parsnip::get_encoding()] for sparse input
#' `allow_sparse_x` column of [parsnip::get_encoding()] for sparse input
#' compatibility.
#'
#'
#' @examplesIf !parsnip:::is_cran_check()
#' # Although `glm()` only has a formula interface, different
#' # methods for specifying the model can be used
Expand Down Expand Up @@ -121,21 +121,17 @@ fit.model_spec <-
control <- condense_control(control, control_parsnip())
check_case_weights(case_weights, object)

if (!inherits(formula, "formula")) {
msg <- "The {.arg formula} argument must be a formula, but it is a \\
{.cls {class(formula)[1]}}."

if (inherits(formula, "recipe")) {
msg <-
c(
msg,
"i" = "To fit a model with a recipe preprocessor, please use a \\
if (inherits(formula, "recipe")) {
cli::cli_abort(
c(
"The {.arg formula} argument must be a formula.",
"i" = "To fit a model with a recipe preprocessor, please use a \\
{.help [workflow](workflows::workflow)}."
)
}

cli::cli_abort(msg)
)
)
}
check_formula(formula)


if (is_sparse_matrix(data)) {
data <- sparsevctrs::coerce_to_sparse_tibble(data)
Expand Down Expand Up @@ -179,7 +175,7 @@ fit.model_spec <-
eval_env$weights <- wts

data <- materialize_sparse_tibble(data, object, "data")

fit_interface <-
check_interface(eval_env$formula, eval_env$data, cl, object)

Expand Down Expand Up @@ -297,10 +293,11 @@ fit_xy.model_spec <-
# TODO case weights: pass in eval_env not individual elements
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)

if (object$engine == "spark")
if (object$engine == "spark") {
cli::cli_abort(
"spark objects can only be used with the formula interface to {.fn fit} with a spark data object."
"spark objects can only be used with the formula interface to {.fn fit} with a spark data object."
)
}

# populate `method` with the details for this model type
object <- add_methods(object, engine = object$engine)
Expand Down Expand Up @@ -373,59 +370,47 @@ eval_mod <- function(e, capture = FALSE, catch = FALSE, envir = NULL, ...) {

# ------------------------------------------------------------------------------

inher <- function(x, cls, cl) {
if (!is.null(x) && !inherits(x, cls)) {

call <- match.call()
obj <- deparse(call[["x"]])

if (length(cls) > 1)
cli::cli_abort(
"{.arg {obj}} should be one of the following classes: {.cls {cls}}.")

else
cli::cli_abort("{.arg {obj}} should be a {.cls {cls}} object")
}
invisible(x)
}

# ------------------------------------------------------------------------------

check_interface <- function(formula, data, cl, model) {
inher(formula, "formula", cl)
inher(data, c("data.frame", "dgCMatrix", "tbl_spark"), cl)
check_interface <- function(formula, data, cl, model, call = caller_env()) {
check_inherits(formula, "formula", call = call)
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
check_inherits(data, c("data.frame", "dgCMatrix", "tbl_spark"), call = call)

# Determine the `fit()` interface
form_interface <- !is.null(formula) & !is.null(data)

if (form_interface)
return("formula")
cli::cli_abort("Error when checking the interface.")
cli::cli_abort("Error when checking the interface.", call = call)
}

check_xy_interface <- function(x, y, cl, model) {
check_xy_interface <- function(x, y, cl, model, call = caller_env()) {

sparse_ok <- allow_sparse(model)
sparse_x <- inherits(x, "dgCMatrix")
if (!sparse_ok & sparse_x) {
cli::cli_abort("Sparse matrices not supported by this model/engine combination.")
cli::cli_abort(
"Sparse matrices not supported by this model/engine combination.",
call = call
)
}

if (sparse_ok) {
inher(x, c("data.frame", "matrix", "dgCMatrix"), cl)
check_inherits(x, c("data.frame", "matrix", "dgCMatrix"), call = call)
} else {
inher(x, c("data.frame", "matrix"), cl)
check_inherits(x, c("data.frame", "matrix"), call = call)
}

if (!is.null(y) && !is.atomic(y))
inher(y, c("data.frame", "matrix"), cl)
if (!is.null(y) && !is.atomic(y)) {
check_inherits(y, c("data.frame", "matrix"), call = call)
}

# rule out spark data sets that don't use the formula interface
if (inherits(x, "tbl_spark") | inherits(y, "tbl_spark"))
if (inherits(x, "tbl_spark") | inherits(y, "tbl_spark")) {
cli::cli_abort(
"spark objects can only be used with the formula interface via {.fn fit} with a spark data object."
)

"spark objects can only be used with the formula interface via
{.fn fit} with a spark data object.",
call = call
)
}

if (sparse_ok) {
matrix_interface <- !is.null(x) && !is.null(y) && (is.matrix(x) | sparse_x)
Expand All @@ -444,7 +429,7 @@ check_xy_interface <- function(x, y, cl, model) {

check_outcome(y, model)

cli::cli_abort("Error when checking the interface")
cli::cli_abort("Error when checking the interface", call = call)
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
}

allow_sparse <- function(x) {
Expand Down
15 changes: 15 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,21 @@ check_case_weights <- function(x, spec, call = rlang::caller_env()) {
invisible(NULL)
}

# ------------------------------------------------------------------------------

check_inherits <- function(x, cls, arg = caller_arg(x), call = caller_env()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided to write out inher() in favor of this helper. I think we could probably get a good bit of mileage out of this based on the number of if (inherits( in this repo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100%, and the name is much better too!

if (is.null(x)) {
return(invisible(x))
}

if (!inherits(x, cls)) {
cli::cli_abort(
"{.arg {arg}} should be a {.cls {cls}}, not {.obj_type_friendly {x}}.",
call = call
)
}
}

# -----------------------------------------------------------------------------
check_for_newdata <- function(..., call = rlang::caller_env()) {
if (any(names(list(...)) == "newdata")) {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/fit_interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
fit(linear_reg(), rec, mtcars)
Condition
Error in `fit()`:
! The `formula` argument must be a formula, but it is a <recipe>.
! The `formula` argument must be a formula.
i To fit a model with a recipe preprocessor, please use a workflow (`?workflows::workflow()`).

---
Expand All @@ -13,7 +13,7 @@
fit(linear_reg(), "boop", mtcars)
Condition
Error in `fit()`:
! The `formula` argument must be a formula, but it is a <character>.
! `formula` must be a formula, not the string "boop".

# No loaded engines

Expand Down
Loading