Skip to content

Commit

Permalink
rename target_observations to oracle_output (#62)
Browse files Browse the repository at this point in the history
* rename target_observations to oracle_output throughout documentation, primary code, examples, and unit tests

* avoid the use of tidyselect::

* Update R/score_model_out.R

Co-authored-by: Zhian N. Kamvar <[email protected]>

* remove unused csv files for tests

* shorter line length for doc string

---------

Co-authored-by: Zhian N. Kamvar <[email protected]>
  • Loading branch information
elray1 and zkamvar authored Dec 2, 2024
1 parent c7d48b6 commit fbd8700
Show file tree
Hide file tree
Showing 24 changed files with 18,404 additions and 624 deletions.
24 changes: 14 additions & 10 deletions R/score_model_out.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#' Scores model outputs with a single `output_type` against observed data.
#'
#' @param model_out_tbl Model output tibble with predictions
#' @param target_observations Observed 'ground truth' data to be compared to
#' predictions
#' @param oracle_output Predictions that would have been generated by an oracle
#' model that knew the observed target data values in advance
#' @param metrics Character vector of scoring metrics to compute. If `NULL`
#' (the default), appropriate metrics are chosen automatically. See details
#' for more.
Expand All @@ -18,6 +18,9 @@
#' all other output types, this is ignored.
#'
#' @details
#' See the hubverse documentation for the expected format of the
#' [oracle output data](https://hubverse.io/en/latest/user-guide/target-data.html#oracle-output).
#'
#' Default metrics are provided by the `scoringutils` package. You can select
#' metrics by passing in a character vector of metric names to the `metrics`
#' argument.
Expand Down Expand Up @@ -58,7 +61,7 @@
#' quantile_scores <- score_model_out(
#' model_out_tbl = hubExamples::forecast_outputs |>
#' dplyr::filter(.data[["output_type"]] == "quantile"),
#' target_observations = hubExamples::forecast_target_observations,
#' oracle_output = hubExamples::forecast_oracle_output,
#' metrics = c("wis", "interval_coverage_80", "interval_coverage_90"),
#' by = "model_id"
#' )
Expand All @@ -72,7 +75,7 @@
#' pmf_scores <- score_model_out(
#' model_out_tbl = hubExamples::forecast_outputs |>
#' dplyr::filter(.data[["output_type"]] == "pmf"),
#' target_observations = hubExamples::forecast_target_observations,
#' oracle_output = hubExamples::forecast_oracle_output,
#' metrics = "log_score",
#' by = c("model_id", "location", "horizon")
#' )
Expand All @@ -85,7 +88,7 @@
#' American Statistical Association 106 (494): 746–62. <doi: 10.1198/jasa.2011.r10138>.
#'
#' @export
score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
score_model_out <- function(model_out_tbl, oracle_output, metrics = NULL,
summarize = TRUE, by = "model_id",
output_type_id_order = NULL) {
# check that model_out_tbl has a single output_type that is supported by this package
Expand All @@ -94,10 +97,10 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,

# assemble data for scoringutils
su_data <- switch(output_type,
quantile = transform_quantile_model_out(model_out_tbl, target_observations),
pmf = transform_pmf_model_out(model_out_tbl, target_observations, output_type_id_order),
mean = transform_point_model_out(model_out_tbl, target_observations, output_type),
median = transform_point_model_out(model_out_tbl, target_observations, output_type),
quantile = transform_quantile_model_out(model_out_tbl, oracle_output),
pmf = transform_pmf_model_out(model_out_tbl, oracle_output, output_type_id_order),
mean = transform_point_model_out(model_out_tbl, oracle_output, output_type),
median = transform_point_model_out(model_out_tbl, oracle_output, output_type),
NULL # default, should not happen because of the validation above
)

Expand All @@ -122,7 +125,8 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
#' Get scoring metrics
#'
#' @param forecast A scoringutils `forecast` object (see
#' [scoringutils::as_forecast()] for details).
#' [scoringutils' general information on creating a `forecast` object][scoringutils::as_forecast_doc_template()]
#' for details).
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
Expand Down
24 changes: 13 additions & 11 deletions R/transform_pmf_model_out.R
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
#' Transform pmf model output into a forecast object
#'
#' @param model_out_tbl Model output tibble with predictions
#' @param target_observations Observed 'ground truth' data to be compared against predictions
#' @param oracle_output Predictions that would have been generated by an oracle
#' model that knew the observed target data values in advance
#' @param output_type_id_order For nominal variables, this should be `NULL` (the default).
#' For ordinal variables, this is a vector of levels for pmf forecasts, in
#' increasing order of the levels.
#'
#' @return forecast_quantile
#' @importFrom rlang .data
transform_pmf_model_out <- function(model_out_tbl, target_observations, output_type_id_order = NULL) {
model_out_tbl <- validate_model_out_target_obs(model_out_tbl, target_observations)
transform_pmf_model_out <- function(model_out_tbl, oracle_output, output_type_id_order = NULL) {
model_out_tbl <- validate_model_oracle_out(model_out_tbl, oracle_output)

# subset both model_out_tbl and target_observations to output_type == "pmf"
# subset both model_out_tbl and oracle_output to output_type == "pmf"
model_out_tbl <- model_out_tbl |>
dplyr::filter(.data[["output_type"]] == "pmf")

if (c("output_type") %in% colnames(target_observations)) {
target_observations <- target_observations |>
if (c("output_type") %in% colnames(oracle_output)) {
oracle_output <- oracle_output |>
dplyr::filter(.data[["output_type"]] == "pmf") |>
dplyr::select(-c("output_type"))
}

# validate or set output_type_id_order
if (!is.null(output_type_id_order)) {
cli::cli_abort(
"ordinal variables are not yet supported (we expect that they will be by the time we release this package)."
"ordinal variables are not yet supported."
)
is_ordinal <- TRUE
} else {
Expand All @@ -39,17 +40,18 @@ transform_pmf_model_out <- function(model_out_tbl, target_observations, output_t
dplyr::rename(model = "model_id")

data <- dplyr::left_join(
model_out_tbl, target_observations,
by = c(task_id_cols[task_id_cols %in% colnames(target_observations)], "output_type_id"),
model_out_tbl, oracle_output,
by = c(task_id_cols[task_id_cols %in% colnames(oracle_output)], "output_type_id"),
relationship = "many-to-one"
) |>
dplyr::group_by(dplyr::across(dplyr::all_of(c(task_id_cols, "model")))) |>
dplyr::mutate(
observation = .data[["output_type_id"]][.data[["observation"]] == 1],
observation = .data[["output_type_id"]][.data[["oracle_value"]] == 1],
observation = factor(.data[["observation"]], levels = output_type_id_order, ordered = is_ordinal),
output_type_id = factor(.data[["output_type_id"]], levels = output_type_id_order, ordered = is_ordinal)
) |>
dplyr::ungroup()
dplyr::ungroup() |>
dplyr::select(-dplyr::all_of("oracle_value"))

forecast_pmf <- scoringutils::as_forecast_nominal(
data,
Expand Down
18 changes: 9 additions & 9 deletions R/transform_point_model_out.R
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
#' Transform either mean or median model output into a point forecast object:
#'
#'
#' @param model_out_tbl Model output tibble with predictions
#' @param target_observations Observed 'ground truth' data to be compared against predictions
#' @param oracle_output Predictions that would have been generated by an oracle
#' model that knew the observed target data values in advance
#' @param output_type Forecast output type: "mean" or "median"
#'
#' @return forecast_point
#'
#' @details This function transforms a model output tibble in the Hubverse
#' format (with either "mean" or "median" output type) to a scoringutils "point"
#' forecast object
transform_point_model_out <- function(model_out_tbl, target_observations, output_type) {
transform_point_model_out <- function(model_out_tbl, oracle_output, output_type) {
if ((!inherits(output_type, "character")) || (!output_type %in% c("mean", "median"))) {
cli::cli_abort(
"invalid 'output_type': {.val {output_type}} Must be 'mean' or 'median'"
)
}

model_out_tbl <- validate_model_out_target_obs(model_out_tbl, target_observations)
model_out_tbl <- validate_model_oracle_out(model_out_tbl, oracle_output)

task_id_cols <- get_task_id_cols(model_out_tbl)
type <- output_type
Expand All @@ -26,20 +26,20 @@ transform_point_model_out <- function(model_out_tbl, target_observations, output
dplyr::filter(output_type == type) |>
dplyr::rename(model = "model_id")

if (c("output_type") %in% colnames(target_observations)) {
target_observations <- target_observations |>
if (c("output_type") %in% colnames(oracle_output)) {
oracle_output <- oracle_output |>
dplyr::filter(output_type == type) |>
dplyr::select(-c("output_type", "output_type_id"))
}

data <- dplyr::left_join(model_out_tbl, target_observations,
by = task_id_cols[task_id_cols %in% colnames(target_observations)],
data <- dplyr::left_join(model_out_tbl, oracle_output,
by = task_id_cols[task_id_cols %in% colnames(oracle_output)],
relationship = "many-to-one"
)

forecast_point <- scoringutils::as_forecast_point(data,
forecast_unit = c("model", task_id_cols),
observed = "observation",
observed = "oracle_value",
predicted = "value"
)

Expand Down
17 changes: 9 additions & 8 deletions R/transform_quantile_model_out.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#' Transform quantile model output into a forecast object
#'
#' @param model_out_tbl Model output tibble with predictions
#' @param target_observations Observed 'ground truth' data to be compared against predictions
#' @param oracle_output Predictions that would have been generated by an oracle
#' model that knew the observed target data values in advance
#'
#' @return forecast_quantile
transform_quantile_model_out <- function(model_out_tbl, target_observations) {
model_out_tbl <- validate_model_out_target_obs(model_out_tbl, target_observations)
transform_quantile_model_out <- function(model_out_tbl, oracle_output) {
model_out_tbl <- validate_model_oracle_out(model_out_tbl, oracle_output)

task_id_cols <- get_task_id_cols(model_out_tbl)

Expand All @@ -14,20 +15,20 @@ transform_quantile_model_out <- function(model_out_tbl, target_observations) {
dplyr::mutate(output_type_id = as.numeric(.data[["output_type_id"]])) |>
dplyr::rename(model = "model_id")

if (c("output_type") %in% colnames(target_observations)) {
target_observations <- target_observations |>
if (c("output_type") %in% colnames(oracle_output)) {
oracle_output <- oracle_output |>
dplyr::filter(.data[["output_type"]] == "quantile") |>
dplyr::select(-c("output_type", "output_type_id"))
}

data <- dplyr::left_join(model_out_tbl, target_observations,
by = task_id_cols[task_id_cols %in% colnames(target_observations)],
data <- dplyr::left_join(model_out_tbl, oracle_output,
by = task_id_cols[task_id_cols %in% colnames(oracle_output)],
relationship = "many-to-one"
)

forecast_quantile <- scoringutils::as_forecast_quantile(data,
forecast_unit = c("model", task_id_cols),
observed = "observation",
observed = "oracle_value",
predicted = "value",
quantile_level = "output_type_id"
)
Expand Down
25 changes: 13 additions & 12 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ get_task_id_cols <- function(model_out_tbl) {
return(task_id_cols)
}

#' Validate model_out_tbl and target_observations arguments to transform functions
#' Validate model_out_tbl and oracle_output arguments to transform functions
#'
#' @param model_out_tbl Model output tibble with predictions
#' @param target_observations Observed 'ground truth' data to be compared against predictions
#' @param oracle_output Predictions that would have been generated by an oracle
#' model that knew the observed target data values in advance
#'
#' @return The input `model_out_tbl`, possibly modified to ensure it has S3 class `model_out_tbl`
#'
#' @noRd
validate_model_out_target_obs <- function(model_out_tbl, target_observations) {
validate_model_oracle_out <- function(model_out_tbl, oracle_output) {
# check that: model_out_tbl contains columns: model_id, output_type, output_type_id, value
req_cols <- c("model_id", "output_type", "output_type_id", "value")
if (!all(req_cols %in% colnames(model_out_tbl))) {
Expand All @@ -35,28 +36,28 @@ validate_model_out_target_obs <- function(model_out_tbl, target_observations) {
)
}

# check that model_out_tbl and target_observations have compatible columns
# check that model_out_tbl and oracle_output have compatible columns
task_id_cols <- get_task_id_cols(model_out_tbl)
if (length(task_id_cols[task_id_cols %in% colnames(target_observations)]) == 0) {
if (length(task_id_cols[task_id_cols %in% colnames(oracle_output)]) == 0) {
cli::cli_abort(
"model_out_tbl and target_observations do not have compatible columns"
"model_out_tbl and oracle_output do not have compatible columns"
)
}
t_o_cols <- colnames(target_observations)
expected_cols_superset <- c(task_id_cols, "output_type", "output_type_id", "observation")
t_o_cols <- colnames(oracle_output)
expected_cols_superset <- c(task_id_cols, "output_type", "output_type_id", "oracle_value")
unexpected_cols <- t_o_cols[!t_o_cols %in% expected_cols_superset]
if (length(unexpected_cols) > 0) {
cli::cli_abort(
c(
"`target_observations` had {length(unexpected_cols)} unexpected column{?s} {.val {unexpected_cols}};",
" expected the columns of `target_observations` to be a subset of {.val {expected_cols_superset}}."
"`oracle_output` had {length(unexpected_cols)} unexpected column{?s} {.val {unexpected_cols}};",
" expected the columns of `oracle_output` to be a subset of {.val {expected_cols_superset}}."
)
)
}

if (!c("observation") %in% colnames(target_observations)) {
if (!c("oracle_value") %in% colnames(oracle_output)) {
cli::cli_abort(
"target_observations does not have observation column"
"oracle_output does not have oracle_value column"
)
}

Expand Down
18 changes: 10 additions & 8 deletions man/score_model_out.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions man/transform_pmf_model_out.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions man/transform_point_model_out.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit fbd8700

Please sign in to comment.