Skip to content

Commit

Permalink
Compute grid info dplyr (#961)
Browse files Browse the repository at this point in the history
* refactored compute_grid_info() using dplyr, purrr, and tidyr

* remove padding in .config

* sort values for tests

* update test specification for different sorting

* fix bug in the messages

* update snapshots with new remotes

* added padding back
  • Loading branch information
topepo authored Nov 14, 2024
1 parent f85eac9 commit e16bb0d
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 57 deletions.
2 changes: 1 addition & 1 deletion R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ utils::globalVariables(
"rowwise", ".best", "location", "msg", "..object", ".eval_time",
".pred_survival", ".pred_time", ".weight_censored", "nice_time",
"time_metric", ".lower", ".upper", "i", "results", "term", ".alpha",
".method", "old_term"
".method", "old_term", ".lab_pre", ".model", ".num_models"
)
)

Expand Down
79 changes: 50 additions & 29 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -318,50 +318,71 @@ compute_grid_info <- function(workflow, grid) {

res <- min_grid(extract_spec_parsnip(workflow), grid)

syms_pre <- rlang::syms(parameters_preprocessor$id)
syms_mod <- rlang::syms(parameters_model$id)

# ----------------------------------------------------------------------------
# Create an order of execution to train the preprocessor (if any). This will
# define a loop over any preprocessing tuning parameter combinations.
if (any_parameters_preprocessor) {
res$.iter_preprocessor <- seq_len(nrow(res))
pp_df <-
dplyr::distinct(res, !!!syms_pre) %>%
dplyr::arrange(!!!syms_pre) %>%
dplyr::mutate(
.iter_preprocessor = dplyr::row_number(),
.lab_pre = recipes::names0(max(dplyr::n()), "Preprocessor")
)
res <-
dplyr::full_join(res, pp_df, by = parameters_preprocessor$id) %>%
dplyr::arrange(.iter_preprocessor)
} else {
res$.iter_preprocessor <- 1L
res$.lab_pre <- "Preprocessor1"
}

# Make the label shown in the grid and in loggining
res$.msg_preprocessor <-
new_msgs_preprocessor(
seq_len(max(res$.iter_preprocessor)),
res$.iter_preprocessor,
max(res$.iter_preprocessor)
)

if (nrow(res) != nrow(grid) ||
(any_parameters_model && !any_parameters_preprocessor)) {
res$.iter_model <- seq_len(dplyr::n_distinct(res[parameters_model$id]))
} else {
res$.iter_model <- 1L
}

res$.iter_config <- list(list())
for (row in seq_len(nrow(res))) {
res$.iter_config[row] <- list(iter_config(res[row, ]))
}
# ----------------------------------------------------------------------------
# Now make a similar iterator across models. Conditioning on each unique
# preprocessing candidate set, make an iterator for the model candidate sets
# (if any)

res <-
res %>%
dplyr::group_nest(.iter_preprocessor, keep = TRUE) %>%
dplyr::mutate(
.iter_config = purrr::map(data, make_iter_config),
.model = purrr::map(data, ~ tibble::tibble(.iter_model = seq_len(nrow(.x)))),
.num_models = purrr::map_int(.model, nrow)
) %>%
dplyr::select(-.iter_preprocessor) %>%
tidyr::unnest(cols = c(data, .model, .iter_config)) %>%
dplyr::select(-.lab_pre) %>%
dplyr::relocate(dplyr::starts_with(".iter"))

res$.msg_model <-
new_msgs_model(i = res$.iter_model, n = max(res$.iter_model), res$.msg_preprocessor)
new_msgs_model(i = res$.iter_model,
n = res$.num_models,
res$.msg_preprocessor)

res
res %>%
dplyr::select(-.num_models) %>%
dplyr::relocate(dplyr::starts_with(".msg"))
}

