Skip to content

Commit

Permalink
Refactor target samples calc into calc_samples_per_combo() helper
Browse files Browse the repository at this point in the history
  • Loading branch information
lshandross committed Dec 17, 2024
1 parent 6149832 commit 764a63e
Showing 1 changed file with 71 additions and 38 deletions.
109 changes: 71 additions & 38 deletions R/linear_pool_sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,51 +35,21 @@ linear_pool_sample <- function(model_out_tbl, weights = NULL,
)
weights_col_name <- "weight"
}
weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]
unique_weights <- unique(weights[[weights_col_name]])

if (length(unique_weights) != 1) {
cli::cli_abort("{.arg weights} must be {.val NULL} or equal for every model")
}

if (!is.null(n_output_samples)) {
samples_per_combo <- model_out_tbl |>
dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", compound_taskid_set)))) |>
dplyr::summarize(provided_samples = length(unique(.data[["output_type_id"]]))) |>
dplyr::ungroup() |>
tidyr::complete(
!!!rlang::syms(c("model_id", compound_taskid_set)),
fill = list(provided_samples = 0)
) |>
dplyr::left_join(weights, weight_by_cols) |>
dplyr::group_by(dplyr::across(dplyr::all_of(compound_taskid_set))) |>
dplyr::mutate(
effective_weight = as.integer(.data[["provided_samples"]] > 0) * .data[[weights_col_name]],
effective_weight = .data[["effective_weight"]] / sum(.data[["effective_weight"]]),
target_samples = floor(.data[["effective_weight"]] * n_output_samples)
) |>
dplyr::ungroup()


if (!is.null(compound_taskid_set)) {
samples_per_combo <- split(samples_per_combo, f = samples_per_combo[, compound_taskid_set])
} else {
samples_per_combo <- list(samples_per_combo)
}
# deal with n_output_samples not divisible evenly among component models
samples_per_combo <- samples_per_combo |>
purrr::map(.f = function(split_per_combo) {
actual_output_samples <- sum(split_per_combo$target_samples)
remainder_samples <- n_output_samples - actual_output_samples
valid_models <- split_per_combo$model_id[split_per_combo$provided_samples > 0]
models_to_resample <- sample(x = valid_models, size = remainder_samples)

split_per_combo |>
dplyr::mutate(target_samples = ifelse(
model_id %in% models_to_resample, .data[["target_samples"]] + 1, .data[["target_samples"]]
))
}) |>
purrr::list_rbind()
samples_per_combo <- calc_samples_per_combo(
model_out_tbl,
weights,
weights_col_name,
task_id_cols,
compound_taskid_set,
n_output_samples
)

if (any(samples_per_combo$provided_samples < samples_per_combo$target_samples)) {
cli::cli_abort("Requested output samples per compound unit cannot exceed the provided samples per compound unit.")
Expand Down Expand Up @@ -139,6 +109,69 @@ make_sample_indices_unique <- function(model_out_tbl) {
}


#' Helper function for computing the requested number of samples from each model for
#' every unique combination of compound task ID set variables when requesting a
#' linear pool of the `sample` output type.
#'
#' @inheritParams linear_pool
#'
#' @details The resulting output type ID values are character strings, generated by
#' a concatenation of the component model ID and initial output type ID, unless
#' the input `model_out_tbl` is detected to have a numeric `output_type_id`
#' column. In the latter case, a factor representation of this character string
#' is coerced to a numeric value.
#' @noRd
#'
#' @return a `data.frame` giving the number of provided and target samples for every
#' unique combination of compound task ID set variables for each model, with columns:
#' "model id", compound task id set variables, "provided_samples", "weight",
#' "effective_weight", and "target_samples".
#'
#' @importFrom rlang .data
calc_samples_per_combo <- function(model_out_tbl, weights,
weights_col_name,
task_id_cols,
compound_taskid_set,
n_output_samples) {
weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]
samples_per_combo <- model_out_tbl |>
dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", compound_taskid_set)))) |>
dplyr::summarize(provided_samples = length(unique(.data[["output_type_id"]]))) |>
dplyr::ungroup() |>
tidyr::complete(
!!!rlang::syms(c("model_id", compound_taskid_set)),
fill = list(provided_samples = 0)
) |>
dplyr::left_join(weights, weight_by_cols) |>
dplyr::group_by(dplyr::across(dplyr::all_of(compound_taskid_set))) |>
dplyr::mutate(
effective_weight = as.integer(.data[["provided_samples"]] > 0) * .data[[weights_col_name]],
effective_weight = .data[["effective_weight"]] / sum(.data[["effective_weight"]]),
target_samples = floor(.data[["effective_weight"]] * n_output_samples)
) |>
dplyr::ungroup()

if (!is.null(compound_taskid_set)) {
samples_per_combo <- split(samples_per_combo, f = samples_per_combo[, compound_taskid_set])
} else {
samples_per_combo <- list(samples_per_combo)
}
# deal with n_output_samples not divisible evenly among component models
samples_per_combo <- samples_per_combo |>
purrr::map(.f = function(split_per_combo) {
actual_output_samples <- sum(split_per_combo$target_samples)
remainder_samples <- n_output_samples - actual_output_samples
valid_models <- split_per_combo$model_id[split_per_combo$provided_samples > 0]
models_to_resample <- sample(x = valid_models, size = remainder_samples)

split_per_combo |>
dplyr::mutate(target_samples = ifelse(
.data[["model_id"]] %in% models_to_resample, .data[["target_samples"]] + 1, .data[["target_samples"]]
))
}) |>
purrr::list_rbind()
}

#' Perform simple validations on the inputs used to calculate a linear pool
#' of samples
#'
Expand Down

0 comments on commit 764a63e

Please sign in to comment.