Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use grid_space_filling() instead of grid_latin_hypercube() #919

Merged
merged 4 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Depends:
R (>= 4.0)
Imports:
cli (>= 3.3.0),
dials (>= 1.0.0),
dials (>= 1.3.0),
doFuture (>= 1.0.0),
dplyr (>= 1.1.0),
foreach,
Expand All @@ -33,13 +33,13 @@ Imports:
purrr (>= 1.0.0),
recipes (>= 1.0.4),
rlang (>= 1.1.0),
rsample (>= 1.2.0),
rsample (>= 1.2.1.9000),
tibble (>= 3.1.0),
tidyr (>= 1.2.0),
tidyselect (>= 1.1.2),
vctrs (>= 0.6.1),
withr,
workflows (>= 1.1.4),
workflows (>= 1.1.4.9000),
yardstick (>= 1.3.0)
Suggests:
C50,
Expand All @@ -66,4 +66,4 @@ Encoding: UTF-8
Language: en-US
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* The package will now log a backtrace for errors and warnings that occur during tuning. When a tuning process encounters issues, see the new `trace` column in the `collect_notes(.Last.tune.result)` output to find precisely where the error occurred (#873).

* When automatic grids are used, `dials::grid_space_filling()` is now used (instead of `dials::grid_latin_hypercube()`). Overall, the new function produces optimized designs (not depending on random numbers). When using Bayesian models, we will use a Latin Hypercube since we produce 5,000 candidates, which is too slow to do with pre-optimized designs.

# tune 1.2.1

* Addressed issue in `int_pctl()` where the function would error when parallelized using `makePSOCKcluster()` (#885).
Expand Down
2 changes: 1 addition & 1 deletion R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ check_grid <- function(grid, workflow, pset = NULL, call = caller_env()) {
}
check_workflow(workflow, pset = pset, check_dials = TRUE, call = call)

grid <- dials::grid_latin_hypercube(pset, size = grid)
grid <- dials::grid_space_filling(pset, size = grid)
grid <- dplyr::distinct(grid)
}

Expand Down
4 changes: 2 additions & 2 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ create_initial_set <- function(param, n = NULL, checks) {
if (any(checks == "bayes")) {
check_bayes_initial_size(nrow(param), n)
}
dials::grid_latin_hypercube(param, size = n)
dials::grid_space_filling(param, size = n)
}

check_iter <- function(iter, call) {
Expand Down Expand Up @@ -632,7 +632,7 @@ fit_gp <- function(dat, pset, metric, eval_time = NULL, control, ...) {

pred_gp <- function(object, pset, size = 5000, current = NULL, control) {
pred_grid <-
dials::grid_latin_hypercube(pset, size = size) %>%
dials::grid_space_filling(pset, size = size, type = "latin_hypercube") %>%
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I stuck with a LHD here because the new designs take a really long time to compute a really large grid (5,000 points).

dplyr::distinct()

if (!is.null(current)) {
Expand Down
5 changes: 2 additions & 3 deletions R/tune_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@
#'
#' @section Parameter Grids:
#'
#' If no tuning grid is provided, a semi-random grid (via
#' [dials::grid_latin_hypercube()]) is created with 10 candidate parameter
#' combinations.
#' If no tuning grid is provided, a grid (via [dials::grid_space_filling()]) is
#' created with 10 candidate parameter combinations.
#'
#' When provided, the grid should have column names for each parameter and
#' these should be named by the parameter name or `id`. For example, if a
Expand Down
3 changes: 2 additions & 1 deletion inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ Codecov
Davison
Disambiguates
EI
foreach
Hinkley
Isomap
Lifecycle
Expand All @@ -15,10 +14,12 @@ Olshen
PSOCK
RNGkind
Wadsworth
backtrace
cdot
doi
el
finetune
foreach
frac
geo
ggplot
Expand Down
5 changes: 2 additions & 3 deletions man/tune_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 22 additions & 18 deletions tests/testthat/_snaps/bayes.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,29 +170,29 @@

-- Iteration 1 -----------------------------------------------------------------

i Current best: rmse=2.453 (@iter 0)
i Current best: rmse=2.461 (@iter 0)
i Gaussian process model
! The Gaussian process model is being fit using 1 features but only has 2
data points to do so. This may cause errors or a poor model fit.
v Gaussian process model
i Generating 3 candidates
i Predicted candidates
i num_comp=4
i num_comp=5
i Estimating performance
v Estimating performance
(x) Newest results: rmse=2.461 (+/-0.37)
<3 Newest results: rmse=2.453 (+/-0.381)

-- Iteration 2 -----------------------------------------------------------------

i Current best: rmse=2.453 (@iter 0)
i Current best: rmse=2.453 (@iter 1)
i Gaussian process model
v Gaussian process model
i Generating 2 candidates
i Predicted candidates
i num_comp=3
i num_comp=1
i Estimating performance
v Estimating performance
<3 Newest results: rmse=2.418 (+/-0.357)
(x) Newest results: rmse=2.646 (+/-0.286)
Output
# Tuning results
# 10-fold cross-validation
Expand Down Expand Up @@ -225,14 +225,14 @@

-- Iteration 1 -----------------------------------------------------------------

i Current best: rmse=2.453 (@iter 0)
i Current best: rmse=2.461 (@iter 0)
i Gaussian process model
! The Gaussian process model is being fit using 1 features but only has 2
data points to do so. This may cause errors or a poor model fit.
v Gaussian process model
i Generating 3 candidates
i Predicted candidates
i num_comp=4
i num_comp=5
i Estimating performance
i Fold01: preprocessor 1/1
v Fold01: preprocessor 1/1
Expand Down Expand Up @@ -295,16 +295,16 @@
i Fold10: preprocessor 1/1, model 1/1 (extracts)
i Fold10: preprocessor 1/1, model 1/1 (predictions)
v Estimating performance
(x) Newest results: rmse=2.461 (+/-0.37)
<3 Newest results: rmse=2.453 (+/-0.381)

-- Iteration 2 -----------------------------------------------------------------

i Current best: rmse=2.453 (@iter 0)
i Current best: rmse=2.453 (@iter 1)
i Gaussian process model
v Gaussian process model
i Generating 2 candidates
i Predicted candidates
i num_comp=3
i num_comp=1
i Estimating performance
i Fold01: preprocessor 1/1
v Fold01: preprocessor 1/1
Expand Down Expand Up @@ -367,7 +367,7 @@
i Fold10: preprocessor 1/1, model 1/1 (extracts)
i Fold10: preprocessor 1/1, model 1/1 (predictions)
v Estimating performance
<3 Newest results: rmse=2.418 (+/-0.357)
(x) Newest results: rmse=2.646 (+/-0.286)
Output
# Tuning results
# 10-fold cross-validation
Expand Down Expand Up @@ -523,12 +523,6 @@
data points to do so. This may cause errors or a poor model fit.
! For the rsq estimates, 1 missing value was found and removed before fitting
the Gaussian process model.
! For the rsq estimates, 1 missing value was found and removed before fitting
the Gaussian process model.
! For the rsq estimates, 1 missing value was found and removed before fitting
the Gaussian process model.
! For the rsq estimates, 1 missing value was found and removed before fitting
the Gaussian process model.
! validation: internal: A correlation computation is required, but `estimate` is constant and ha...
! For the rsq estimates, 2 missing values were found and removed before
fitting the Gaussian process model.
Expand All @@ -545,6 +539,16 @@
! For the rsq estimates, 6 missing values were found and removed before
fitting the Gaussian process model.
! validation: internal: A correlation computation is required, but `estimate` is constant and ha...
! For the rsq estimates, 7 missing values were found and removed before
fitting the Gaussian process model.
! validation: internal: A correlation computation is required, but `estimate` is constant and ha...
! For the rsq estimates, 8 missing values were found and removed before
fitting the Gaussian process model.
! validation: internal: A correlation computation is required, but `estimate` is constant and ha...
! For the rsq estimates, 9 missing values were found and removed before
fitting the Gaussian process model.
! validation: internal: A correlation computation is required, but `estimate` is constant and ha...
! No improvement for 10 iterations; returning current results.

---

Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/_snaps/fit_best.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
fit_best(knn_pca_res, verbose = TRUE)
Output
Using rmse as the metric, the optimal parameters were:
neighbors: 10
num_comp: 3
neighbors: 1
num_comp: 4

Message
i Fitting using 161 data points...
Expand All @@ -23,13 +23,13 @@
-- Model -----------------------------------------------------------------------

Call:
kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(10L, data, 5))
kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(1L, data, 5))

Type of response variable: continuous
minimal mean absolute error: 1.690086
Minimal mean squared error: 4.571625
minimal mean absolute error: 1.015528
Minimal mean squared error: 2.448261
Best kernel: optimal
Best k: 10
Best k: 1

---

Expand Down
20 changes: 10 additions & 10 deletions tests/testthat/_snaps/grid.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@
# A tibble: 10 x 4
splits id .metrics .notes
<list> <chr> <list> <list>
1 <split [28/4]> Fold01 <tibble [4 x 5]> <tibble [0 x 4]>
2 <split [28/4]> Fold02 <tibble [4 x 5]> <tibble [0 x 4]>
3 <split [29/3]> Fold03 <tibble [4 x 5]> <tibble [0 x 4]>
4 <split [29/3]> Fold04 <tibble [4 x 5]> <tibble [0 x 4]>
5 <split [29/3]> Fold05 <tibble [4 x 5]> <tibble [0 x 4]>
6 <split [29/3]> Fold06 <tibble [4 x 5]> <tibble [0 x 4]>
7 <split [29/3]> Fold07 <tibble [4 x 5]> <tibble [0 x 4]>
8 <split [29/3]> Fold08 <tibble [4 x 5]> <tibble [0 x 4]>
9 <split [29/3]> Fold09 <tibble [4 x 5]> <tibble [0 x 4]>
10 <split [29/3]> Fold10 <tibble [4 x 5]> <tibble [0 x 4]>
1 <split [28/4]> Fold01 <tibble [6 x 5]> <tibble [0 x 4]>
2 <split [28/4]> Fold02 <tibble [6 x 5]> <tibble [0 x 4]>
3 <split [29/3]> Fold03 <tibble [6 x 5]> <tibble [0 x 4]>
4 <split [29/3]> Fold04 <tibble [6 x 5]> <tibble [0 x 4]>
5 <split [29/3]> Fold05 <tibble [6 x 5]> <tibble [0 x 4]>
6 <split [29/3]> Fold06 <tibble [6 x 5]> <tibble [0 x 4]>
7 <split [29/3]> Fold07 <tibble [6 x 5]> <tibble [0 x 4]>
8 <split [29/3]> Fold08 <tibble [6 x 5]> <tibble [0 x 4]>
9 <split [29/3]> Fold09 <tibble [6 x 5]> <tibble [0 x 4]>
10 <split [29/3]> Fold10 <tibble [6 x 5]> <tibble [0 x 4]>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that this snap is changing. What's the explanation here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "6" is because we have three standard metrics (added brier) but also different due to different initial grids.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that we have 3 default classification metrics now, but why wouldn't that snapped have failed before this PR? What do the grids have to do with this change?


17 changes: 13 additions & 4 deletions tests/testthat/test-autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,21 @@ test_that("plot_perf_vs_iter with fairness metrics (#773)", {

test_that("regular grid plot", {
skip_if_not_installed("ggplot2", minimum_version = "3.5.0")
topepo marked this conversation as resolved.
Show resolved Hide resolved
set.seed(1)
res <-

svm_spec <-
parsnip::svm_rbf(cost = tune()) %>%
parsnip::set_engine("kernlab") %>%
parsnip::set_mode("regression") %>%
tune_grid(mpg ~ ., resamples = rsample::vfold_cv(mtcars, v = 5), grid = 1)
parsnip::set_mode("regression")

svm_grid <-
svm_spec %>%
extract_parameter_set_dials() %>%
dials::grid_regular(levels = 1)

set.seed(1)
res <-
svm_spec %>%
tune_grid(mpg ~ ., resamples = rsample::vfold_cv(mtcars, v = 5), grid = svm_grid)

expect_snapshot(
error = TRUE,
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ test_that("tune model and recipe", {
grid_3 <-
extract_parameter_set_dials(wflow_3) %>%
update(num_comp = dials::num_comp(c(2, 5))) %>%
dials::grid_latin_hypercube(size = 4)
dials::grid_space_filling(size = 4)

expect_error(
res_3_1 <- tune_grid(
Expand Down
Loading