Skip to content

Commit

Permalink
Merge branch 'main' into refactor-eval-time-default
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Dec 4, 2023
2 parents 90d5d8a + 32f739a commit 8448353
Show file tree
Hide file tree
Showing 29 changed files with 105 additions and 38 deletions.
7 changes: 5 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tune
Title: Tidy Tuning Tools
Version: 1.1.2.9000
Version: 1.1.2.9001
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down Expand Up @@ -38,7 +38,7 @@ Imports:
vctrs (>= 0.6.1),
withr,
workflows (>= 1.0.0),
yardstick (>= 1.2.0)
yardstick (>= 1.2.0.9001)
Suggests:
C50,
censored,
Expand All @@ -47,10 +47,13 @@ Suggests:
kknn,
knitr,
modeldata,
scales,
spelling,
testthat (>= 3.0.0),
xgboost,
xml2
Remotes:
tidymodels/yardstick
Config/Needs/website: pkgdown, tidymodels, kknn, doParallel, doFuture,
tidyverse/tidytemplate
Config/testthat/edition: 3
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ export(.config_key_from_metrics)
export(.estimate_metrics)
export(.get_extra_col_names)
export(.get_fingerprint)
export(.get_tune_eval_times)
export(.get_tune_metric_names)
export(.get_tune_metrics)
export(.get_tune_outcome_names)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

* Improves documentation related to the hyperparameters associated with extracted objects that are generated from submodels. See the "Extracting with submodels" section of `?collect_extracts` to learn more.

* An `eval_time` attribute was added to tune objects. There is also a `.get_tune_eval_times()` function.