iter_config <- function(res_row) {
submodels <- res_row$.submodels[[1]]
if (identical(submodels, list())) {
models <- res_row$.iter_model
} else {
models <- seq_len(length(submodels[[1]]) + 1)
}

paste0(
"Preprocessor",
res_row$.iter_preprocessor,
"_Model",
format_with_padding(models)
)
make_iter_config <- function(dat) {
# Compute labels for the models *within* each preprocessing loop.
num_submodels <- purrr::map_int(dat$.submodels, ~ length(unlist(.x)))
num_models <- sum(num_submodels + 1) # +1 for the model being trained
.mod_label <- recipes::names0(num_models, "Model")
.iter_config <- paste(dat$.lab_pre[1], .mod_label, sep = "_")
.iter_config <- vctrs::vec_chop(.iter_config, sizes = num_submodels + 1)
tibble::tibble(.iter_config = .iter_config)
}

# This generates a "dummy" grid_info object that has the same
Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/_snaps/bayes.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,12 @@
Message
x Fold1: preprocessor 1/1:
Error in `step_spline_b()`:
Caused by error in `spline_msg()`:
! Error in if (df < 0) : missing value where TRUE/FALSE needed
Caused by error in `prep()`:
! `deg_free` must be a whole number, not a numeric `NA`.
x Fold2: preprocessor 1/1:
Error in `step_spline_b()`:
Caused by error in `spline_msg()`:
! Error in if (df < 0) : missing value where TRUE/FALSE needed
Caused by error in `prep()`:
! `deg_free` must be a whole number, not a numeric `NA`.
Condition
Warning:
All models failed. Run `show_notes(.Last.tune.result)` for more information.
Expand All @@ -415,10 +415,10 @@
Message
x Fold1: preprocessor 1/1:
Error in `get_all_predictors()`:
! The following predictors were not found in `data`: 'z'.
! The following predictor was not found in `data`: "z".
x Fold2: preprocessor 1/1:
Error in `get_all_predictors()`:
! The following predictors were not found in `data`: 'z'.
! The following predictor was not found in `data`: "z".
Condition
Warning:
All models failed. Run `show_notes(.Last.tune.result)` for more information.
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/_snaps/checks.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,22 @@
Error in `tune:::check_workflow()`:
! A parsnip model is required.

# errors informatively when needed package isn't installed

Code
check_workflow(stan_wflow)
Condition
Error:
! Package install is required for rstanarm.

---

Code
fit_resamples(stan_wflow, rsample::bootstraps(mtcars))
Condition
Error in `fit_resamples()`:
! Package install is required for rstanarm.

# workflow objects (will not tune, tidymodels/tune#548)

Code
Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/_snaps/grid.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
Message
x Fold1: preprocessor 1/1:
Error in `step_spline_b()`:
Caused by error in `spline_msg()`:
! Error in if (df < 0) : missing value where TRUE/FALSE needed
Caused by error in `prep()`:
! `deg_free` must be a whole number, not a numeric `NA`.
x Fold2: preprocessor 1/1:
Error in `step_spline_b()`:
Caused by error in `spline_msg()`:
! Error in if (df < 0) : missing value where TRUE/FALSE needed
Caused by error in `prep()`:
! `deg_free` must be a whole number, not a numeric `NA`.
Condition
Warning:
All models failed. Run `show_notes(.Last.tune.result)` for more information.
Expand All @@ -28,10 +28,10 @@
Message
x Fold1: preprocessor 1/1:
Error in `get_all_predictors()`:
! The following predictors were not found in `data`: 'z'.
! The following predictor was not found in `data`: "z".
x Fold2: preprocessor 1/1:
Error in `get_all_predictors()`:
! The following predictors were not found in `data`: 'z'.
! The following predictor was not found in `data`: "z".
Condition
Warning:
All models failed. Run `show_notes(.Last.tune.result)` for more information.
Expand Down
10 changes: 5 additions & 5 deletions tests/testthat/_snaps/resample.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
Message
x Fold1: preprocessor 1/1:
Error in `step_spline_natural()`:
Caused by error in `spline_msg()`:
! Error in if (df < 2) : missing value where TRUE/FALSE needed
Caused by error in `prep()`:
! `deg_free` must be a whole number, not a numeric `NA`.
x Fold2: preprocessor 1/1:
Error in `step_spline_natural()`:
Caused by error in `spline_msg()`:
! Error in if (df < 2) : missing value where TRUE/FALSE needed
Caused by error in `prep()`:
! `deg_free` must be a whole number, not a numeric `NA`.
Condition
Warning:
All models failed. Run `show_notes(.Last.tune.result)` for more information.
Expand All @@ -20,7 +20,7 @@
Code
note
Output
[1] "Error in `step_spline_natural()`:\nCaused by error in `spline_msg()`:\n! Error in if (df < 2) { : missing value where TRUE/FALSE needed"
[1] "Error in `step_spline_natural()`:\nCaused by error in `prep()`:\n! `deg_free` must be a whole number, not a numeric `NA`."

# failure in variables tidyselect specification is caught elegantly

Expand Down
36 changes: 26 additions & 10 deletions tests/testthat/test-grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ test_that("compute_grid_info - recipe only", {

expect_equal(res$.iter_preprocessor, 1:5)
expect_equal(res$.msg_preprocessor, paste0("preprocessor ", 1:5, "/5"))
expect_equal(res$deg_free, grid$deg_free)
expect_equal(sort(res$deg_free), sort(grid$deg_free))
expect_equal(res$.iter_model, rep(1, 5))
expect_equal(res$.iter_config, as.list(paste0("Preprocessor", 1:5, "_Model1")))
expect_equal(res$.msg_model, paste0("preprocessor ", 1:5, "/5, model 1/1"))
Expand All @@ -27,6 +27,7 @@ test_that("compute_grid_info - recipe only", {
ignore.order = TRUE
)
expect_equal(nrow(res), 5)
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid))
})

