Skip to content

Commit

Permalink
Merge pull request #9 from hubverse-org/elr/compute_scores
Browse files Browse the repository at this point in the history
add score computation
  • Loading branch information
elray1 authored Jan 3, 2025
2 parents 4ed84f9 + 6b07281 commit 9448733
Show file tree
Hide file tree
Showing 9 changed files with 435 additions and 3 deletions.
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Imports:
cli,
dplyr (>= 1.1.0),
hubData,
hubEvals,
hubUtils,
jsonlite,
jsonvalidate,
Expand All @@ -20,13 +21,15 @@ Imports:
yaml
Remotes:
hubverse-org/hubData,
hubverse-org/hubEvals,
hubverse-org/hubUtils,
epiforecasts/scoringutils
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
Suggests:
testthat (>= 3.0.0)
testthat (>= 3.0.0),
withr
Config/testthat/edition: 3
URL: https://github.com/hubverse-org/hubPredEvalsData
BugReports: https://github.com/hubverse-org/hubPredEvalsData/issues
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Generated by roxygen2: do not edit by hand

export(generate_eval_data)
99 changes: 99 additions & 0 deletions R/generate_eval_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#' Generate evaluation data for a hub
#'
#' @param hub_path A path to the hub.
#' @param config_path A path to a yaml file that specifies the configuration
#' options for the evaluation.
#' @param out_path The directory to write the evaluation data to.
#' @param oracle_output A data frame of oracle output to use for the evaluation.
#'
#' @export
generate_eval_data <- function(hub_path,
config_path,
out_path,
oracle_output) {
config <- read_config(hub_path, config_path)
for (target in config$targets) {
generate_target_eval_data(hub_path, config, out_path, oracle_output, target)
}
}


#' Generate evaluation data for a target
#'
#' @inheritParams generate_eval_data
#' @param config The configuration options for the evaluation.
#' @param target The target to generate evaluation data for. This is one object
#' from the list of targets in the config, with properties "target_id",
#' "metrics", and "disaggregate_by".
#'
#' @noRd
generate_target_eval_data <- function(hub_path,
config,
out_path,
oracle_output,
target) {
target_id <- target$target_id
metrics <- target$metrics
# adding `NULL` at the beginning will calculate overall scores
disaggregate_by <- c(list(NULL), as.list(target$disaggregate_by))
eval_windows <- config$eval_windows

task_groups_w_target <- get_task_groups_w_target(hub_path, target_id)
metric_name_to_output_type <- get_metric_name_to_output_type(task_groups_w_target, metrics)

for (eval_window in eval_windows) {
model_out_tbl <- load_model_out_in_window(hub_path, target$target_id, eval_window)

# calculate overall scores followed by scores disaggregated by a task ID variable.
for (by in disaggregate_by) {
get_and_save_scores(
model_out_tbl = model_out_tbl,
oracle_output = oracle_output,
metric_name_to_output_type = metric_name_to_output_type,
target_id = target_id,
window_name = eval_window$window_name,
by = by,
out_path = out_path
)
}
}
}


