Skip to content

Commit

Permalink
Refactor choose_metric (#777)
Browse files Browse the repository at this point in the history
* move function to new file

* change function order for docs

* documentation start

* updates to the show/select functions

* updates to select/show functions

* remove commented out code

* bug fix

* metric test cases

* Apply suggestions from code review

Co-authored-by: Hannah Frick <[email protected]>

* Apply suggestions from code review

* update tests

---------

Co-authored-by: Hannah Frick <[email protected]>
  • Loading branch information
topepo and hfrick authored Dec 5, 2023
1 parent 8d71b2c commit 1b67e42
Show file tree
Hide file tree
Showing 12 changed files with 264 additions and 239 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ export(check_parameters)
export(check_rset)
export(check_time)
export(check_workflow)
export(choose_metric)
export(collect_extracts)
export(collect_metrics)
export(collect_notes)
Expand Down
61 changes: 58 additions & 3 deletions R/metric-selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,69 @@
#' @param metric A character value for which metric is being used.
#' @param eval_time An optional vector of times to compute dynamic and/or
#' integrated metrics.
#' @description
#' These are developer-facing functions used to compute and validate choices
#' for performance metrics. For survival analysis models, there are similar
#' functions for the evaluation time(s) required for dynamic and/or integrated
#' metrics.
#'
#' `choose_metric()` is used with functions such as [show_best()] or
#' [select_best()] where a single valid metric is required to rank models. If
#' no value is given by the user, the first metric value is used (with a
#' warning).
#'
#' @keywords internal
#' @export
choose_metric <- function(x, metric, ..., call = rlang::caller_env()) {
rlang::check_dots_empty()

mtr_set <- .get_tune_metrics(x)
mtr_info <- tibble::as_tibble(mtr_set)

if (is.null(metric)) {
metric <- mtr_info$metric[1]
cli::cli_warn("No value of {.arg metric} was given; {.val {metric}} will be used.", call = call)
} else {
metric <- check_mult_metrics(metric, call = call)
check_right_metric(mtr_info, metric, call = call)
}

mtr_info[mtr_info$metric == metric,]
}

check_mult_metrics <- function(metric, ..., call = rlang::caller_env()) {
rlang::check_dots_empty()

num_metrics <- length(metric)
metric <- metric[1]
if (num_metrics > 1) {
cli::cli_warn("{num_metrics} metric{?s} were given; {.val {metric}} will be used.", call = call)
}
metric
}

check_right_metric <- function(mtr_info, metric, ..., call = rlang::caller_env()) {
rlang::check_dots_empty()

if (!any(mtr_info$metric == metric)) {
cli::cli_abort("{.val {metric}} was not in the metric set. Please choose from: {.val {mtr_info$metric}}.", call = call)
}
invisible(NULL)
}

contains_survival_metric <- function(mtr_info) {
any(grepl("_survival", mtr_info$class))
}

# ------------------------------------------------------------------------------

#' @rdname choose_metric
#' @export
first_metric <- function(mtr_set) {
tibble::as_tibble(mtr_set)[1,]
}

#' @rdname first_metric
#' @keywords internal
#' @rdname choose_metric
#' @export
first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) {
num_times <- length(eval_time)
Expand All @@ -25,7 +80,7 @@ first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) {
}

