-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* remove snapshots of hubExamples data for tests, instead directly using hubExamples data objects * add functionality for relative metrics * update docs * don't use tidyr in test * fix typo * add some whitespace * Apply suggestions from code review Co-authored-by: Zhian N. Kamvar <[email protected]> * add comment about validations done by add_relative_skill * change place where we document validation of baseline * Apply suggestions from code review compare data frames directly Co-authored-by: Zhian N. Kamvar <[email protected]> * move get_pariwise_scores_by_loc to helper file * refactor expected outputs for pairwise score tests into a test fixture --------- Co-authored-by: Zhian N. Kamvar <[email protected]>
- Loading branch information
Showing
6 changed files
with
298 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
test_that("score_model_out succeeds with valid inputs: quantile output_type, relative wis and ae, no baseline", { | ||
# Forecast data from hubExamples: <https://hubverse-org.github.io/hubExamples/reference/forecast_data.html> | ||
forecast_outputs <- hubExamples::forecast_outputs | ||
forecast_oracle_output <- hubExamples::forecast_oracle_output | ||
|
||
act_scores <- score_model_out( | ||
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = forecast_oracle_output, | ||
metrics = c("ae_median", "wis", "interval_coverage_80", "interval_coverage_90"), | ||
relative_metrics = c("ae_median", "wis"), | ||
by = c("model_id", "location") | ||
) | ||
|
||
exp_scores <- read.csv(test_path("testdata", "exp_pairwise_scores.csv")) |> | ||
dplyr::mutate(location = as.character(location)) |> | ||
dplyr::select(-ae_median_scaled_relative_skill, -wis_scaled_relative_skill) | ||
|
||
expect_equal(act_scores, exp_scores, ignore_attr = TRUE) | ||
}) | ||
|
||
|
||
test_that("score_model_out succeeds with valid inputs: quantile output_type, relative wis and ae, Flusight-baseline", { | ||
# Forecast data from hubExamples: <https://hubverse-org.github.io/hubExamples/reference/forecast_data.html> | ||
forecast_outputs <- hubExamples::forecast_outputs | ||
forecast_oracle_output <- hubExamples::forecast_oracle_output | ||
|
||
act_scores <- score_model_out( | ||
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = forecast_oracle_output, | ||
metrics = c("ae_median", "wis", "interval_coverage_80", "interval_coverage_90"), | ||
relative_metrics = c("ae_median", "wis"), | ||
baseline = "Flusight-baseline", | ||
by = c("model_id", "location") | ||
) | ||
|
||
exp_scores <- read.csv(test_path("testdata", "exp_pairwise_scores.csv")) |> | ||
dplyr::mutate(location = as.character(location)) | ||
|
||
expect_equal(act_scores, exp_scores, ignore_attr = TRUE) | ||
}) | ||
|
||
|
||
test_that("score_model_out errors when invalid relative metrics are requested", { | ||
# Forecast data from hubExamples: <https://hubverse-org.github.io/hubExamples/reference/forecast_data.html> | ||
forecast_outputs <- hubExamples::forecast_outputs | ||
forecast_oracle_output <- hubExamples::forecast_oracle_output | ||
|
||
# not allowed to compute relative skill for interval coverage | ||
expect_error( | ||
score_model_out( | ||
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = forecast_oracle_output, | ||
metrics = c("wis", "interval_coverage_80", "interval_coverage_90"), | ||
relative_metrics = c("interval_coverage_90", "wis"), | ||
), | ||
regexp = "Interval coverage metrics are not supported for relative skill scores." | ||
) | ||
|
||
# relative_metrics must be a subset of metrics | ||
expect_error( | ||
score_model_out( | ||
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = forecast_oracle_output, | ||
metrics = c("wis", "interval_coverage_80", "interval_coverage_90"), | ||
relative_metrics = c("ae_median", "wis"), | ||
), | ||
regexp = "Relative metrics must be a subset of the metrics." | ||
) | ||
|
||
# can't ask for relative metrics without breaking down by model_id | ||
expect_error( | ||
score_model_out( | ||
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = forecast_oracle_output, | ||
metrics = c("wis", "interval_coverage_80", "interval_coverage_90"), | ||
relative_metrics = "wis", | ||
by = "location" | ||
), | ||
regexp = "Relative metrics require 'model_id' to be included in `by`." | ||
) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"model_id","location","ae_median","wis","interval_coverage_80","interval_coverage_90","ae_median_relative_skill","ae_median_scaled_relative_skill","wis_relative_skill","wis_scaled_relative_skill" | ||
"Flusight-baseline","25",210.75,179.944642857143,0,0.25,1.09843550132001,1,1.12612125846186,1 | ||
"Flusight-baseline","48",593,478.964285714286,0,0,1.12969637511783,1,1.15711681669639,1 | ||
"MOBS-GLEAM_FLUH","25",196.125,162.6875,0.5,0.5,1.02220955016079,0.930604982206406,1.01812340354839,0.904097490299596 | ||
"MOBS-GLEAM_FLUH","48",636.625,467.791071428571,0.5,0.625,1.2128043082789,1.07356661045531,1.13012375159286,0.976672134814704 | ||
"PSI-DICE","25",170.875,139.369642857143,0.625,0.625,0.890605771236332,0.81079478054567,0.872196666228429,0.774513987436612 | ||
"PSI-DICE","48",383.125,316.535714285714,0.375,0.375,0.729873395812849,0.646079258010118,0.764710039995535,0.660875400790396 |
127 changes: 127 additions & 0 deletions
127
tests/testthat/testdata/make_pairwise_score_test_data.R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
#' Geometric mean | ||
#' (x_1 \times x_2 \times \ldots \times x_n)^{1/n} | ||
#' = exp[1/n \sum_{i=1}^{n} log(x_i)] | ||
geometric_mean <- function(x) { | ||
exp(mean(log(x))) | ||
} | ||
|
||
|
||
#' Helper function manually computes pairwise relative skill scores by location | ||
#' Called from tests in test-score_model_out_rel_metrics.R | ||
get_pairwise_scores_by_loc <- function(scores_per_task, metric, baseline) { | ||
mean_scores_by_loc <- scores_per_task |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", "location")))) |> | ||
dplyr::summarize( | ||
mean_score = mean(.data[[metric]], na.rm = TRUE), # nolint: object_usage_linter | ||
.groups = "drop" | ||
) | ||
|
||
pairwise_score_ratios <- expand.grid( | ||
model_id = unique(mean_scores_by_loc$model_id), | ||
model_id_compare = unique(mean_scores_by_loc$model_id), | ||
location = unique(mean_scores_by_loc[["location"]]) | ||
) |> | ||
dplyr::left_join(mean_scores_by_loc, by = c("model_id" = "model_id", "location")) |> | ||
dplyr::left_join(mean_scores_by_loc, by = c("model_id_compare" = "model_id", "location")) |> | ||
dplyr::mutate( | ||
pairwise_score_ratio = .data[["mean_score.x"]] / .data[["mean_score.y"]] # nolint: object_usage_linter | ||
) | ||
|
||
result <- pairwise_score_ratios |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", "location")))) |> | ||
dplyr::summarize( | ||
relative_skill = geometric_mean(.data[["pairwise_score_ratio"]]), # nolint: object_usage_linter | ||
.groups = "drop" | ||
) |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of("location"))) |> | ||
dplyr::mutate( | ||
scaled_relative_skill = .data[["relative_skill"]] / | ||
.data[["relative_skill"]][.data[["model_id"]] == baseline] | ||
) | ||
|
||
colnames(result) <- c("model_id", "location", | ||
paste0(metric, "_relative_skill"), | ||
paste0(metric, "_scaled_relative_skill")) | ||
|
||
return(result) | ||
} | ||
|
||
|
||
# Forecast data from hubExamples: <https://hubverse-org.github.io/hubExamples/reference/forecast_data.html> | ||
forecast_outputs <- hubExamples::forecast_outputs | ||
forecast_oracle_output <- hubExamples::forecast_oracle_output | ||
|
||
# expected scores | ||
exp_scores_unsummarized <- forecast_outputs |> | ||
dplyr::filter(.data[["output_type"]] == "quantile") |> | ||
dplyr::left_join( | ||
forecast_oracle_output |> | ||
dplyr::filter(.data[["output_type"]] == "quantile") |> | ||
dplyr::select(-dplyr::all_of(c("output_type", "output_type_id"))), | ||
by = c("location", "target_end_date", "target") | ||
) |> | ||
dplyr::mutate( | ||
output_type_id = as.numeric(.data[["output_type_id"]]), | ||
qs = ifelse( | ||
.data[["oracle_value"]] >= .data[["value"]], | ||
.data[["output_type_id"]] * (.data[["oracle_value"]] - .data[["value"]]), | ||
(1 - .data[["output_type_id"]]) * (.data[["value"]] - .data[["oracle_value"]]) | ||
), | ||
q_coverage_80_lower = ifelse( | ||
.data[["output_type_id"]] == 0.1, | ||
.data[["oracle_value"]] >= .data[["value"]], | ||
NA_real_ | ||
), | ||
q_coverage_80_upper = ifelse( | ||
.data[["output_type_id"]] == 0.9, | ||
.data[["oracle_value"]] <= .data[["value"]], | ||
NA_real_ | ||
), | ||
q_coverage_90_lower = ifelse( | ||
.data[["output_type_id"]] == 0.05, | ||
.data[["oracle_value"]] >= .data[["value"]], | ||
NA_real_ | ||
), | ||
q_coverage_90_upper = ifelse( | ||
.data[["output_type_id"]] == 0.95, | ||
.data[["oracle_value"]] <= .data[["value"]], | ||
NA_real_ | ||
) | ||
) |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of( | ||
c("model_id", "location", "reference_date", "horizon", "target_end_date", "target") | ||
))) |> | ||
dplyr::summarize( | ||
ae_median = sum(ifelse( | ||
.data[["output_type_id"]] == 0.5, | ||
abs(.data[["oracle_value"]] - .data[["value"]]), | ||
0 | ||
)), | ||
wis = 2 * mean(.data[["qs"]]), | ||
interval_coverage_80 = (sum(.data[["q_coverage_80_lower"]], na.rm = TRUE) == 1) * | ||
(sum(.data[["q_coverage_80_upper"]], na.rm = TRUE) == 1), | ||
interval_coverage_90 = (sum(.data[["q_coverage_90_lower"]], na.rm = TRUE) == 1) * | ||
(sum(.data[["q_coverage_90_upper"]], na.rm = TRUE) == 1) | ||
) | ||
|
||
exp_scores_standard <- exp_scores_unsummarized |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of( | ||
c("model_id", "location") | ||
))) |> | ||
dplyr::summarize( | ||
ae_median = mean(.data[["ae_median"]]), | ||
wis = mean(.data[["wis"]]), | ||
interval_coverage_80 = mean(.data[["interval_coverage_80"]], na.rm = TRUE), | ||
interval_coverage_90 = mean(.data[["interval_coverage_90"]], na.rm = TRUE), | ||
.groups = "drop" | ||
) | ||
|
||
# add pairwise relative scores for ae_median and wis | ||
exp_scores_relative_ae_median <- get_pairwise_scores_by_loc(exp_scores_unsummarized, "ae_median", "Flusight-baseline") | ||
exp_scores_relative_wis <- get_pairwise_scores_by_loc(exp_scores_unsummarized, "wis", "Flusight-baseline") | ||
exp_scores <- exp_scores_standard |> | ||
dplyr::full_join(exp_scores_relative_ae_median, by = c("model_id", "location")) |> | ||
dplyr::full_join(exp_scores_relative_wis, by = c("model_id", "location")) | ||
|
||
# save | ||
write.csv(exp_scores, testthat::test_path("testdata", "exp_pairwise_scores.csv"), row.names = FALSE) |