Skip to content

Commit

Permalink
use type checkers in fit.R
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Sep 9, 2024
1 parent 8af5ddf commit 280b4cd
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 55 deletions.
2 changes: 1 addition & 1 deletion R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ check_dup_names <- function(x, y, call = rlang::caller_env()) {
#' @return A data frame, matrix, or sparse matrix.
#' @export
maybe_matrix <- function(x) {
inher(x, c("data.frame", "matrix", "dgCMatrix"), cl = match.call())
check_inherits(x, c("data.frame", "matrix", "dgCMatrix"))
if (is.data.frame(x)) {
non_num_cols <- vapply(x, function(x) !is.numeric(x), logical(1))
if (any(non_num_cols)) {
Expand Down
89 changes: 37 additions & 52 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@
#' a "reverse Kaplan-Meier" curve that models the probability of censoring. This
#' may be used later to compute inverse probability censoring weights for
#' performance measures.
#'
#'
#' Sparse data is supported, with the use of the `x` argument in `fit_xy()`. See
#' `allow_sparse_x` column of [parsnip::get_encoding()] for sparse input
#' `allow_sparse_x` column of [parsnip::get_encoding()] for sparse input
#' compatibility.
#'
#'
#' @examplesIf !parsnip:::is_cran_check()
#' # Although `glm()` only has a formula interface, different
#' # methods for specifying the model can be used
Expand Down Expand Up @@ -121,21 +121,17 @@ fit.model_spec <-
control <- condense_control(control, control_parsnip())
check_case_weights(case_weights, object)

if (!inherits(formula, "formula")) {
msg <- "The {.arg formula} argument must be a formula, but it is a \\
{.cls {class(formula)[1]}}."

if (inherits(formula, "recipe")) {
msg <-
c(
msg,
"i" = "To fit a model with a recipe preprocessor, please use a \\
if (inherits(formula, "recipe")) {
cli::cli_abort(
c(
"The {.arg formula} argument must be a formula.",
"i" = "To fit a model with a recipe preprocessor, please use a \\
{.help [workflow](workflows::workflow)}."
)
}

cli::cli_abort(msg)
)
)
}
check_formula(formula)


if (is_sparse_matrix(data)) {
data <- sparsevctrs::coerce_to_sparse_tibble(data)
Expand Down Expand Up @@ -179,7 +175,7 @@ fit.model_spec <-
eval_env$weights <- wts

data <- materialize_sparse_tibble(data, object, "data")

fit_interface <-
check_interface(eval_env$formula, eval_env$data, cl, object)

Expand Down Expand Up @@ -297,10 +293,11 @@ fit_xy.model_spec <-
# TODO case weights: pass in eval_env not individual elements
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)

if (object$engine == "spark")
if (object$engine == "spark") {
cli::cli_abort(
"spark objects can only be used with the formula interface to {.fn fit} with a spark data object."
"spark objects can only be used with the formula interface to {.fn fit} with a spark data object."
)
}

# populate `method` with the details for this model type
object <- add_methods(object, engine = object$engine)
Expand Down Expand Up @@ -373,59 +370,47 @@ eval_mod <- function(e, capture = FALSE, catch = FALSE, envir = NULL, ...) {

# ------------------------------------------------------------------------------

inher <- function(x, cls, cl) {
if (!is.null(x) && !inherits(x, cls)) {

call <- match.call()
obj <- deparse(call[["x"]])

if (length(cls) > 1)
cli::cli_abort(
"{.arg {obj}} should be one of the following classes: {.cls {cls}}.")

else
cli::cli_abort("{.arg {obj}} should be a {.cls {cls}} object")
}
invisible(x)
}

# ------------------------------------------------------------------------------

check_interface <- function(formula, data, cl, model) {
inher(formula, "formula", cl)
inher(data, c("data.frame", "dgCMatrix", "tbl_spark"), cl)
check_interface <- function(formula, data, cl, model, call = caller_env()) {
check_inherits(formula, "formula", call = call)
check_inherits(data, c("data.frame", "dgCMatrix", "tbl_spark"), call = call)

# Determine the `fit()` interface
form_interface <- !is.null(formula) & !is.null(data)

if (form_interface)
return("formula")
cli::cli_abort("Error when checking the interface.")
cli::cli_abort("Error when checking the interface.", call = call)
}

check_xy_interface <- function(x, y, cl, model) {
check_xy_interface <- function(x, y, cl, model, call = caller_env()) {

sparse_ok <- allow_sparse(model)
sparse_x <- inherits(x, "dgCMatrix")
if (!sparse_ok & sparse_x) {
cli::cli_abort("Sparse matrices not supported by this model/engine combination.")
cli::cli_abort(
"Sparse matrices not supported by this model/engine combination.",
call = call
)
}

if (sparse_ok) {
inher(x, c("data.frame", "matrix", "dgCMatrix"), cl)
check_inherits(x, c("data.frame", "matrix", "dgCMatrix"), call = call)
} else {
inher(x, c("data.frame", "matrix"), cl)
check_inherits(x, c("data.frame", "matrix"), call = call)
}

if (!is.null(y) && !is.atomic(y))
inher(y, c("data.frame", "matrix"), cl)
if (!is.null(y) && !is.atomic(y)) {
check_inherits(y, c("data.frame", "matrix"), call = call)
}

# rule out spark data sets that don't use the formula interface
if (inherits(x, "tbl_spark") | inherits(y, "tbl_spark"))
if (inherits(x, "tbl_spark") | inherits(y, "tbl_spark")) {
cli::cli_abort(
"spark objects can only be used with the formula interface via {.fn fit} with a spark data object."
)

"spark objects can only be used with the formula interface via
{.fn fit} with a spark data object.",
call = call
)
}

if (sparse_ok) {
matrix_interface <- !is.null(x) && !is.null(y) && (is.matrix(x) | sparse_x)
Expand All @@ -444,7 +429,7 @@ check_xy_interface <- function(x, y, cl, model) {

check_outcome(y, model)

cli::cli_abort("Error when checking the interface")
cli::cli_abort("Error when checking the interface", call = call)
}

allow_sparse <- function(x) {
Expand Down
15 changes: 15 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,21 @@ check_case_weights <- function(x, spec, call = rlang::caller_env()) {
invisible(NULL)
}

# ------------------------------------------------------------------------------

check_inherits <- function(x, cls, arg = caller_arg(x), call = caller_env()) {
if (is.null(x)) {
return(invisible(x))
}

if (!inherits(x, cls)) {
cli::cli_abort(
"{.arg {arg}} should be a {.cls {cls}}, not {.obj_type_friendly {x}}.",
call = call
)
}
}

# -----------------------------------------------------------------------------
check_for_newdata <- function(..., call = rlang::caller_env()) {
if (any(names(list(...)) == "newdata")) {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/fit_interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
fit(linear_reg(), rec, mtcars)
Condition
Error in `fit()`:
! The `formula` argument must be a formula, but it is a <recipe>.
! The `formula` argument must be a formula.
i To fit a model with a recipe preprocessor, please use a workflow (`?workflows::workflow()`).

---
Expand All @@ -13,7 +13,7 @@
fit(linear_reg(), "boop", mtcars)
Condition
Error in `fit()`:
! The `formula` argument must be a formula, but it is a <character>.
! `formula` must be a formula, not the string "boop".

# No loaded engines

Expand Down

0 comments on commit 280b4cd

Please sign in to comment.