test_that("compute_grid_info - model only (no submodels)", {
Expand Down Expand Up @@ -57,6 +58,7 @@ test_that("compute_grid_info - model only (no submodels)", {
ignore.order = TRUE
)
expect_equal(nrow(res), 5)
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid))
})

test_that("compute_grid_info - model only (with submodels)", {
Expand Down Expand Up @@ -107,8 +109,8 @@ test_that("compute_grid_info - recipe and model (no submodels)", {

expect_equal(res$.iter_preprocessor, 1:5)
expect_equal(res$.msg_preprocessor, paste0("preprocessor ", 1:5, "/5"))
expect_equal(res$learn_rate, grid$learn_rate)
expect_equal(res$deg_free, grid$deg_free)
expect_equal(sort(res$learn_rate), sort(grid$learn_rate))
expect_equal(sort(res$deg_free), sort(grid$deg_free))
expect_equal(res$.iter_model, rep(1, 5))
expect_equal(res$.iter_config, as.list(paste0("Preprocessor", 1:5, "_Model1")))
expect_equal(res$.msg_model, paste0("preprocessor ", 1:5, "/5, model 1/1"))
Expand All @@ -120,6 +122,7 @@ test_that("compute_grid_info - recipe and model (no submodels)", {
ignore.order = TRUE
)
expect_equal(nrow(res), 5)
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid))
})

test_that("compute_grid_info - recipe and model (with submodels)", {
Expand Down Expand Up @@ -169,6 +172,7 @@ test_that("compute_grid_info - recipe and model (with submodels)", {
)
expect_equal(nrow(res), 3)
})

test_that("compute_grid_info - recipe and model (with and without submodels)", {
library(workflows)
library(parsnip)
Expand All @@ -185,25 +189,30 @@ test_that("compute_grid_info - recipe and model (with and without submodels)", {
# use grid_regular to (partially) trigger submodel trick
set.seed(1)
param_set <- extract_parameter_set_dials(wflow)
grid <- bind_rows(grid_regular(param_set), grid_space_filling(param_set))
grid <-
bind_rows(grid_regular(param_set), grid_space_filling(param_set)) %>%
arrange(deg_free, loss_reduction, trees)
res <- compute_grid_info(wflow, grid)

expect_equal(length(unique(res$.iter_preprocessor)), 5)
expect_equal(
unique(res$.msg_preprocessor),
paste0("preprocessor ", 1:5, "/5")
)
expect_equal(res$trees, c(rep(max(grid$trees), 10), 1))
expect_equal(sort(res$trees), sort(c(rep(max(grid$trees), 10), 1)))
expect_equal(unique(res$.iter_model), 1:3)
expect_equal(
res$.iter_config[1:3],
res$.iter_config[res$.iter_preprocessor == 1],
list(
c("Preprocessor1_Model1", "Preprocessor1_Model2", "Preprocessor1_Model3", "Preprocessor1_Model4"),
c("Preprocessor2_Model1", "Preprocessor2_Model2", "Preprocessor2_Model3"),
c("Preprocessor3_Model1", "Preprocessor3_Model2", "Preprocessor3_Model3")
c("Preprocessor1_Model01", "Preprocessor1_Model02", "Preprocessor1_Model03", "Preprocessor1_Model04"),
c("Preprocessor1_Model05", "Preprocessor1_Model06", "Preprocessor1_Model07"),
c("Preprocessor1_Model08", "Preprocessor1_Model09", "Preprocessor1_Model10")
)
)
expect_equal(res$.msg_model[1:3], paste0("preprocessor ", 1:3, "/5, model 1/3"))
expect_equal(
res$.msg_model[res$.iter_preprocessor == 1],
paste0("preprocessor 1/5, model ", 1:3, "/3")
)
expect_equal(
res$.submodels[1:3],
list(
Expand All @@ -212,6 +221,12 @@ test_that("compute_grid_info - recipe and model (with and without submodels)", {
list(trees = c(1L, 1000L))
)
)
expect_equal(
res %>%
mutate(num_models = purrr::map_int(.iter_config, length)) %>%
summarize(n = sum(num_models), .by = c(deg_free)),
grid %>% count(deg_free)
)
expect_named(
res,
c(".iter_preprocessor", ".msg_preprocessor", "deg_free", "trees",
Expand Down Expand Up @@ -325,4 +340,5 @@ test_that("compute_grid_info - recipe and model (no submodels but has inner grid
ignore.order = TRUE
)
expect_equal(nrow(res), 9)
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid))
})

0 comments on commit e16bb0d

Please sign in to comment.