Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transition from add_tailor(prop) and method to fit.workflow(calibration) #262

Merged
merged 5 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
#' @param object A workflow
#'
#' @param data A data frame of predictors and outcomes to use when fitting the
#' workflow
#' preprocessor and model.
#'
#' @param ... Not used
#'
#' @param calibration A data frame of predictors and outcomes to use when
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should adapt the name slightly to make it more obvious that this the data for calibration, rather than, say, the method. data_calibration, calibration_data, calibration_set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I don't have strong preferences between any of these options, but agree that calibration by itself could be ambiguous.

#' fitting the postprocessor. See the "Data Usage" section of [add_tailor()]
#' for more information.
#'
#' @param control A [control_workflow()] object
#'
#' @return
Expand Down Expand Up @@ -51,23 +55,27 @@
#' add_recipe(recipe)
#'
#' fit(recipe_wf, mtcars)
fit.workflow <- function(object, data, ..., control = control_workflow()) {
fit.workflow <- function(object, data, ..., calibration = NULL, control = control_workflow()) {
check_dots_empty()

if (is_missing(data)) {
cli_abort("{.arg data} must be provided to fit a workflow.")
}

validate_has_calibration(object, calibration)

if (is_sparse_matrix(data)) {
data <- sparsevctrs::coerce_to_sparse_tibble(data)
}

workflow <- object
workflow <- .fit_pre(workflow, data)
workflow <- .fit_model(workflow, control)
# if (has_postprocessor(workflow)) {
# workflow <- .fit_post(workflow, calibration)
# }
if (has_postprocessor(workflow)) {
# if (is.null(calibration)), then the tailor doesn't have a calibrator
# and training the tailor on `data` will not leak data
workflow <- .fit_post(workflow, calibration %||% data)
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
}
workflow <- .fit_finalize(workflow)

workflow
Expand Down Expand Up @@ -218,6 +226,27 @@ validate_has_model <- function(x, ..., call = caller_env()) {
invisible(x)
}

validate_has_calibration <- function(x, calibration,
hfrick marked this conversation as resolved.
Show resolved Hide resolved
x_arg = caller_arg(x), call = caller_env()) {
hfrick marked this conversation as resolved.
Show resolved Hide resolved
if (.should_inner_split(x) && is.null(calibration)) {
cli::cli_abort(
"{.arg {x_arg}} requires a {.arg calibration} set to train but none
was supplied.",
call = call
)
}

if (!.should_inner_split(x) && !is.null(calibration)) {
cli::cli_abort(
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
"{.arg {x_arg}} does not require a {.arg calibration} set to train
but one was supplied.",
call = call
)
}

invisible(x)
}

# ------------------------------------------------------------------------------

finalize_blueprint <- function(workflow) {
Expand Down
68 changes: 19 additions & 49 deletions R/post-action-tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,16 @@
#' should not have been trained already with [tailor::fit()]; workflows
#' will handle training internally.
#'
#' @param prop The proportion of the data in [fit.workflow()] that should be
#' held back specifically for estimating the postprocessor. Only relevant for
#' postprocessors that require estimation---see section Data Usage below to
#' learn more. Defaults to 1/3.
#'
#' @param method The method with which to split the data in [fit.workflow()],
#' as a character vector. Only relevant for postprocessors that
#' require estimation and not required when resampling the workflow with
#' tune. If `fit.workflow(data)` arose as `training(split_object)`, this argument can
#' usually be supplied as `class(split_object)`. Defaults to `"mc_split"`, which
#' randomly samples `fit.workflow(data)` into two sets, similarly to
#' [rsample::initial_split()]. See section Data Usage below to learn more.
#'
#' @section Data Usage:
#'
#' While preprocessors and models are trained on data in the usual sense,
#' postprocessors are training on _predictions_ on data. When a workflow
#' is fitted, the user supplies training data with the `data` argument.
#' is fitted, the user typically supplies training data with the `data` argument.
#' When workflows don't contain a postprocessor that requires training,
#' they can use all of the supplied `data` to train the preprocessor and model.
#' However, in the case where a postprocessor must be trained as well,
#' training the preprocessor and model on all of `data` would leave no data
#' users can pass all of the available data to the `data` argument to train the
#' preprocessor and model. However, in the case where a postprocessor must be
#' trained as well, allotting all of the available data to the `data` argument
#' to train the preprocessor and model would leave no data
#' left to train the postprocessor with---if that were the case, workflows
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
#' would need to `predict()` from the preprocessor and model on the same `data`
#' that they were trained on, with the postprocessor then training on those
Expand All @@ -49,22 +37,15 @@
#' is passed to that trained postprocessor and model to generate predictions,
#' which then form the training data for the postprocessor.
#'
#' The arguments `prop` and `method` parameterize how that data is split up.
#' `prop` determines the proportion of rows in `fit.workflow(data)` that are
#' allotted to training the preprocessor and model, while the rest are used to
#' train the postprocessor. `method` determines how that split occurs; since
#' `fit.workflow()` just takes in a data frame, the function doesn't have
#' any information on how that dataset came to be. For example, `data` could
#' have been created as:
#'
#' ```
#' split <- rsample::initial_split(some_other_data)
#' data <- rsample::training(split)
#' ```
#' When fitting a workflow with a postprocessor that requires training
#' (i.e. one that returns `TRUE` in `.should_inner_split(workflow)`), users
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are making "inner split" public-facing terminology, we should think about the name one more time. I'll open an rsample issue for that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It lives here now: tidymodels/rsample#553

#' must pass two data arguments--the usual `fit.workflow(data)` will be used
#' to train the preprocessor and model while `fit.workflow(calibration)` will
#' be used to train the postprocessor.
#'
#' ...in which case it's okay to randomly allot some rows of `data` to train the
#' preprocessor and model and the rest to train the postprocessor. However,
#' `data` could also have arisen as:
#' In some situations, randomly splitting `fit.workflow(data)` (with
#' [rsample::initial_split()], for example) is sufficient to prevent data
#' leakage. However, `fit.workflow(data)` could also have arisen as:
#'
#' ```
#' boots <- rsample::bootstraps(some_other_data)
Expand All @@ -78,8 +59,9 @@
#' datasets, resulting in the preprocessor and model generating predictions on
#' rows they've seen before. Similarly problematic situations could arise in the
#' context of other resampling situations, like time-based splits.
#' The `method` argument ensures that data is allotted properly (and is
#' internally handled by the tune package when resampling workflows).
#' In general, use the [rsample::inner_split()] function to prevent data
#' leakage when resampling; when workflows with postprocessors that require
#' training are passed to the tune package, this is handled internally.
#'
#' @param ... Not used.
#'
Expand All @@ -102,14 +84,11 @@
#' remove_tailor(workflow)
#'
#' update_tailor(workflow, adjust_probability_threshold(tailor, .2))
add_tailor <- function(x, tailor, prop = NULL, method = NULL, ...) {
add_tailor <- function(x, tailor, ...) {
check_dots_empty()
validate_tailor_available()
action <- new_action_tailor(tailor, prop = prop, method = method)
action <- new_action_tailor(tailor)
res <- add_action(x, action, "tailor")
if (.should_inner_split(res)) {
validate_rsample_available()
}
res
}

Expand Down Expand Up @@ -185,7 +164,7 @@ mock_trained_workflow <- function(workflow) {

# ------------------------------------------------------------------------------

new_action_tailor <- function(tailor, prop, method, ..., call = caller_env()) {
new_action_tailor <- function(tailor, ..., call = caller_env()) {
check_dots_empty()

if (!is_tailor(tailor)) {
Expand All @@ -196,17 +175,8 @@ new_action_tailor <- function(tailor, prop, method, ..., call = caller_env()) {
cli_abort("Can't add a trained tailor to a workflow.", call = call)
}

if (!is.null(prop) &&
(!rlang::is_double(prop, n = 1) || prop <= 0 || prop >= 1)) {
cli_abort("{.arg prop} must be a numeric on (0, 1).", call = call)
}

# todo: test method

new_action_post(
tailor = tailor,
prop = prop,
method = method,
subclass = "action_tailor"
)
}
Expand Down
51 changes: 17 additions & 34 deletions man/add_tailor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions man/fit-workflow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/_snaps/fit.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@
! The workflow must have a model.
i Provide one with `add_model()`.

# fit.workflow confirms compatibility of object and calibration

Code
fit(workflow, mtcars, calibration = mtcars)
Condition
Error in `fit()`:
! `object` does not require a `calibration` set to train but one was supplied.
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved

---

Code
fit(workflow, mtcars)
Condition
Error in `fit()`:
! `object` requires a `calibration` set to train but none was supplied.
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved

# can `predict()` from workflow fit from individual pieces

Code
Expand Down
23 changes: 23 additions & 0 deletions tests/testthat/test-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ test_that("cannot fit without a fit stage", {
})
})

test_that("fit.workflow confirms compatibility of object and calibration", {
skip_if_not_installed("tailor")

mod <- parsnip::linear_reg()
mod <- parsnip::set_engine(mod, "lm")

workflow <- workflow()
workflow <- add_formula(workflow, mpg ~ cyl)
workflow <- add_model(workflow, mod)

expect_snapshot(error = TRUE, {
fit(workflow, mtcars, calibration = mtcars)
})

tailor <- tailor::tailor()
tailor <- tailor::adjust_numeric_calibration(tailor)
workflow <- add_tailor(workflow, tailor)

expect_snapshot(error = TRUE, {
fit(workflow, mtcars)
})
})

# ------------------------------------------------------------------------------
# .fit_pre()

Expand Down
Loading