* `augment()` methods to `tune_results`, `resample_results`, and `last_fit` objects now always returns tibbles (#759).

# tune 1.1.2
Expand Down
3 changes: 3 additions & 0 deletions R/compat-vctrs-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ new_tune_results_from_template <- function(x, to) {
x = x,
parameters = attrs$parameters,
metrics = attrs$metrics,
eval_time = attrs$eval_time,
outcomes = attrs$outcomes,
rset_info = attrs$rset_info
)
Expand Down Expand Up @@ -105,6 +106,7 @@ new_resample_results_from_template <- function(x, to) {
x = x,
parameters = attrs$parameters,
metrics = attrs$metrics,
eval_time = attrs$eval_time,
outcomes = attrs$outcomes,
rset_info = attrs$rset_info
)
Expand Down Expand Up @@ -134,6 +136,7 @@ new_iteration_results_from_template <- function(x, to) {
x = x,
parameters = attrs$parameters,
metrics = attrs$metrics,
eval_time = attrs$eval_time,
outcomes = attrs$outcomes,
rset_info = attrs$rset_info,
workflow = attrs$workflow
Expand Down
33 changes: 21 additions & 12 deletions R/iteration_results.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,27 @@
#' @rdname empty_ellipses
#' @param parameters A `parameters` object.
#' @param metrics A metric set.
#' @param eval_time A numeric vector of time points where dynamic event time
#' metrics should be computed (e.g. the time-dependent ROC curve, etc).
#' @param outcomes A character vector of outcome names.
#' @param rset_info Attributes from an `rset` object.
#' @param workflow The workflow used to fit the iteration results.
new_iteration_results <- function(x, parameters, metrics, outcomes = character(0),
rset_info, workflow) {
new_tune_results(
x = x,
parameters = parameters,
metrics = metrics,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow,
class = "iteration_results"
)
}
new_iteration_results <-
function(x,
parameters,
metrics,
eval_time,
outcomes = character(0),
rset_info,
workflow) {
new_tune_results(
x = x,
parameters = parameters,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow,
class = "iteration_results"
)
}
1 change: 1 addition & 0 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ resample_workflow <- function(workflow, resamples, metrics, control,
x = out,
parameters = attributes$parameters,
metrics = attributes$metrics,
eval_time = attributes$eval_time,
outcomes = attributes$outcomes,
rset_info = attributes$rset_info,
workflow = attributes$workflow
Expand Down
30 changes: 19 additions & 11 deletions R/resample_results.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,22 @@

# ------------------------------------------------------------------------------

new_resample_results <- function(x, parameters, metrics, outcomes = character(0), rset_info, workflow = NULL) {
new_tune_results(
x = x,
parameters = parameters,
metrics = metrics,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow,
class = "resample_results"
)
}
new_resample_results <-
function(x,
parameters,
metrics,
eval_time,
outcomes = character(0),
rset_info,
workflow = NULL) {
new_tune_results(
x = x,
parameters = parameters,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow,
class = "resample_results"
)
}
2 changes: 2 additions & 0 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ tune_bayes_workflow <-
x = unsummarized,
parameters = param_info,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = NULL
Expand Down Expand Up @@ -476,6 +477,7 @@ tune_bayes_workflow <-
x = unsummarized,
parameters = param_info,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow_output
Expand Down
1 change: 1 addition & 0 deletions R/tune_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ tune_grid_workflow <- function(workflow,
x = resamples,
parameters = pset,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow
Expand Down
31 changes: 20 additions & 11 deletions R/tune_results.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,26 @@ summarize_notes <- function(x) {

# ------------------------------------------------------------------------------

new_tune_results <- function(x, parameters, metrics, outcomes = character(0), rset_info, ..., class = character()) {
new_bare_tibble(
x = x,
parameters = parameters,
metrics = metrics,
outcomes = outcomes,
rset_info = rset_info,
...,
class = c(class, "tune_results")
)
}
new_tune_results <-
function(x,
parameters,
metrics,
eval_time,
outcomes = character(0),
rset_info,
...,
class = character()) {
new_bare_tibble(
x = x,
parameters = parameters,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
...,
class = c(class, "tune_results")
)
}

is_tune_results <- function(x) {
inherits(x, "tune_results")
Expand Down
14 changes: 14 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,20 @@ new_bare_tibble <- function(x, ..., class = character()) {
res
}


#' @export
#' @rdname tune_accessor
.get_tune_eval_times <- function(x) {
x <- attributes(x)
if (any(names(x) == "eval_time")) {
res <- x$eval_time
} else {
res <- NULL
}
res
}


#' @export
#' @rdname tune_accessor
.get_tune_outcome_names <- function(x) {
Expand Down
1 change: 1 addition & 0 deletions inst/test_objects.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
library(tidymodels)
library(scales)
library(sessioninfo)
library(testthat)

Expand Down
4 changes: 4 additions & 0 deletions man/empty_ellipses.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/tune_accessor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified tests/testthat/data/knn_gp.rds
Binary file not shown.
Binary file modified tests/testthat/data/knn_grid.rds
Binary file not shown.
Binary file modified tests/testthat/data/knn_results.rds
Binary file not shown.
Binary file modified tests/testthat/data/knn_set.rds
Binary file not shown.
Binary file modified tests/testthat/data/lm_bayes.rds
Binary file not shown.
Binary file modified tests/testthat/data/lm_resamples.rds
Binary file not shown.
Binary file modified tests/testthat/data/rcv_results.rds
Binary file not shown.
Binary file modified tests/testthat/data/svm_reg_results.rds
Binary file not shown.
Binary file modified tests/testthat/data/svm_results.rds
Binary file not shown.
Binary file modified tests/testthat/data/test_objects.RData
Binary file not shown.
4 changes: 3 additions & 1 deletion tests/testthat/test-autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ test_that("parameter plot for iterative search", {


test_that("regular grid plot", {
skip_if_not_installed("scales", "1.3.0")
rcv_results <- readRDS(test_path("data", "rcv_results.rds"))
svm_reg_results <- readRDS(test_path("data", "svm_reg_results.rds"))

Expand Down Expand Up @@ -166,6 +167,7 @@ test_that("regular grid plot", {
expect_equal(p$labels$y, "rmse")
expect_equal(p$labels$x, "deg_free")


p <- autoplot(svm_reg_results)
expect_s3_class(p, "ggplot")
expect_equal(
Expand All @@ -183,7 +185,7 @@ test_that("regular grid plot", {
expect_equal(p$labels$x, "Cost")
expect_equal(p$labels$group, "%^*#")

expect_equal(class(p$scales$scales[[1]]$trans), "trans")
expect_true(grepl("^trans", class(p$scales$scales[[1]]$trans)))
expect_equal(p$scales$scales[[1]]$trans$name, "log-2")
expect_equal(unique(p$data$name), "Cost")
})
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ test_that("tune recipe only", {
expect_equal(res_est$n, rep(10, iterT * 2))
expect_false(identical(num_comp, expr(tune())))
expect_true(res_workflow$trained)
expect_null(.get_tune_eval_times(res))

set.seed(1)
expect_error(
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ test_that("tune recipe only", {
expect_equal(res_est$n, rep(10, nrow(grid) * 2))
expect_false(identical(num_comp, expr(tune())))
expect_true(res_workflow$trained)
expect_null(.get_tune_eval_times(res))
})

# ------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-last-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ test_that("formula method", {
nrow(predict(res$.workflow[[1]], rsample::testing(split))),
nrow(rsample::testing(split))
)

expect_null(.get_tune_eval_times(res))

})

Expand Down
2 changes: 2 additions & 0 deletions tests/testthat/test-resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ test_that("`fit_resamples()` returns a `resample_result` object", {
expect_s3_class(result, "resample_results")

expect_equal(result, .Last.tune.result)

expect_null(.get_tune_eval_times(result))
})

test_that("can use `fit_resamples()` with a formula", {
Expand Down

0 comments on commit 8448353

Please sign in to comment.