# Not a survival metric
if (!any(grepl("_survival_", mtr_info$class))) {
if (!contains_survival_metric(mtr_info)) {
return(NULL)
}

Expand Down
104 changes: 24 additions & 80 deletions R/select_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,17 @@ show_best.default <- function(x, ...) {
#' @export
#' @rdname show_best
show_best.tune_results <- function(x, metric = NULL, n = 5, eval_time = NULL, ...) {
# TODO should return the as_tibble(metric_set) results to get the class etc.
# TODO new function start
metric <- choose_metric(metric, x)
metric_info <- choose_metric(x, metric)
metric <- metric_info$metric

dots <- rlang::enquos(...)
if (!is.null(dots$maximize)) {
rlang::warn(paste(
"The `maximize` argument is no longer needed.",
"This value was ignored."
))
}
direction <- get_metric_direction(x, metric)
summary_res <- estimate_tune_results(x)

# TODO
metrics <- unique(summary_res$.metric)
if (length(metrics) == 1) {
metric <- metrics
}

# TODO new function stop

# get estimates/summarise
summary_res <- summary_res %>% dplyr::filter(.metric == metric)

Expand All @@ -106,32 +97,18 @@ show_best.tune_results <- function(x, metric = NULL, n = 5, eval_time = NULL, ..
rlang::abort("No results are available. Please check the value of `metric`.")
}

if (direction == "maximize") {
if (metric_info$direction == "maximize") {
summary_res <- summary_res %>% dplyr::arrange(dplyr::desc(mean))
} else if (direction == "minimize") {
} else if (metric_info$direction == "minimize") {
summary_res <- summary_res %>% dplyr::arrange(mean)
} else if (direction == "zero") {
} else if (metric_info$direction == "zero") {
summary_res <- summary_res %>% dplyr::arrange(abs(mean))
}
show_ind <- 1:min(nrow(summary_res), n)
summary_res %>%
dplyr::slice(show_ind)
}

choose_metric <- function(metric, x) {
if (is.null(metric)) {
metric_vals <- .get_tune_metric_names(x)
metric <- metric_vals[1]
if (length(metric_vals) > 1) {
msg <- paste0(
"No value of `metric` was given; metric '", metric, "' ",
"will be used."
)
rlang::warn(msg)
}
}
metric
}

#' @export
#' @rdname show_best
Expand All @@ -148,16 +125,11 @@ select_best.default <- function(x, ...) {
#' @export
#' @rdname show_best
select_best.tune_results <- function(x, metric = NULL, eval_time = NULL, ...) {
metric <- choose_metric(metric, x)
metric_info <- choose_metric(x, metric)
metric <- metric_info$metric

param_names <- .get_tune_parameter_names(x)

dots <- rlang::enquos(...)
if (!is.null(dots$maximize)) {
rlang::warn(paste(
"The `maximize` argument is no longer needed.",
"This value was ignored."
))
}
res <- show_best(x, metric = metric, n = 1, eval_time = eval_time)
res %>% dplyr::select(dplyr::all_of(param_names), .config)

Expand All @@ -178,22 +150,16 @@ select_by_pct_loss.default <- function(x, ...) {
#' @export
#' @rdname show_best
select_by_pct_loss.tune_results <- function(x, ..., metric = NULL, limit = 2, eval_time = NULL) {
metric <- choose_metric(metric, x)
metric_info <- choose_metric(x, metric)
metric <- metric_info$metric

param_names <- .get_tune_parameter_names(x)

dots <- rlang::enquos(...)
if (!is.null(dots$maximize)) {
rlang::warn(paste(
"The `maximize` argument is no longer needed.",
"This value was ignored."
))
dots[["maximize"]] <- NULL
}

if (length(dots) == 0) {
rlang::abort("Please choose at least one tuning parameter to sort in `...`.")
}
direction <- get_metric_direction(x, metric)

res <-
collect_metrics(x) %>%
dplyr::filter(.metric == !!metric & !is.na(mean))
Expand All @@ -202,11 +168,11 @@ select_by_pct_loss.tune_results <- function(x, ..., metric = NULL, limit = 2, ev
if (nrow(res) == 0) {
rlang::abort("No results are available. Please check the value of `metric`.")
}
if (direction == "maximize") {
if (metric_info$direction == "maximize") {
best_metric <- max(res$mean, na.rm = TRUE)
} else if (direction == "minimize") {
} else if (metric_info$direction == "minimize") {
best_metric <- min(res$mean, na.rm = TRUE)
} else if (direction == "zero") {
} else if (metric_info$direction == "zero") {
which_min <- which.min(abs(res$mean))
best_metric <- res$mean[which_min]
}
Expand Down Expand Up @@ -255,21 +221,16 @@ select_by_one_std_err.default <- function(x, ...) {
#' @export
#' @rdname show_best
select_by_one_std_err.tune_results <- function(x, ..., metric = NULL, eval_time = NULL) {
metric <- choose_metric(metric, x)
metric_info <- choose_metric(x, metric)
metric <- metric_info$metric

param_names <- .get_tune_parameter_names(x)

dots <- rlang::enquos(...)
if (!is.null(dots$maximize)) {
rlang::warn(paste(
"The `maximize` argument is no longer needed.",
"This value was ignored."
))
dots[["maximize"]] <- NULL
}
if (length(dots) == 0) {
rlang::abort("Please choose at least one tuning parameter to sort in `...`.")
}
direction <- get_metric_direction(x, metric)

res <-
collect_metrics(x) %>%
dplyr::filter(.metric == !!metric & !is.na(mean))
Expand All @@ -279,7 +240,7 @@ select_by_one_std_err.tune_results <- function(x, ..., metric = NULL, eval_time
rlang::abort("No results are available. Please check the value of `metric`.")
}

if (direction == "maximize") {
if (metric_info$direction == "maximize") {
best_index <- which.max(res$mean)
best <- res$mean[best_index]
bound <- best - res$std_err[best_index]
Expand All @@ -290,7 +251,7 @@ select_by_one_std_err.tune_results <- function(x, ..., metric = NULL, eval_time
.bound = bound
) %>%
dplyr::filter(mean >= .bound)
} else if (direction == "minimize") {
} else if (metric_info$direction == "minimize") {
best_index <- which.min(res$mean)
best <- res$mean[best_index]
bound <- best + res$std_err[best_index]
Expand All @@ -301,7 +262,7 @@ select_by_one_std_err.tune_results <- function(x, ..., metric = NULL, eval_time
.bound = bound
) %>%
dplyr::filter(mean <= .bound)
} else if (direction == "zero") {
} else if (metric_info$direction == "zero") {
best_index <- which.min(abs(res$mean))
best <- res$mean[best_index]
bound_lower <- -abs(best) - res$std_err[best_index]
Expand Down Expand Up @@ -329,23 +290,6 @@ select_by_one_std_err.tune_results <- function(x, ..., metric = NULL, eval_time
dplyr::select(dplyr::all_of(param_names), .config)
}

get_metric_direction <- function(x, metric) {
if (rlang::is_missing(metric) | length(metric) > 1) {
rlang::abort("Please specify a single character value for `metric`.")
}
attr_x <- attr(x, "metrics") %>%
attr("metrics")
if (!metric %in% names(attr_x)) {
rlang::abort("Please check the value of `metric`.")
}
directions <-
attr(x, "metrics") %>%
attr("metrics") %>%
purrr::map(~ attr(.x, "direction"))

directions[[metric]]
}

middle_eval_time <- function(x) {
x <- x[!is.na(x)]
times <- unique(x)
Expand Down
2 changes: 1 addition & 1 deletion inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ggplot
iteratively
misspecified
notag
parallelizes
parallelize
pre
preprocessed
preprocessor
Expand Down
34 changes: 34 additions & 0 deletions man/choose_metric.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 0 additions & 23 deletions man/first_metric.Rd

This file was deleted.

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/fit_best.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

---

Please check the value of `metric`.
"WAT" was not in the metric set. Please choose from: "rmse" and "rsq".

---

Expand Down
Loading

0 comments on commit 1b67e42

Please sign in to comment.