#' Get and save scores for a target in a given evaluation window,
#' collecting across different output types as necessary.
#' Scores are saved in .csv files in subdirectorys under out_path with one of
#' two structures:
#' - If by is NULL, the scores are saved in
#' out_path/target_id/window_name/scores.csv
#' - If by is not NULL, the scores are saved in
#' out_path/target_id/window_name/by/scores.csv
#' @noRd
get_and_save_scores <- function(model_out_tbl, oracle_output, metric_name_to_output_type,
target_id, window_name, by,
out_path) {
# Iterate over the output types and calculate scores for each
scores <- purrr::map(
unique(metric_name_to_output_type$output_type),
~ hubEvals::score_model_out(
model_out_tbl = model_out_tbl |> dplyr::filter(output_type == !!.x),
oracle_output = oracle_output,
metrics = metric_name_to_output_type$metric[
metric_name_to_output_type$output_type == .x
],
by = c("model_id", by)
)
) |>
purrr::reduce(dplyr::left_join, by = c("model_id", by))

target_window_by_out_path <- file.path(out_path, target_id, window_name)
if (!is.null(by)) {
target_window_by_out_path <- file.path(target_window_by_out_path, by)
}
if (!dir.exists(target_window_by_out_path)) {
dir.create(target_window_by_out_path, recursive = TRUE)
}
utils::write.csv(scores,
file = file.path(target_window_by_out_path, "scores.csv"),
row.names = FALSE)
}
62 changes: 62 additions & 0 deletions R/utils-hub_tasks_config.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
#' Get the task groups from a hub's tasks config that contain a target_id, and
#' subset the target_metadata to just the entry for the target_id.
#'
#' @param hub_path A path to the hub.
#' @param target_id The target_id to filter to.
#'
#' @noRd
get_task_groups_w_target <- function(hub_path, target_id) {
hub_tasks_config <- hubUtils::read_config(hub_path, config = "tasks")
round_ids <- hubUtils::get_round_ids(hub_tasks_config)
task_groups <- hubUtils::get_round_model_tasks(hub_tasks_config, round_ids[1])
task_groups_w_target <- filter_task_groups_to_target(task_groups, target_id)
return(task_groups_w_target)
}


#' Filter task groups from a hub's tasks config to those that contain a target_id.
#' Additionally, subset the target_metadata to just the entry for the target_id.
#'
Expand Down Expand Up @@ -42,3 +58,49 @@ is_target_ordinal <- function(task_groups_w_target) {
target_type <- task_groups_w_target[[1]]$target_metadata[[1]]$target_type
return(target_type == "ordinal")
}


#' Get the output type id values for a given output type. The output type may
#' appear in multiple task groups, and the output type id values in those groups
#' may differ as long as there is one group that has all of the output type id
#' values.
#' @noRd
get_output_type_ids_for_type <- function(task_groups, output_type) {
output_type_ids_by_group <- purrr::map(
task_groups,
function(task_group) {
task_group$output_type[[output_type]]$output_type_id
}
)

# Small groups should contain subsets of the largest group, so this is our reference.
output_type_ids <- output_type_ids_by_group[[which.max(lengths(output_type_ids_by_group))]]

# check that the output type id values in each group are a (possibly improper)
# subset of the output type id values in the group with the most values, in
# the same order
for (i in seq_along(output_type_ids_by_group)) {
output_type_ids_group <- output_type_ids_by_group[[i]]

# if output_type_ids_group has any entries that are not in output_type_ids,
# raise an error
if (any(!output_type_ids_group %in% output_type_ids)) {
cli::cli_abort(
"In hub's tasks.json, output type ids for output type {.val {output_type}}
have different values across task groups."
)
}

# if output_type_ids and output_type_ids_group have some entries in common
# but order differs, raise an error
output_type_ids_subset <- output_type_ids[output_type_ids %in% output_type_ids_group]
if (!identical(output_type_ids_subset, output_type_ids_group)) {
cli::cli_abort(
"In hub's tasks.json, output type ids for output type {.val {output_type}}
have different order across task groups."
)
}
}

return(output_type_ids)
}
21 changes: 21 additions & 0 deletions man/generate_eval_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions tests/testthat/helper-expect_df_equal_up_to_order.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
#'
#' @param df_act The actual data frame
#' @param df_exp The expected data frame
expect_df_equal_up_to_order <- function(df_act, df_exp) {
expect_df_equal_up_to_order <- function(df_act, df_exp, ignore_attr = FALSE, ...) {
cols <- colnames(df_act)
testthat::expect_equal(cols, colnames(df_exp))
testthat::expect_equal(
dplyr::arrange(df_act, dplyr::across(dplyr::all_of(cols))),
dplyr::arrange(df_exp, dplyr::across(dplyr::all_of(cols))),
ignore_attr = FALSE
ignore_attr = ignore_attr,
...
)
}
101 changes: 101 additions & 0 deletions tests/testthat/test-generate_eval_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#' Helper function to check that the output files were created and have the expected contents
#' for one evaluation window.
#' @param out_path The path to the output directory where scores were saved.
#' @param window_name The name of the evaluation window.
#' @param model_out_tbl The model output table, filtered to data for the evaluation window.
#' @param oracle_output The oracle output.
check_exp_scores_for_window <- function(out_path, window_name, model_out_tbl, oracle_output) {
# check that the output files were created and have the expected contents
# no disaggregation
scores_path <- file.path(out_path, "wk inc flu hosp", window_name, "scores.csv")
testthat::expect_true(file.exists(scores_path))

actual_scores <- read.csv(scores_path)
expected_mean_scores <- hubEvals::score_model_out(
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "mean"),
oracle_output = oracle_output,
metrics = "se_point",
by = "model_id"
)
expected_median_scores <- hubEvals::score_model_out(
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "median"),
oracle_output = oracle_output,
metrics = "ae_point",
by = "model_id"
)
expected_quantile_scores <- hubEvals::score_model_out(
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "quantile"),
oracle_output = oracle_output,
metrics = c("wis", "ae_median", "interval_coverage_50", "interval_coverage_95"),
by = "model_id"
)
expected_scores <- expected_mean_scores |>
dplyr::left_join(expected_median_scores, by = "model_id") |>
dplyr::left_join(expected_quantile_scores, by = "model_id")
expect_df_equal_up_to_order(actual_scores, expected_scores, ignore_attr = TRUE) # nolint: object_usage_linter

