Skip to content

Commit

Permalink
Merge pull request #1167 from tidymodels/sparse-matrix-predict
Browse files Browse the repository at this point in the history
Make sure sparse matrices can be used with `predict()`
  • Loading branch information
EmilHvitfeldt authored Sep 5, 2024
2 parents bc27e2b + f606403 commit a9aadfb
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* `fit()` and `fit_xy()` can now take sparse tibbles as data values (#1165).

* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument, and error informatively when model doesn't support it (#1167).

* Transitioned package errors and warnings to use cli (#1147 and #1148 by
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
#1161, #1081).
Expand Down
4 changes: 4 additions & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,10 @@ check_xy_interface <- function(x, y, cl, model) {
}

allow_sparse <- function(x) {
if (inherits(x, "model_fit")) {
x <- x$spec
}

res <- get_from_env(paste0(class(x)[1], "_encoding"))
all(res$allow_sparse_x[res$engine == x$engine])
}
Expand Down
2 changes: 2 additions & 0 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
}
check_pred_type_dots(object, type, ...)

new_data <- to_sparse_data_frame(new_data, object)

res <- switch(
type,
numeric = predict_numeric(object = object, new_data = new_data, ...),
Expand Down
8 changes: 8 additions & 0 deletions R/sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ to_sparse_data_frame <- function(x, object) {
if (allow_sparse(object)) {
x <- sparsevctrs::coerce_to_sparse_data_frame(x)
} else {
if (inherits(object, "model_fit")) {
object <- object$spec
}

cli::cli_abort(
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
engine {.code {object$engine}} doesn't accept that.")
Expand All @@ -19,6 +23,10 @@ is_sparse_tibble <- function(x) {

materialize_sparse_tibble <- function(x, object, input) {
if (is_sparse_tibble(x) && (!allow_sparse(object))) {
if (inherits(object, "model_fit")) {
object <- object$spec
}

cli::cli_warn(
"{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with
engine {.code {object$engine}} doesn't accept that. Converting to
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/_snaps/sparsevctrs.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@
Error in `to_sparse_data_frame()`:
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.

# sparse tibble can be passed to `predict()

Code
preds <- predict(lm_fit, sparse_mtcars)
Condition
Warning:
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.

# sparse matrices can be passed to `predict()

Code
predict(lm_fit, sparse_mtcars)
Condition
Error in `to_sparse_data_frame()`:
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.

# to_sparse_data_frame() is used correctly

Code
Expand Down
60 changes: 60 additions & 0 deletions tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,66 @@ test_that("sparse matrices can be passed to `fit_xy()", {
)
})

test_that("sparse tibble can be passed to `predict()", {
skip_if_not_installed("ranger")

hotel_data <- sparse_hotel_rates()
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)

spec <- rand_forest(trees = 10) %>%
set_mode("regression") %>%
set_engine("ranger")

tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])

expect_no_error(
predict(tree_fit, hotel_data)
)

spec <- linear_reg() %>%
set_mode("regression") %>%
set_engine("lm")

lm_fit <- fit(spec, mpg ~ ., data = mtcars)

sparse_mtcars <- mtcars %>%
sparsevctrs::coerce_to_sparse_matrix() %>%
sparsevctrs::coerce_to_sparse_tibble()

expect_snapshot(
preds <- predict(lm_fit, sparse_mtcars)
)
})

test_that("sparse matrices can be passed to `predict()", {
skip_if_not_installed("ranger")

hotel_data <- sparse_hotel_rates()

spec <- rand_forest(trees = 10) %>%
set_mode("regression") %>%
set_engine("ranger")

tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])

expect_no_error(
predict(tree_fit, hotel_data)
)

spec <- linear_reg() %>%
set_mode("regression") %>%
set_engine("lm")

lm_fit <- fit(spec, mpg ~ ., data = mtcars)

sparse_mtcars <- sparsevctrs::coerce_to_sparse_matrix(mtcars)

expect_snapshot(
error = TRUE,
predict(lm_fit, sparse_mtcars)
)
})

test_that("to_sparse_data_frame() is used correctly", {
skip_if_not_installed("xgboost")

Expand Down

0 comments on commit a9aadfb

Please sign in to comment.