Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

let fit_xy() take dgCMatrix input #1121

Merged
merged 20 commits into from
Aug 28, 2024
Merged

let fit_xy() take dgCMatrix input #1121

merged 20 commits into from
Aug 28, 2024

Conversation

EmilHvitfeldt
Copy link
Member

@EmilHvitfeldt EmilHvitfeldt commented May 23, 2024

Ref: #1125

General idea:

  • intercept dgCMatrix early on, and turn it into a sparse data frame
  • before data is passed to modeling function, turn it into a sparse matrix

TODO:

  • Make sure this only happens in compliant models
library(tidymodels)
library(textrecipes)
library(friends)

preped_rec <- recipe(season ~ text, data = friends) %>%
  step_tokenize(text) %>%
  step_tf(text) %>%
  prep()
#> Warning in asMethod(object): sparse->dense coercion: allocating vector of size
#> 8.7 GiB

term_freq <- bake(preped_rec, new_data = NULL, composition = "dgCMatrix")

dim(term_freq)
#> [1] 67373 17378

lobstr::obj_size(term_freq)
#> 9.86 MB

lm_spec <- linear_reg(penalty = 0) |>
  set_engine("glmnet")

tictoc::tic()
lm_fit <- fit_xy(lm_spec, x = term_freq[, -1], y = term_freq[, 1])
tictoc::toc()
#> 2.006 sec elapsed

lm_fit
#> parsnip model object
#> 
#> 
#> Call:  glmnet::glmnet(x = maybe_matrix(x), y = y, family = "gaussian") 
#> 
#>       Df  %Dev   Lambda
#> 1      0  0.00 0.200600
#> 2      1  0.09 0.182800
#> 3      2  0.22 0.166500
#> 4      3  0.35 0.151700
#> 5      3  0.50 0.138300
#> 6      4  0.63 0.126000
#> 7      7  0.83 0.114800
#> 8      8  1.04 0.104600
#> 9     11  1.26 0.095300
#> 10    13  1.49 0.086830
#> 11    19  1.73 0.079120
#> 12    31  2.04 0.072090
#> 13    39  2.40 0.065680
#> 14    52  2.80 0.059850
#> 15    70  3.22 0.054530
#> 16    81  3.66 0.049690
#> 17   101  4.10 0.045270
#> 18   139  4.58 0.041250
#> 19   193  5.14 0.037590
#> 20   273  5.79 0.034250
#> 21   375  6.52 0.031210
#> 22   515  7.34 0.028430
#> 23   677  8.25 0.025910
#> 24   962  9.26 0.023610
#> 25  1208 10.40 0.021510
#> 26  1516 11.58 0.019600
#> 27  2001 12.83 0.017860
#> 28  2946 14.32 0.016270
#> 29  3538 15.93 0.014820
#> 30  4287 17.51 0.013510
#> 31  5048 19.10 0.012310
#> 32  5607 20.61 0.011210
#> 33  6149 22.00 0.010220
#> 34  6755 23.30 0.009311
#> 35  7295 24.50 0.008483
#> 36  7820 25.59 0.007730
#> 37  8359 26.58 0.007043
#> 38  8846 27.48 0.006417
#> 39  9370 28.30 0.005847
#> 40  9814 29.03 0.005328
#> 41 10265 29.71 0.004855
#> 42 10717 30.31 0.004423
#> 43 11068 30.84 0.004030
#> 44 11432 31.34 0.003672
#> 45 11753 31.77 0.003346
#> 46 12103 32.17 0.003049
#> 47 12389 32.51 0.002778
#> 48 12669 32.82 0.002531
#> 49 12956 33.09 0.002306
#> 50 13223 33.34 0.002101
#> 51 13505 33.55 0.001915
#> 52 13731 33.72 0.001745
#> 53 14016 33.92 0.001590
#> 54 14224 34.07 0.001448
#> 55 14406 34.18 0.001320
#> 56 14667 34.31 0.001203
#> 57 14791 34.41 0.001096
#> 58 14998 34.51 0.000998
#> 59 15103 34.59 0.000910
#> 60 15216 34.65 0.000829
#> 61 15329 34.70 0.000755
#> 62 15449 34.75 0.000688
#> 63 15542 34.78 0.000627
#> 64 15668 34.82 0.000571
#> 65 15739 34.87 0.000521
#> 66 15770 34.89 0.000474
#> 67 15843 34.91 0.000432
#> 68 15894 34.93 0.000394
#> 69 15957 34.94 0.000359
#> 70 16004 34.96 0.000327
#> 71 16044 34.96 0.000298
#> 72 16099 34.97 0.000271
#> 73 16141 34.98 0.000247
#> 74 16174 34.99 0.000225
#> 75 16211 35.00 0.000205
#> 76 16251 35.00 0.000187
#> 77 16339 35.01 0.000170
#> 78 16393 35.02 0.000155
#> 79 16396 35.02 0.000141
#> 80 16393 35.02 0.000129
#> 81 16401 35.03 0.000118
#> 82 16424 35.03 0.000107
#> 83 16450 35.03 0.000098
#> 84 16493 35.03 0.000089
#> 85 16509 35.04 0.000081
#> 86 16501 35.04 0.000074
#> 87 16523 35.04 0.000067
#> 88 16522 35.04 0.000061
#> 89 16528 35.04 0.000056
#> 90 16529 35.04 0.000051
#> 91 16542 35.04 0.000046
#> 92 16555 35.04 0.000042
#> 93 16567 35.04 0.000038
#> 94 16579 35.04 0.000035
#> 95 16649 35.05 0.000032
#> 96 16644 35.05 0.000029

@EmilHvitfeldt EmilHvitfeldt changed the title let fit_xy() take dgCMatrix input let fit_xy() take dgCMatrix input May 23, 2024
test_that("to_sparse_data_frame() is used correctly", {
skip_if_not_installed("LiblineaR")

local_mocked_bindings(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main testing strategy follows this template:

  • mock the functions that deals with sparsevctrs
  • see if we can trigger all paths inside those functions

set_engine("LiblineaR")

expect_no_error(
lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this didn't work, it would take quite a lot longer to run which we would notice

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we? Does "didn't work" mean a failure or just inefficient?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inefficient. it should pop up as a "this test is running a little long" from CRAN / CMD R Check

@EmilHvitfeldt EmilHvitfeldt requested a review from topepo June 29, 2024 03:10
@@ -32,6 +32,7 @@ Imports:
prettyunits,
purrr (>= 1.0.0),
rlang (>= 1.1.0),
sparsevctrs (>= 0.1.0.9000),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it uses dev version because of this bug fix: r-lib/sparsevctrs@9c22ca9

sparsevctrs will of course be merged in time for parsnip release

Copy link
Member

@topepo topepo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add something to fit.model_spec() to mention that this can happen? Also, does the description for x need to reflect this?

Also, please add a note to the NEWS file.

set_engine("LiblineaR")

expect_no_error(
lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we? Does "didn't work" mean a failure or just inefficient?

@EmilHvitfeldt
Copy link
Member Author

added a little documentation. It will fleshed out more once the other parts of #1125 is added

R/sparsevctrs.R Outdated
if (allow_sparse(object)) {
x <- sparsevctrs::coerce_to_sparse_data_frame(x)
} else {
cli::cli_warn(c(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to update the PR to make this a failure.

@EmilHvitfeldt EmilHvitfeldt merged commit 81d9536 into main Aug 28, 2024
10 checks passed
@EmilHvitfeldt EmilHvitfeldt deleted the sparse-input branch August 28, 2024 19:59
Copy link

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Sep 12, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants