Skip to content

Commit

Permalink
Merge pull request #1071 from tidymodels/extract_fit_time
Browse files Browse the repository at this point in the history
add extract_fit_time
  • Loading branch information
EmilHvitfeldt authored Apr 5, 2024
2 parents 707ed6c + 128b75d commit 63ced27
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 13 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Imports:
generics (>= 0.1.2),
glue,
gower,
hardhat (>= 1.3.0),
hardhat (>= 1.3.1.9000),
ipred (>= 0.9-12),
lifecycle (>= 1.0.3),
lubridate (>= 1.8.0),
Expand Down Expand Up @@ -71,3 +71,5 @@ Config/testthat/edition: 3
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
Remotes:
tidymodels/hardhat
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ S3method(bake,step_zv)
S3method(conditionMessage,recipes_error)
S3method(discretize,default)
S3method(discretize,numeric)
S3method(extract_fit_time,recipe)
S3method(extract_parameter_dials,recipe)
S3method(extract_parameter_set_dials,recipe)
S3method(fixed,Date)
Expand Down Expand Up @@ -567,6 +568,7 @@ export(dummy_extract_names)
export(dummy_names)
export(ellipse_check)
export(estimate_yj)
export(extract_fit_time)
export(extract_parameter_dials)
export(extract_parameter_set_dials)
export(fixed)
Expand Down Expand Up @@ -729,6 +731,7 @@ importFrom(generics,tunable)
importFrom(generics,tune_args)
importFrom(glue,glue)
importFrom(gower,gower_topn)
importFrom(hardhat,extract_fit_time)
importFrom(hardhat,extract_parameter_dials)
importFrom(hardhat,extract_parameter_set_dials)
importFrom(hardhat,frequency_weights)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# recipes (development version)

* New `extract_fit_time()` method has been added that returns the time it took to train the recipe. (#1071)

# recipes 1.0.10

## Bug Fixes
Expand Down
20 changes: 20 additions & 0 deletions R/extract_fit_time.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#' @export
extract_fit_time.recipe <- function(x, summarize = TRUE, ...) {
res <- x$fit_times

if (is.null(res)) {
cli::cli_abort(
"This recipe was created before {.fn recipes::extract_fit_time} was \\
added. Fit time cannot be extracted."
)
}

if (summarize) {
res <- tibble(
stage_id = "recipe",
elapsed = sum(res$elapsed)
)
}

res
}
14 changes: 14 additions & 0 deletions R/recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,8 @@ prep.recipe <-
x$term_info <- x$var_info
}

fit_times <- list()

running_info <- x$term_info %>% mutate(number = 0, skip = FALSE)

get_needs_tuning <- function(x) {
Expand Down Expand Up @@ -470,17 +472,28 @@ prep.recipe <-

# Compute anything needed for the preprocessing steps
# then apply it to the current training set
time <- proc.time()
x$steps[[i]] <- recipes_error_context(
prep(x$steps[[i]],
training = training,
info = x$term_info
),
step_name = step_name
)
prep_time <- proc.time() - time

time <- proc.time()
training <- recipes_error_context(
bake(x$steps[[i]], new_data = training),
step_name = step_name
)
bake_time <- proc.time() - time

fit_times[[i]] <- list(
stage_id = paste(c("prep", "bake"), x$steps[[i]]$id, sep = "."),
elapsed = c(prep_time[["elapsed"]], bake_time[["elapsed"]])
)

if (!is_tibble(training)) {
cli::cli_abort(c(
"x" = "{.fun bake} methods should always return tibbles.",
Expand Down Expand Up @@ -539,6 +552,7 @@ prep.recipe <-
x$levels <- lvls
x$orig_lvls <- orig_lvls
x$retained <- retain
x$fit_times <- dplyr::bind_rows(fit_times)
# In case a variable was removed, and that removal step used
# `skip = TRUE`, we need to retain its record so that
# selectors can be properly used with `bake`. This tibble
Expand Down
4 changes: 4 additions & 0 deletions R/reexports.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ hardhat::frequency_weights
#' @importFrom magrittr %>%
#' @export
magrittr::`%>%`

#' @importFrom hardhat extract_fit_time
#' @export
hardhat::extract_fit_time
3 changes: 2 additions & 1 deletion man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 11 additions & 11 deletions man/selections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions tests/testthat/_snaps/extract_fit_time.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# extract_fit_time() works

Code
extract_fit_time(rec)
Condition
Error in `extract_fit_time()`:
! This recipe was created before `recipes::extract_fit_time()` was added. Fit time cannot be extracted.

3 changes: 3 additions & 0 deletions tests/testthat/test-basics.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ test_that("prep with fresh = TRUE", {

new_rec <- prep(rec, training = test_data, fresh = TRUE)

rec$fit_times$elapsed <- 0
new_rec$fit_times$elapsed <- 0

expect_identical(rec, new_rec)

expect_equal(
Expand Down
32 changes: 32 additions & 0 deletions tests/testthat/test-extract_fit_time.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
test_that("extract_fit_time() works", {
rec <- recipe(mpg ~ ., data = mtcars) %>%
step_scale(all_numeric_predictors(), id = "scale") %>%
step_center(all_numeric_predictors(), id = "center") %>%
prep()

res <- extract_fit_time(rec)

expect_true(is_tibble(res))
expect_identical(names(res), c("stage_id", "elapsed"))
expect_identical(res$stage_id, "recipe")
expect_true(is.double(res$elapsed))
expect_true(res$elapsed >= 0)

res <- extract_fit_time(rec, summarize = FALSE)

expect_true(is_tibble(res))
expect_identical(names(res), c("stage_id", "elapsed"))
expect_identical(
res$stage_id,
c("prep.scale", "bake.scale", "prep.center", "bake.center")
)
expect_true(is.double(res$elapsed))
expect_true(all(res$elapsed >= 0))

rec$fit_times <- NULL

expect_snapshot(
error = TRUE,
extract_fit_time(rec)
)
})
3 changes: 3 additions & 0 deletions tests/testthat/test-interact.R
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,16 @@ test_that("works when formula is passed in as an object", {
step_interact(terms = cars_formula, id = "") %>%
prep()

rec1$fit_times$elapsed <- 0
rec2$fit_times$elapsed <- 0
expect_identical(rec1, rec2)

cars_formula <- ~ vs:am
rec3 <- recipe(~., data = mtcars) %>%
step_interact(terms = !!cars_formula, id = "") %>%
prep()

rec3$fit_times$elapsed <- 0
expect_identical(rec1, rec3)
})

Expand Down

0 comments on commit 63ced27

Please sign in to comment.