Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sparse argument to step_dummy() #1392

Merged
merged 4 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

* All steps and checks now require arguments `trained`, `skip`, `role`, and `id` at all times.

* `step_dummy()` gained `sparse` argument. When set to `TRUE`, `step_dummy()` will produce sparse vectors. (#1392)

# recipes 1.1.0

## Improvements
Expand Down
92 changes: 59 additions & 33 deletions R/dummy.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#' @param levels A list that contains the information needed to create dummy
#' variables for each variable contained in `terms`. This is `NULL` until the
#' step is trained by [prep()].
#' @param sparse A logical. Should the columns produced be sparse vectors.
#' Sparsity is only supported for `"contr.treatment"` contrasts. Defaults to
#' `FALSE`.
#' @template step-return
#' @family dummy variable and encoding steps
#' @seealso [dummy_names()]
Expand Down Expand Up @@ -60,7 +63,8 @@
#' this step.
#'
#' Also, there are a number of contrast methods that return fractional values.
#' The columns returned by this step are doubles (not integers).
#' The columns returned by this step are doubles (not integers) when
#' `sparse = FALSE`. The columns returned when `sparse = TRUE` are integers.
#'
#' The [package vignette for dummy variables](https://recipes.tidymodels.org/articles/Dummies.html)
#' and interactions has more information.
Expand Down Expand Up @@ -121,6 +125,7 @@
preserve = deprecated(),
naming = dummy_names,
levels = NULL,
sparse = FALSE,
keep_original_cols = FALSE,
skip = FALSE,
id = rand_id("dummy")) {
Expand All @@ -143,6 +148,7 @@
preserve = keep_original_cols,
naming = naming,
levels = levels,
sparse = sparse,
keep_original_cols = keep_original_cols,
skip = skip,
id = id
Expand All @@ -151,7 +157,7 @@
}

