diff --git a/R/grid_code_paths.R b/R/grid_code_paths.R index 39971fb7..505d6ccc 100644 --- a/R/grid_code_paths.R +++ b/R/grid_code_paths.R @@ -384,7 +384,7 @@ tune_grid_loop_iter <- function(split, assessment_rows <- as.integer(split, data = "assessment") assessment <- vctrs::vec_slice(split$data, assessment_rows) - if (workflows::.should_inner_split(workflow)) { + if (workflows::.workflow_includes_calibration(workflow)) { # if the workflow has a postprocessor that needs training (i.e. calibration), # further split the analysis data into an "inner" analysis and # assessment set. @@ -397,11 +397,6 @@ tune_grid_loop_iter <- function(split, # calibration set # * the model (including the post-processor) generates predictions on the # assessment set and those predictions are assessed with performance metrics - # todo: check if workflow's `method` is incompatible with `class(split)`? - # todo: workflow's `method` is currently ignored in favor of the one - # automatically dispatched to from `split`. consider this is combination - # with above todo. - split_args <- c(split_args, list(prop = workflow$post$actions$tailor$prop)) split <- rsample::inner_split(split, split_args = split_args) analysis <- rsample::analysis(split) diff --git a/tests/testthat/test-last-fit.R b/tests/testthat/test-last-fit.R index 6635b5f5..3fcde31d 100644 --- a/tests/testthat/test-last-fit.R +++ b/tests/testthat/test-last-fit.R @@ -246,9 +246,7 @@ test_that("can use `last_fit()` with a workflow - postprocessor (requires traini parsnip::linear_reg() ) %>% workflows::add_tailor( - tailor::tailor("regression") %>% tailor::adjust_numeric_calibration("linear"), - prop = 2/3, - method = class(split) + tailor::tailor() %>% tailor::adjust_numeric_calibration("linear") ) set.seed(1) @@ -261,13 +259,21 @@ test_that("can use `last_fit()` with a workflow - postprocessor (requires traini last_fit_preds <- collect_predictions(last_fit_res) set.seed(1) - wflow_res <- generics::fit(wflow, rsample::analysis(split)) + inner_split <- rsample::inner_split(split, split_args = list()) + + set.seed(1) + wflow_res <- + generics::fit( + wflow, + rsample::analysis(inner_split), + calibration = rsample::assessment(inner_split) + ) wflow_preds <- predict(wflow_res, rsample::assessment(split)) expect_equal(last_fit_preds[".pred"], wflow_preds) }) -test_that("can use `last_fit()` with a workflow - postprocessor (requires training)", { +test_that("can use `last_fit()` with a workflow - postprocessor (does not require training)", { skip_if_not_installed("tailor") y <- seq(0, 7, .001) @@ -284,9 +290,7 @@ test_that("can use `last_fit()` with a workflow - postprocessor (requires traini parsnip::linear_reg() ) %>% workflows::add_tailor( - tailor::tailor("regression") %>% tailor::adjust_numeric_range(lower_limit = 1), - prop = 2/3, - method = class(split) + tailor::tailor() %>% tailor::adjust_numeric_range(lower_limit = 1) ) set.seed(1) diff --git a/tests/testthat/test-resample.R b/tests/testthat/test-resample.R index a1025159..6c4331eb 100644 --- a/tests/testthat/test-resample.R +++ b/tests/testthat/test-resample.R @@ -151,9 +151,7 @@ test_that("can use `fit_resamples()` with a workflow - postprocessor (requires t parsnip::linear_reg() ) %>% workflows::add_tailor( - tailor::tailor("regression") %>% tailor::adjust_numeric_calibration("linear"), - prop = 2/3, - method = class(folds$splits[[1]]) + tailor::tailor() %>% tailor::adjust_numeric_calibration("linear") ) set.seed(1) @@ -178,8 +176,20 @@ test_that("can use `fit_resamples()` with a workflow - postprocessor (requires t seed <- generate_seeds(TRUE, 1)[[1]] old_kind <- RNGkind()[[1]] assign(".Random.seed", seed, envir = globalenv()) + withr::defer(RNGkind(kind = old_kind)) - wflow_res <- generics::fit(wflow, rsample::analysis(folds$splits[[1]])) + inner_split_1 <- + rsample::inner_split( + folds$splits[[1]], + split_args = list(v = 2, repeats = 1, breaks = 4, pool = 0.1) + ) + + wflow_res <- + generics::fit( + wflow, + rsample::analysis(inner_split_1), + calibration = rsample::assessment(inner_split_1) + ) wflow_preds <- predict(wflow_res, rsample::assessment(folds$splits[[1]])) tune_wflow$fit$fit$elapsed$elapsed <- wflow_res$fit$fit$elapsed$elapsed @@ -201,7 +211,7 @@ test_that("can use `fit_resamples()` with a workflow - postprocessor (no trainin parsnip::linear_reg() ) %>% workflows::add_tailor( - tailor::tailor("regression") %>% tailor::adjust_numeric_range(lower_limit = 1) + tailor::tailor() %>% tailor::adjust_numeric_range(lower_limit = 1) ) set.seed(1)