-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from hubverse-org/elr/compute_scores
add score computation
- Loading branch information
Showing
9 changed files
with
435 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export(generate_eval_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#' Generate evaluation data for a hub | ||
#' | ||
#' @param hub_path A path to the hub. | ||
#' @param config_path A path to a yaml file that specifies the configuration | ||
#' options for the evaluation. | ||
#' @param out_path The directory to write the evaluation data to. | ||
#' @param oracle_output A data frame of oracle output to use for the evaluation. | ||
#' | ||
#' @export | ||
generate_eval_data <- function(hub_path, | ||
config_path, | ||
out_path, | ||
oracle_output) { | ||
config <- read_config(hub_path, config_path) | ||
for (target in config$targets) { | ||
generate_target_eval_data(hub_path, config, out_path, oracle_output, target) | ||
} | ||
} | ||
|
||
|
||
#' Generate evaluation data for a target | ||
#' | ||
#' @inheritParams generate_eval_data | ||
#' @param config The configuration options for the evaluation. | ||
#' @param target The target to generate evaluation data for. This is one object | ||
#' from the list of targets in the config, with properties "target_id", | ||
#' "metrics", and "disaggregate_by". | ||
#' | ||
#' @noRd | ||
generate_target_eval_data <- function(hub_path, | ||
config, | ||
out_path, | ||
oracle_output, | ||
target) { | ||
target_id <- target$target_id | ||
metrics <- target$metrics | ||
# adding `NULL` at the beginning will calculate overall scores | ||
disaggregate_by <- c(list(NULL), as.list(target$disaggregate_by)) | ||
eval_windows <- config$eval_windows | ||
|
||
task_groups_w_target <- get_task_groups_w_target(hub_path, target_id) | ||
metric_name_to_output_type <- get_metric_name_to_output_type(task_groups_w_target, metrics) | ||
|
||
for (eval_window in eval_windows) { | ||
model_out_tbl <- load_model_out_in_window(hub_path, target$target_id, eval_window) | ||
|
||
# calculate overall scores followed by scores disaggregated by a task ID variable. | ||
for (by in disaggregate_by) { | ||
get_and_save_scores( | ||
model_out_tbl = model_out_tbl, | ||
oracle_output = oracle_output, | ||
metric_name_to_output_type = metric_name_to_output_type, | ||
target_id = target_id, | ||
window_name = eval_window$window_name, | ||
by = by, | ||
out_path = out_path | ||
) | ||
} | ||
} | ||
} | ||
|
||
|
||
#' Get and save scores for a target in a given evaluation window, | ||
#' collecting across different output types as necessary. | ||
#' Scores are saved in .csv files in subdirectorys under out_path with one of | ||
#' two structures: | ||
#' - If by is NULL, the scores are saved in | ||
#' out_path/target_id/window_name/scores.csv | ||
#' - If by is not NULL, the scores are saved in | ||
#' out_path/target_id/window_name/by/scores.csv | ||
#' @noRd | ||
get_and_save_scores <- function(model_out_tbl, oracle_output, metric_name_to_output_type, | ||
target_id, window_name, by, | ||
out_path) { | ||
# Iterate over the output types and calculate scores for each | ||
scores <- purrr::map( | ||
unique(metric_name_to_output_type$output_type), | ||
~ hubEvals::score_model_out( | ||
model_out_tbl = model_out_tbl |> dplyr::filter(output_type == !!.x), | ||
oracle_output = oracle_output, | ||
metrics = metric_name_to_output_type$metric[ | ||
metric_name_to_output_type$output_type == .x | ||
], | ||
by = c("model_id", by) | ||
) | ||
) |> | ||
purrr::reduce(dplyr::left_join, by = c("model_id", by)) | ||
|
||
target_window_by_out_path <- file.path(out_path, target_id, window_name) | ||
if (!is.null(by)) { | ||
target_window_by_out_path <- file.path(target_window_by_out_path, by) | ||
} | ||
if (!dir.exists(target_window_by_out_path)) { | ||
dir.create(target_window_by_out_path, recursive = TRUE) | ||
} | ||
utils::write.csv(scores, | ||
file = file.path(target_window_by_out_path, "scores.csv"), | ||
row.names = FALSE) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
#' Helper function to check that the output files were created and have the expected contents | ||
#' for one evaluation window. | ||
#' @param out_path The path to the output directory where scores were saved. | ||
#' @param window_name The name of the evaluation window. | ||
#' @param model_out_tbl The model output table, filtered to data for the evaluation window. | ||
#' @param oracle_output The oracle output. | ||
check_exp_scores_for_window <- function(out_path, window_name, model_out_tbl, oracle_output) { | ||
# check that the output files were created and have the expected contents | ||
# no disaggregation | ||
scores_path <- file.path(out_path, "wk inc flu hosp", window_name, "scores.csv") | ||
testthat::expect_true(file.exists(scores_path)) | ||
|
||
actual_scores <- read.csv(scores_path) | ||
expected_mean_scores <- hubEvals::score_model_out( | ||
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "mean"), | ||
oracle_output = oracle_output, | ||
metrics = "se_point", | ||
by = "model_id" | ||
) | ||
expected_median_scores <- hubEvals::score_model_out( | ||
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "median"), | ||
oracle_output = oracle_output, | ||
metrics = "ae_point", | ||
by = "model_id" | ||
) | ||
expected_quantile_scores <- hubEvals::score_model_out( | ||
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = oracle_output, | ||
metrics = c("wis", "ae_median", "interval_coverage_50", "interval_coverage_95"), | ||
by = "model_id" | ||
) | ||
expected_scores <- expected_mean_scores |> | ||
dplyr::left_join(expected_median_scores, by = "model_id") |> | ||
dplyr::left_join(expected_quantile_scores, by = "model_id") | ||
expect_df_equal_up_to_order(actual_scores, expected_scores, ignore_attr = TRUE) # nolint: object_usage_linter | ||
|
||
for (by in c("location", "reference_date", "horizon", "target_end_date")) { | ||
# check that the output files were created and have the expected contents | ||
# disaggregated by `by` | ||
scores_path <- file.path(out_path, "wk inc flu hosp", window_name, by, "scores.csv") | ||
testthat::expect_true(file.exists(scores_path)) | ||
|
||
actual_scores <- read.csv(scores_path) | ||
if (by %in% c("reference_date", "target_end_date")) { | ||
actual_scores[[by]] <- as.Date(actual_scores[[by]]) | ||
} | ||
|
||
expected_mean_scores <- hubEvals::score_model_out( | ||
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "mean"), | ||
oracle_output = oracle_output, | ||
metrics = "se_point", | ||
by = c("model_id", by) | ||
) | ||
expected_median_scores <- hubEvals::score_model_out( | ||
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "median"), | ||
oracle_output = oracle_output, | ||
metrics = "ae_point", | ||
by = c("model_id", by) | ||
) | ||
expected_quantile_scores <- hubEvals::score_model_out( | ||
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = oracle_output, | ||
metrics = c("wis", "ae_median", "interval_coverage_50", "interval_coverage_95"), | ||
by = c("model_id", by) | ||
) | ||
expected_scores <- expected_mean_scores |> | ||
dplyr::left_join(expected_median_scores, by = c("model_id", by)) |> | ||
dplyr::left_join(expected_quantile_scores, by = c("model_id", by)) | ||
expect_df_equal_up_to_order(actual_scores, expected_scores, ignore_attr = TRUE) # nolint: object_usage_linter | ||
} | ||
} | ||
|
||
test_that( | ||
"generate_eval_data works, integration test", | ||
{ | ||
out_path <- withr::local_tempdir() | ||
hub_path <- test_path("testdata", "ecfh") | ||
model_out_tbl <- hubData::connect_hub(hub_path) |> | ||
dplyr::collect() | ||
oracle_output <- read.csv( | ||
test_path("testdata", "ecfh", "target-data", "oracle-output.csv") | ||
) | ||
oracle_output[["target_end_date"]] <- as.Date(oracle_output[["target_end_date"]]) | ||
|
||
generate_eval_data( | ||
hub_path = hub_path, | ||
config_path = test_path("testdata", "test_configs", "config_valid_mean_median_quantile.yaml"), | ||
out_path = out_path, | ||
oracle_output = oracle_output | ||
) | ||
|
||
check_exp_scores_for_window(out_path, | ||
"Full season", | ||
model_out_tbl, | ||
oracle_output) | ||
check_exp_scores_for_window(out_path, | ||
"Last 5 weeks", | ||
model_out_tbl |> dplyr::filter(reference_date >= "2022-12-17"), | ||
oracle_output) | ||
} | ||
) |
Oops, something went wrong.