Skip to content

Commit

Permalink
test parameter set extraction with postprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Oct 16, 2024
1 parent b20ae66 commit d3417d2
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/testthat/_snaps/extract.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,12 @@
Error in `extract_recipe()`:
! The workflow must have a recipe preprocessor.

# extract parameter set from workflow with potentially conflicting ids (#266)

Code
extract_parameter_set_dials(wflow)
Condition
Error in `extract_parameter_set_dials()`:
x Element id should have unique values.
i Duplicates exist for item: threshold

80 changes: 80 additions & 0 deletions tests/testthat/test-extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,23 @@ test_that("extract parameter set from workflow with tunable model", {
expect_true(all(wf_info$source == "model_spec"))
})

test_that("extract parameter set from workflow with tunable postprocessor", {
wflow <- workflow()
wflow <- add_recipe(wflow, recipes::recipe(mpg ~ ., mtcars))
wflow <- add_model(wflow, parsnip::linear_reg())
wflow <- add_tailor(
wflow,
tailor::tailor() %>%
tailor::adjust_numeric_range(lower_limit = hardhat::tune())
)

wflow_info <- extract_parameter_set_dials(wflow)

check_parameter_set_tibble(wflow_info)
expect_equal(nrow(wflow_info), 1)
expect_true(all(wflow_info$source == "tailor"))
})

test_that("extract parameter set from workflow with tunable recipe and model", {

spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
Expand All @@ -252,6 +269,69 @@ test_that("extract parameter set from workflow with tunable recipe and model", {
)
})

test_that("extract parameter set from workflow with tunable recipe, model, and tailor", {
wflow <- workflow()
wflow <- add_recipe(
wflow,
recipes::recipe(mpg ~ ., mtcars) %>%
recipes::step_impute_knn(
recipes::all_predictors(),
neighbors = hardhat::tune("imputation")
)
)
wflow <- add_model(
wflow,
parsnip::linear_reg(engine = "glmnet", penalty = tune())
)
wflow <- add_tailor(
wflow,
tailor::tailor() %>%
tailor::adjust_numeric_range(lower_limit = hardhat::tune())
)

wflow_info <- extract_parameter_set_dials(wflow)

check_parameter_set_tibble(wflow_info)
expect_equal(nrow(wflow_info), 3)
expect_true(all(wflow_info$source %in% c("recipe", "model_spec", "tailor")))
})

test_that("extract parameter set from workflow with potentially conflicting ids (#266)", {
# re: https://github.com/tidymodels/workflows/pull/266#issuecomment-2417772184
# specifically concerned that duplicated "threshold" parameters result in
# an informative error
wflow <- workflow()
wflow <- add_recipe(
wflow,
recipes::recipe(mpg ~ ., mtcars) %>%
recipes::step_pca(recipes::all_predictors(), threshold = hardhat::tune())
)
wflow <- add_model(wflow, parsnip::linear_reg())
wflow <- add_tailor(
wflow,
tailor::tailor() %>%
tailor::adjust_probability_threshold(threshold = hardhat::tune())
)

expect_snapshot(
error = TRUE,
extract_parameter_set_dials(wflow)
)

# ensure that the user can actually do something about it
wflow <- remove_tailor(wflow)
wflow <- add_tailor(
wflow,
tailor::tailor() %>%
tailor::adjust_probability_threshold(threshold = hardhat::tune("unique id"))
)

wflow_info <- extract_parameter_set_dials(wflow)

check_parameter_set_tibble(wflow_info)
expect_equal(nrow(wflow_info), 2)
expect_true(all(wflow_info$source %in% c("recipe", "tailor")))
})

# ------------------------------------------------------------------------------
# extract_parameter_dials()
Expand Down

0 comments on commit d3417d2

Please sign in to comment.