Skip to content

Commit

Permalink
Merge branch 'main' into doc-sparse-data
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Sep 9, 2024
2 parents 44403fc + 8af5ddf commit 93463c2
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 11 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* `fit()` and `fit_xy()` can now take sparse tibbles as data values (#1165).

* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument, and error informatively when model doesn't support it (#1167).

* Transitioned package errors and warnings to use cli (#1147 and #1148 by
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
#1161, #1081).
Expand Down
File renamed without changes.
File renamed without changes.
8 changes: 8 additions & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ fit.model_spec <-
cli::cli_abort(msg)
}

if (is_sparse_matrix(data)) {
data <- sparsevctrs::coerce_to_sparse_tibble(data)
}

dots <- quos(...)

if (length(possible_engines(object)) == 0) {
Expand Down Expand Up @@ -444,6 +448,10 @@ check_xy_interface <- function(x, y, cl, model) {
}

allow_sparse <- function(x) {
if (inherits(x, "model_fit")) {
x <- x$spec
}

res <- get_from_env(paste0(class(x)[1], "_encoding"))
all(res$allow_sparse_x[res$engine == x$engine])
}
Expand Down
4 changes: 3 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
}
check_pred_type_dots(object, type, ...)

new_data <- to_sparse_data_frame(new_data, object)

res <- switch(
type,
numeric = predict_numeric(object = object, new_data = new_data, ...),
Expand Down Expand Up @@ -450,7 +452,7 @@ prepare_data <- function(object, new_data) {
} else if (translate_from_xy_to_formula) {
new_data <- .convert_xy_to_form_new(object$preproc, new_data)
} else if (translate_from_xy_to_xy) {
new_data <- new_data[, object$preproc$x_names]
new_data <- new_data[, object$preproc$x_names, drop = FALSE]
}

encodings <- get_encoding(class(object$spec)[1])
Expand Down
26 changes: 20 additions & 6 deletions R/sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@
#' @name sparse_data
NULL

to_sparse_data_frame <- function(x, object) {
if (methods::is(x, "sparseMatrix")) {
to_sparse_data_frame <- function(x, object, call = rlang::caller_env()) {
if (is_sparse_matrix(x)) {
if (allow_sparse(object)) {
x <- sparsevctrs::coerce_to_sparse_data_frame(x)
} else {
if (inherits(object, "model_fit")) {
object <- object$spec
}

cli::cli_abort(
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
engine {.code {object$engine}} doesn't accept that.")
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
engine {.val {object$engine}} doesn't accept that.",
call = call
)
}
} else if (is.data.frame(x)) {
x <- materialize_sparse_tibble(x, object, "x")
Expand All @@ -35,11 +41,19 @@ is_sparse_tibble <- function(x) {
any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))
}

is_sparse_matrix <- function(x) {
methods::is(x, "sparseMatrix")
}

materialize_sparse_tibble <- function(x, object, input) {
if ((!allow_sparse(object)) && is_sparse_tibble(x)) {
if (is_sparse_tibble(x) && (!allow_sparse(object))) {
if (inherits(object, "model_fit")) {
object <- object$spec
}

cli::cli_warn(
"{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with
engine {.code {object$engine}} doesn't accept that. Converting to
engine {.val {object$engine}} doesn't accept that. Converting to
non-sparse."
)
for (i in seq_along(ncol(x))) {
Expand Down
32 changes: 28 additions & 4 deletions tests/testthat/_snaps/sparsevctrs.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,47 @@
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
Condition
Warning:
`data` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse matrix can be passed to `fit()

Code
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
Condition
Warning:
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse tibble can be passed to `fit_xy()

Code
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
Condition
Warning:
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
`x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse matrices can be passed to `fit_xy()

Code
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
Condition
Error in `to_sparse_data_frame()`:
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
Error in `fit_xy()`:
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.

# sparse tibble can be passed to `predict()

Code
preds <- predict(lm_fit, sparse_mtcars)
Condition
Warning:
`x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse matrices can be passed to `predict()

Code
predict(lm_fit, sparse_mtcars)
Condition
Error in `predict()`:
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.

# to_sparse_data_frame() is used correctly

Expand Down
82 changes: 82 additions & 0 deletions tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@ test_that("sparse tibble can be passed to `fit()", {
)
})

test_that("sparse matrix can be passed to `fit()", {
skip_if_not_installed("xgboost")

hotel_data <- sparse_hotel_rates()

spec <- boost_tree() %>%
set_mode("regression") %>%
set_engine("xgboost")

expect_no_error(
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
)

spec <- linear_reg() %>%
set_mode("regression") %>%
set_engine("lm")

expect_snapshot(
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
)
})

test_that("sparse tibble can be passed to `fit_xy()", {
skip_if_not_installed("xgboost")

Expand Down Expand Up @@ -67,6 +89,66 @@ test_that("sparse matrices can be passed to `fit_xy()", {
)
})

test_that("sparse tibble can be passed to `predict()", {
skip_if_not_installed("ranger")

hotel_data <- sparse_hotel_rates()
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)

spec <- rand_forest(trees = 10) %>%
set_mode("regression") %>%
set_engine("ranger")

tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])

expect_no_error(
predict(tree_fit, hotel_data)
)

spec <- linear_reg() %>%
set_mode("regression") %>%
set_engine("lm")

lm_fit <- fit(spec, mpg ~ ., data = mtcars)

sparse_mtcars <- mtcars %>%
sparsevctrs::coerce_to_sparse_matrix() %>%
sparsevctrs::coerce_to_sparse_tibble()

expect_snapshot(
preds <- predict(lm_fit, sparse_mtcars)
)
})

test_that("sparse matrices can be passed to `predict()", {
skip_if_not_installed("ranger")

hotel_data <- sparse_hotel_rates()

spec <- rand_forest(trees = 10) %>%
set_mode("regression") %>%
set_engine("ranger")

tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])

expect_no_error(
predict(tree_fit, hotel_data)
)

spec <- linear_reg() %>%
set_mode("regression") %>%
set_engine("lm")

lm_fit <- fit(spec, mpg ~ ., data = mtcars)

sparse_mtcars <- sparsevctrs::coerce_to_sparse_matrix(mtcars)

expect_snapshot(
error = TRUE,
predict(lm_fit, sparse_mtcars)
)
})

test_that("to_sparse_data_frame() is used correctly", {
skip_if_not_installed("xgboost")

Expand Down

0 comments on commit 93463c2

Please sign in to comment.