Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored functions to select single evaluation times #778

Merged
merged 32 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a6fab99
move function to new file
topepo Dec 4, 2023
8c2aeb1
change function order for docs
topepo Dec 4, 2023
252db2b
documentation start
topepo Dec 4, 2023
f891bf0
updates to the show/select functions
topepo Dec 4, 2023
6790f6d
updates to select/show functions
topepo Dec 4, 2023
16caebe
updates for selecting eval times
topepo Dec 5, 2023
8a813c8
remove commented out code
topepo Dec 5, 2023
a340985
bug fix
topepo Dec 5, 2023
620d048
metric test cases
topepo Dec 5, 2023
a00c951
Merge branch 'new-metric-selections' into new-time-selections
topepo Dec 5, 2023
9f633d4
add a survival model object
topepo Dec 5, 2023
ff06ced
note for next PR
topepo Dec 5, 2023
1aea76b
select/show test cases
topepo Dec 5, 2023
70afeeb
small set of direct tests
topepo Dec 5, 2023
5bcbb7b
update snapshot
topepo Dec 5, 2023
0d5332b
Apply suggestions from code review
topepo Dec 5, 2023
538b343
Apply suggestions from code review
topepo Dec 5, 2023
1dc65ca
Merge branch 'new-metric-selections' into new-time-selections
topepo Dec 5, 2023
5580d7a
updates from previous review
topepo Dec 5, 2023
9afab6c
Merge branch 'main' into new-time-selections
topepo Dec 5, 2023
870ce13
small cli update
topepo Dec 5, 2023
0b366d3
doc update
topepo Dec 5, 2023
fbadcd9
refresh snapshots
topepo Dec 5, 2023
86b37fd
modularize a check
topepo Dec 5, 2023
6247b96
Remake with newest CRAN version of scales for #775
topepo Dec 5, 2023
972ffc2
Apply suggestions from code review
topepo Dec 6, 2023
55fc24e
Apply suggestions from code review
topepo Dec 6, 2023
c0066d3
add dot when function is invoked
topepo Dec 6, 2023
8ee1c98
add a warning for eval times with non-survival models
topepo Dec 6, 2023
f032da4
go back to enquos
topepo Dec 6, 2023
453c23a
rework warning text
topepo Dec 6, 2023
c377d3d
rework warning text pt 2
topepo Dec 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tune
Title: Tidy Tuning Tools
Version: 1.1.2.9001
Version: 1.1.2.9002
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ export(check_parameters)
export(check_rset)
export(check_time)
export(check_workflow)
export(choose_eval_time)
export(choose_metric)
export(collect_extracts)
export(collect_metrics)
Expand Down Expand Up @@ -184,6 +185,7 @@ export(extract_recipe)
export(extract_spec_parsnip)
export(extract_workflow)
export(filter_parameters)
export(filter_perf_metrics)
export(finalize_model)
export(finalize_recipe)
export(finalize_workflow)
Expand Down
83 changes: 82 additions & 1 deletion R/metric-selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#' @param metric A character value for which metric is being used.
#' @param eval_time An optional vector of times to compute dynamic and/or
#' integrated metrics.
#' @param call The execution environment of a currently running function.
topepo marked this conversation as resolved.
Show resolved Hide resolved
#' @description
#' These are developer-facing functions used to compute and validate choices
#' for performance metrics. For survival analysis models, there are similar
Expand All @@ -15,6 +16,14 @@
#' no value is given by the user, the first metric value is used (with a
#' warning).
#'
#' For evaluation times, one is only required when the metric type is dynamic
#' (e.g. [yardstick::brier_survival()] or [yardstick::roc_auc_survival()]). For
#' these metrics, we require a single numeric value that was originally given
#' to the function used to produce `x` (such as [tune_grid()]).
topepo marked this conversation as resolved.
Show resolved Hide resolved
#'
#' If a time is required and none is given, the first value in the vector
#' originally given in the `eval_time` argument is used (with a warning).
#'
#' @keywords internal
#' @export
choose_metric <- function(x, metric, ..., call = rlang::caller_env()) {
Expand Down Expand Up @@ -58,6 +67,47 @@ contains_survival_metric <- function(mtr_info) {
any(grepl("_survival", mtr_info$class))
}

#' @rdname choose_metric
#' @export
choose_eval_time <- function(x, metric, eval_time = NULL, call = rlang::caller_env()) {
mtr_set <- .get_tune_metrics(x)
topepo marked this conversation as resolved.
Show resolved Hide resolved
mtr_info <- tibble::as_tibble(mtr_set)

if (!contains_survival_metric(mtr_info)) {
return(NULL)
}
topepo marked this conversation as resolved.
Show resolved Hide resolved

# If we need an eval time, set it to the possible values so that
# we can choose the first value
if (is_dyn(mtr_set, metric) & is.null(eval_time)) {
topepo marked this conversation as resolved.
Show resolved Hide resolved
eval_time <- .get_tune_eval_times(x)
}

eval_time <- first_eval_time(mtr_set, metric = metric, eval_time = eval_time)

check_right_eval_time(x, eval_time, call = call)

eval_time
}

is_dyn <- function(mtr_set, metric) {
mtr_info <- tibble::as_tibble(mtr_set)
mtr_cls <- mtr_info$class[mtr_info$metric == metric]
mtr_cls == "dynamic_survival_metric"
}

check_right_eval_time <- function(x, eval_time, call = rlang::caller_env()) {
given_times <- .get_tune_eval_times(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
check_right_eval_time <- function(x, eval_time, call = rlang::caller_env()) {
given_times <- .get_tune_eval_times(x)
check_right_eval_time <- function(x, eval_time, ..., call = rlang::caller_env()) {
rlang::check_dots_empty()
given_times <- .get_tune_eval_times(x)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename the check_right_*() functions into something that says what "right" means here? Maybe check_*_in_tune_results()?

if (!is.null(eval_time)) {
if (!any(eval_time == given_times)) {
num_times <- length(given_times)
topepo marked this conversation as resolved.
Show resolved Hide resolved
print_time <- format(eval_time, digits = 3)
cli::cli_abort("Evaluation time {.val {print_time}} is not in the results.", call = call)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cli::cli_abort("Evaluation time {.val {print_time}} is not in the results.", call = call)
cli::cli_abort("Evaluation time {.val {print_time}} is not in the tuning results.", call = call)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might be using this in a slightly broader context than tuning (like simple resampling or last fit).

}
}
invisible(NULL)
}

# ------------------------------------------------------------------------------

#' @rdname choose_metric
Expand Down Expand Up @@ -99,8 +149,39 @@ first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) {
} else if ( num_times > 1 ) {
eval_time <- eval_time[1]
print_time <- format(eval_time, digits = 3)
cli::cli_warn("{num_times} evaluation times were available; the first ({print_time}) will be used.")
cli::cli_warn("{.val {num_times}} evaluation times were specified; the first ({print_time}) will be used.")
}

eval_time
}

# ------------------------------------------------------------------------------

#' @rdname choose_metric
#' @export
filter_perf_metrics <- function(x, metric, eval_time) {
topepo marked this conversation as resolved.
Show resolved Hide resolved
summary_res <- estimate_tune_results(x)
summary_res <- summary_res[summary_res$.metric == metric, ]
is_missing_mean <- is.na(summary_res$mean)
summary_res <- summary_res[!is_missing_mean, ]

if (!is.null(eval_time) && any(colnames(summary_res) == ".eval_time")) {
summary_res <- summary_res[summary_res$.eval_time == eval_time, ]
}
if (nrow(summary_res) == 0) {
cli::cli_abort("No results are available. Please use {.code collect_metrics()} to see if there were any issues.")
topepo marked this conversation as resolved.
Show resolved Hide resolved
}

summary_res
}

# TODO will be removed shortly
topepo marked this conversation as resolved.
Show resolved Hide resolved

middle_eval_time <- function(x) {
x <- x[!is.na(x)]
times <- unique(x)
med_time <- median(x, na.rm = TRUE)
ind <- which.min(abs(times - med_time))
eval_time <- times[ind]
eval_time
}
1 change: 1 addition & 0 deletions R/plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ get_param_label <- function(x, id_val) {
res
}

# TODO remove this.
default_eval_time <- function(eval_time, x, call = rlang::caller_env()) {
if (!any(names(x) == ".eval_time")) {
if (!is.null(eval_time)) {
Expand Down
156 changes: 47 additions & 109 deletions R/select_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,9 @@ show_best.tune_results <- function(x, metric = NULL, n = 5, eval_time = NULL, ..
metric_info <- choose_metric(x, metric)
metric <- metric_info$metric

summary_res <- estimate_tune_results(x)
eval_time <- choose_eval_time(x, metric, eval_time = eval_time)

# TODO
topepo marked this conversation as resolved.
Show resolved Hide resolved
metrics <- unique(summary_res$.metric)
if (length(metrics) == 1) {
metric <- metrics
}

# get estimates/summarise
summary_res <- summary_res %>% dplyr::filter(.metric == metric)

# TODO split selecting the req time and seeing if it is in the data
summary_res <- choose_eval_time(summary_res, x, eval_time)

if (nrow(summary_res) == 0) {
rlang::abort("No results are available. Please check the value of `metric`.")
}
summary_res <- filter_perf_metrics(x, metric, eval_time)

if (metric_info$direction == "maximize") {
summary_res <- summary_res %>% dplyr::arrange(dplyr::desc(mean))
Expand Down Expand Up @@ -155,30 +141,23 @@ select_by_pct_loss.tune_results <- function(x, ..., metric = NULL, limit = 2, ev

param_names <- .get_tune_parameter_names(x)

dots <- rlang::enquos(...)
if (length(dots) == 0) {
rlang::abort("Please choose at least one tuning parameter to sort in `...`.")
}
check_select_dots(..., call = rlang::caller_env())
topepo marked this conversation as resolved.
Show resolved Hide resolved
topepo marked this conversation as resolved.
Show resolved Hide resolved

res <-
collect_metrics(x) %>%
dplyr::filter(.metric == !!metric & !is.na(mean))
res <- choose_eval_time(res, x, eval_time)
eval_time <- choose_eval_time(x, metric, eval_time = eval_time)

summary_res <- filter_perf_metrics(x, metric, eval_time)

if (nrow(res) == 0) {
rlang::abort("No results are available. Please check the value of `metric`.")
}
if (metric_info$direction == "maximize") {
best_metric <- max(res$mean, na.rm = TRUE)
best_metric <- max(summary_res$mean, na.rm = TRUE)
} else if (metric_info$direction == "minimize") {
best_metric <- min(res$mean, na.rm = TRUE)
best_metric <- min(summary_res$mean, na.rm = TRUE)
} else if (metric_info$direction == "zero") {
which_min <- which.min(abs(res$mean))
best_metric <- res$mean[which_min]
which_min <- which.min(abs(summary_res$mean))
best_metric <- summary_res$mean[which_min]
}

res <-
res %>%
summary_res <-
summary_res %>%
dplyr::rowwise() %>%
dplyr::mutate(
.best = best_metric,
Expand All @@ -188,8 +167,10 @@ select_by_pct_loss.tune_results <- function(x, ..., metric = NULL, limit = 2, ev
) %>%
dplyr::ungroup()

res <- try(dplyr::arrange(res, !!!dots), silent = TRUE)
if (inherits(res, "try-error")) {

dots <- rlang::enquos(...)
summary_res <- try(dplyr::arrange(summary_res, !!!dots), silent = TRUE)
if (inherits(summary_res, "try-error")) {
var_nm <- rlang::eval_tidy(dots)
var_nm <- purrr::map_chr(var_nm, ~ rlang::quo_name(.x))
var_nm <- var_nm[!var_nm %in% colnames(collect_metrics(x))]
Expand All @@ -198,8 +179,8 @@ select_by_pct_loss.tune_results <- function(x, ..., metric = NULL, limit = 2, ev

# discard models more complex than the best and
# remove models with greater increase in loss than the limit
best_index <- which(res$.loss == 0)
res %>%
best_index <- which(summary_res$.loss == 0)
summary_res %>%
dplyr::slice(1:best_index) %>%
dplyr::filter(.loss < limit) %>%
dplyr::slice(1) %>%
Expand All @@ -226,49 +207,41 @@ select_by_one_std_err.tune_results <- function(x, ..., metric = NULL, eval_time

param_names <- .get_tune_parameter_names(x)

dots <- rlang::enquos(...)
if (length(dots) == 0) {
rlang::abort("Please choose at least one tuning parameter to sort in `...`.")
}
check_select_dots(..., call = rlang::caller_env())
topepo marked this conversation as resolved.
Show resolved Hide resolved

res <-
collect_metrics(x) %>%
dplyr::filter(.metric == !!metric & !is.na(mean))
res <- choose_eval_time(res, x, eval_time)
eval_time <- choose_eval_time(x, metric, eval_time = eval_time)

if (nrow(res) == 0) {
rlang::abort("No results are available. Please check the value of `metric`.")
}
summary_res <- filter_perf_metrics(x, metric, eval_time)

if (metric_info$direction == "maximize") {
best_index <- which.max(res$mean)
best <- res$mean[best_index]
bound <- best - res$std_err[best_index]
res <-
res %>%
best_index <- which.max(summary_res$mean)
best <- summary_res$mean[best_index]
bound <- best - summary_res$std_err[best_index]
summary_res <-
summary_res %>%
dplyr::mutate(
.best = best,
.bound = bound
) %>%
dplyr::filter(mean >= .bound)
} else if (metric_info$direction == "minimize") {
best_index <- which.min(res$mean)
best <- res$mean[best_index]
bound <- best + res$std_err[best_index]
res <-
res %>%
best_index <- which.min(summary_res$mean)
best <- summary_res$mean[best_index]
bound <- best + summary_res$std_err[best_index]
summary_res <-
summary_res %>%
dplyr::mutate(
.best = best,
.bound = bound
) %>%
dplyr::filter(mean <= .bound)
} else if (metric_info$direction == "zero") {
best_index <- which.min(abs(res$mean))
best <- res$mean[best_index]
bound_lower <- -abs(best) - res$std_err[best_index]
bound_upper <- abs(best) + res$std_err[best_index]
res <-
res %>%
best_index <- which.min(abs(summary_res$mean))
best <- summary_res$mean[best_index]
bound_lower <- -abs(best) - summary_res$std_err[best_index]
bound_upper <- abs(best) + summary_res$std_err[best_index]
summary_res <-
summary_res %>%
dplyr::rowwise() %>%
dplyr::mutate(
.best = best,
Expand All @@ -278,59 +251,24 @@ select_by_one_std_err.tune_results <- function(x, ..., metric = NULL, eval_time
dplyr::ungroup()
}

res <- try(dplyr::arrange(res, !!!dots), silent = TRUE)
if (inherits(res, "try-error")) {
dots <- rlang::enquos(...)
summary_res <- try(dplyr::arrange(summary_res, !!!dots), silent = TRUE)
if (inherits(summary_res, "try-error")) {
var_nm <- rlang::eval_tidy(dots)
var_nm <- purrr::map_chr(var_nm, ~ rlang::quo_name(.x))
var_nm <- var_nm[!var_nm %in% colnames(collect_metrics(x))]
cli::cli_abort("Could not sort results by {.var {var_nm}}.")
}
res %>%
summary_res %>%
dplyr::slice(1) %>%
dplyr::select(dplyr::all_of(param_names), .config)
}

middle_eval_time <- function(x) {
x <- x[!is.na(x)]
times <- unique(x)
med_time <- median(x, na.rm = TRUE)
ind <- which.min(abs(times - med_time))
eval_time <- times[ind]
eval_time
}

# NOTE this chooses the time and subsets the data; break it up to only select
# time
choose_eval_time <- function(x, object, eval_time) {
mtrs <- .get_tune_metrics(object)
mtrs <- tibble::as_tibble(mtrs)
actual_metrics <- unique(x$.metric)
mtrs <- mtrs[mtrs$metric %in% actual_metrics, ]

# Dynamic and integrated metrics need eval time as an input _but_
# only dynamic metrics need them as outputs. So if the metric like
# `brier_survival_integrated()` is used, the evaluation time doesn't need
# to be specified for computations that use the metrics.
if (!any(mtrs$class == "dynamic_survival_metric")) {
return(x)
}
# TODO maybe issue a warning if there are missing values
if (is.null(eval_time)) {
eval_time <- middle_eval_time(x$.eval_time)
msg <- cli::pluralize("No evaluation time was set; a value of {eval_time} was used.")
rlang::warn(msg)
} else {
if (length(eval_time) > 1) {
rlang::abort("Please pick a single evaluation time point.")
}
times <- unique(x$.eval_time)
if (!any(times == eval_time)) {
msg <- cli::pluralize("No evaluation times matched a value of {eval_time}.")
rlang::abort(msg)
}
check_select_dots <- function(..., call = rlang::caller_env()) {
dots <- rlang::enquos(...)
topepo marked this conversation as resolved.
Show resolved Hide resolved
if (length(dots) == 0) {
cli::cli_abort("Please choose at least one tuning parameter to sort in {.code ...}.",
call = call)
}
x <- x[x$.eval_time == eval_time, ]
x
invisible(NULL)
}


Loading
Loading