for (by in c("location", "reference_date", "horizon", "target_end_date")) {
# check that the output files were created and have the expected contents
# disaggregated by `by`
scores_path <- file.path(out_path, "wk inc flu hosp", window_name, by, "scores.csv")
testthat::expect_true(file.exists(scores_path))

actual_scores <- read.csv(scores_path)
if (by %in% c("reference_date", "target_end_date")) {
actual_scores[[by]] <- as.Date(actual_scores[[by]])
}

expected_mean_scores <- hubEvals::score_model_out(
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "mean"),
oracle_output = oracle_output,
metrics = "se_point",
by = c("model_id", by)
)
expected_median_scores <- hubEvals::score_model_out(
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "median"),
oracle_output = oracle_output,
metrics = "ae_point",
by = c("model_id", by)
)
expected_quantile_scores <- hubEvals::score_model_out(
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == "quantile"),
oracle_output = oracle_output,
metrics = c("wis", "ae_median", "interval_coverage_50", "interval_coverage_95"),
by = c("model_id", by)
)
expected_scores <- expected_mean_scores |>
dplyr::left_join(expected_median_scores, by = c("model_id", by)) |>
dplyr::left_join(expected_quantile_scores, by = c("model_id", by))
expect_df_equal_up_to_order(actual_scores, expected_scores, ignore_attr = TRUE) # nolint: object_usage_linter
}
}

test_that(
"generate_eval_data works, integration test",
{
out_path <- withr::local_tempdir()
hub_path <- test_path("testdata", "ecfh")
model_out_tbl <- hubData::connect_hub(hub_path) |>
dplyr::collect()
oracle_output <- read.csv(
test_path("testdata", "ecfh", "target-data", "oracle-output.csv")
)
oracle_output[["target_end_date"]] <- as.Date(oracle_output[["target_end_date"]])

generate_eval_data(
hub_path = hub_path,
config_path = test_path("testdata", "test_configs", "config_valid_mean_median_quantile.yaml"),
out_path = out_path,
oracle_output = oracle_output
)

check_exp_scores_for_window(out_path,
"Full season",
model_out_tbl,
oracle_output)
check_exp_scores_for_window(out_path,
"Last 5 weeks",
model_out_tbl |> dplyr::filter(reference_date >= "2022-12-17"),
oracle_output)
}
)
Loading

0 comments on commit 9448733

Please sign in to comment.