Skip to content

Commit

Permalink
fixes for #975
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Dec 19, 2024
1 parent 2fb8701 commit f6da85f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 64 deletions.
15 changes: 13 additions & 2 deletions R/schedule.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,22 @@ get_tune_schedule <- function(wflow, param, grid) {
if (has_submodels) {
sched <- grid %>%
dplyr::group_nest(!!!symbs$fits, .key = "predict_stage")
# Note: multi_predict() should only be triggered for a submodel parameter if
# Note 1: multi_predict() should only be triggered for a submodel parameter if
# there are multiple rows in the `predict_stage` list column. i.e. the submodel
# column will always be there but we only multipredict when there are 2+
# values to predict.
first_loop_info <- min_grid(model_spec, grid)

# Note 2: The purpose of min_grid() is to determine the minimum grid for
# preprocessing and model parameters to fit. We compute it here and ignore
# any postprocessing tuning parmeters (if any). The postprocessing parameters
# will still be in the schedule since we schedule those before the results
# that use min_grid() are merged in. See issue #975 for an example and
# discussion.
first_loop_info <-
min_grid(model_spec,
grid %>%
dplyr::select(-dplyr::any_of(post_id)) %>%
dplyr::distinct())
} else {
sched <- grid %>%
dplyr::group_nest(!!!symbs$fits, .key = "predict_stage")
Expand Down
117 changes: 55 additions & 62 deletions tests/testthat/test-schedule.R
Original file line number Diff line number Diff line change
Expand Up @@ -645,75 +645,68 @@ test_that("grid processing schedule - recipe + model + tailor, submodels, irregu
prm_used_pre_model_post,
grid_pre_model_post)

# TODO trees seems to have an extra row:

# # A tibble: 4 × 3
# min_n predict_stage trees
# <int> <list> <int>
# 1 2 <tibble [1 × 2]> 1
# 2 21 <tibble [1 × 2]> 1
# 3 40 <tibble [2 × 2]> 1000 #<- shouldn't this row and the one below be combined?
# 4 40 <tibble [2 × 2]> 1

# sched_pre_model_post$model_stage[[1]] %>%
# select(-trees) %>%
# unnest(predict_stage) %>%
# unnest(post_stage) %>%
# arrange(min_n, trees, lower_limit)

# tibble::tribble(
# ~min_n, ~trees, ~lower_limit, ~trees0,
# 2L, 1L, 0, 1L,
# 21L, 1L, 0.5, 1L,
# 40L, 1L, 1, 1000L,
# 40L, 1L, 1, 1L,
# 40L, 1000L, 0, 1000L,
# 40L, 1000L, 0, 1L
# )

expect_named(sched_pre_model_post, c("threshold", "disp_df", "model_stage"))
expect_equal(
sched_pre_model_post %>% select(-model_stage) %>% as_tibble(),
grid_pre %>% arrange(threshold, disp_df)
)

# for (i in seq_along(sched_pre_model_post$model_stage)) {
# model_i <- sched_pre_model_post$model_stage[[i]]
# expect_named(model_i, c("min_n", "predict_stage", "trees"))
# expect_equal(
# model_i %>% select(min_n, trees) %>% arrange(min_n),
# grid_model$data[[i]]
# )
#
# for (j in seq_along(sched_pre_model_post$model_stage[[i]]$predict_stage)) {
# predict_j <- model_i$predict_stage[[j]]
#
# # We need to figure out the trees that need predicting for the current
# # set of other parameters.
#
# # Get the settings that have already be resolved:
# other_ij <-
# model_i %>%
# select(-predict_stage, -trees) %>%
# slice(j) %>%
# vctrs::vec_cbind(
# sched_pre_model_post %>%
# select(threshold, disp_df) %>%
# slice(i)
# )
# # What are the matching values from the grid?
# trees_ij <-
# grid_pre_model_post %>%
# inner_join(other_ij, by = c("min_n", "threshold", "disp_df")) %>%
# select(trees)
#
#
# expect_equal(
# predict_j %>% select(trees) %>% arrange(trees),
# trees_ij %>% arrange(trees)
# )
# }
# }
for (i in seq_along(sched_pre_model_post$model_stage)) {
model_i <- sched_pre_model_post$model_stage[[i]]

# Get the current set of preproc parameters to remove
other_i <-
sched_pre_model_post[i,] %>%
dplyr::select(-model_stage)

# We expect to evaulate these specific models for this set of preprocessors
exp_i <-
grid_pre_model_post %>%
inner_join(other_i, by = c("threshold", "disp_df")) %>%
arrange(trees, min_n, lower_limit) %>%
select(trees, min_n, lower_limit)

# What we will evaluate:
subgrid_i <-
model_i %>%
select(-trees) %>%
unnest(predict_stage) %>%
unnest(post_stage) %>%
arrange(trees, min_n, lower_limit) %>%
select(trees, min_n, lower_limit)

expect_equal(subgrid_i, exp_i)

# for (j in seq_along(sched_pre_model_post$model_stage[[i]]$predict_stage)) {
# predict_j <- model_i$predict_stage[[j]]
#
# # We need to figure out the trees that need predicting for the current
# # set of other parameters.
#
# # Get the settings that have already be resolved:
# other_ij <-
# model_i %>%
# select(-predict_stage, -trees) %>%
# slice(j) %>%
# vctrs::vec_cbind(
# sched_pre_model_post %>%
# select(threshold, disp_df) %>%
# slice(i)
# )
# # What are the matching values from the grid?
# trees_ij <-
# grid_pre_model_post %>%
# inner_join(other_ij, by = c("min_n", "threshold", "disp_df")) %>%
# select(trees)
#
#
# expect_equal(
# predict_j %>% select(trees) %>% arrange(trees),
# trees_ij %>% arrange(trees)
# )
# }
}

expect_s3_class(
sched_pre_model_post,
Expand Down

0 comments on commit f6da85f

Please sign in to comment.