step_dummy_new <-
function(terms, role, trained, one_hot, preserve, naming, levels,
function(terms, role, trained, one_hot, preserve, naming, levels, sparse,
keep_original_cols, skip, id) {
step(
subclass = "dummy",
Expand All @@ -162,6 +168,7 @@
preserve = preserve,
naming = naming,
levels = levels,
sparse = sparse,
keep_original_cols = keep_original_cols,
skip = skip,
id = id
Expand All @@ -174,6 +181,7 @@
check_type(training[, col_names], types = c("factor", "ordered"))
check_bool(x$one_hot, arg = "one_hot")
check_function(x$naming, arg = "naming", allow_empty = FALSE)
check_bool(x$sparse, arg = "sparse")

if (length(col_names) > 0) {
## I hate doing this but currently we are going to have
Expand Down Expand Up @@ -218,6 +226,7 @@
preserve = x$preserve,
naming = x$naming,
levels = levels,
sparse = x$sparse,
keep_original_cols = get_keep_original_cols(x),
skip = x$skip,
id = x$id
Expand Down Expand Up @@ -285,43 +294,60 @@
col_name,
step = "step_dummy"
)

new_data[, col_name] <- factor(
new_data[[col_name]],
levels = levels_values,
ordered = is_ordered
)

new_data[, col_name] <-
factor(
new_data[[col_name]],
levels = levels_values,
ordered = is_ordered
)
if (object$sparse) {
current_contrast <- getOption("contrasts")[is_ordered + 1]
if (current_contrast != "contr.treatment") {
cli::cli_abort(
"When {.code sparse = TRUE}, only {.val contr.treatment} contrasts are
supported, not {.val {current_contrast}}."
)
}

indicators <-
model.frame(
rlang::new_formula(lhs = NULL, rhs = rlang::sym(col_name)),
data = new_data[, col_name],
xlev = levels_values,
na.action = na.pass
indicators <- sparsevctrs::sparse_dummy(
x = new_data[[col_name]],
one_hot = object$one_hot
)

indicators <- tryCatch(
model.matrix(object = levels, data = indicators),
error = function(cnd) {
if (grepl("(vector memory|cannot allocate)", cnd$message)) {
n_levels <- length(attr(levels, "values"))
cli::cli_abort(
"{.var {col_name}} contains too many levels ({n_levels}), \\
which would result in a data.frame too large to fit in memory.",
call = NULL
)
indicators <- tibble::new_tibble(indicators)
used_lvl <- colnames(indicators)
} else {
indicators <-
model.frame(
rlang::new_formula(lhs = NULL, rhs = rlang::sym(col_name)),
data = new_data[, col_name],
xlev = levels_values,
na.action = na.pass
)

indicators <- tryCatch(
model.matrix(object = levels, data = indicators),
error = function(cnd) {
if (grepl("(vector memory|cannot allocate)", cnd$message)) {
n_levels <- length(attr(levels, "values"))
cli::cli_abort(
"{.var {col_name}} contains too many levels ({n_levels}), \\
which would result in a data.frame too large to fit in memory.",
call = NULL
)
}
stop(cnd)

Check warning on line 339 in R/dummy.R

View check run for this annotation

Codecov / codecov/patch

R/dummy.R#L339

Added line #L339 was not covered by tests
}
stop(cnd)
)

if (!object$one_hot) {
indicators <- indicators[, colnames(indicators) != "(Intercept)", drop = FALSE]
}
)

if (!object$one_hot) {
indicators <- indicators[, colnames(indicators) != "(Intercept)", drop = FALSE]

## use backticks for nonstandard factor levels here
used_lvl <- gsub(paste0("^\\`?", col_name, "\\`?"), "", colnames(indicators))
}

## use backticks for nonstandard factor levels here
used_lvl <- gsub(paste0("^\\`?", col_name, "\\`?"), "", colnames(indicators))

new_names <- object$naming(col_name, used_lvl, is_ordered)
colnames(indicators) <- new_names
indicators <- check_name(indicators, new_data, object, new_names)
Expand Down
8 changes: 7 additions & 1 deletion man/step_dummy.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions tests/testthat/_snaps/dummy.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@
Caused by error in `bake()`:
! Only one factor level in `x`: "only-level".

# sparse = TRUE errors on unsupported contrasts

Code
recipe(~., data = tibble(x = letters)) %>% step_dummy(x, sparse = TRUE) %>%
prep()
Condition
Error in `step_dummy()`:
Caused by error in `bake()`:
! When `sparse = TRUE`, only "contr.treatment" contrasts are supported, not "contr.helmert".

# bake method errors when needed non-standard role columns are missing

Code
Expand Down
28 changes: 28 additions & 0 deletions tests/testthat/test-dummy.R
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,34 @@ test_that("throws an informative error for single level", {
)
})

test_that("sparse = TRUE works", {
rec <- recipe(~ ., data = tibble(x = c(NA, letters)))

suppressWarnings({
dense <- rec %>% step_dummy(x, sparse = FALSE) %>% prep() %>% bake(NULL)
dense <- purrr::map(dense, as.integer) %>% tibble::new_tibble()
EmilHvitfeldt marked this conversation as resolved.
Show resolved Hide resolved
sparse <- rec %>% step_dummy(x, sparse = TRUE) %>% prep() %>% bake(NULL)
})

expect_identical(dense, sparse)

expect_false(any(vapply(dense, sparsevctrs::is_sparse_vector, logical(1))))
expect_true(all(vapply(sparse, sparsevctrs::is_sparse_vector, logical(1))))
})

test_that("sparse = TRUE errors on unsupported contrasts", {
go_helmert <- getOption("contrasts")
go_helmert["unordered"] <- "contr.helmert"
withr::local_options(contrasts = go_helmert)

expect_snapshot(
error = TRUE,
recipe(~ ., data = tibble(x = letters)) %>%
step_dummy(x, sparse = TRUE) %>%
prep()
)
})

# Infrastructure ---------------------------------------------------------------

test_that("bake method errors when needed non-standard role columns are missing", {
Expand Down
Loading