Skip to content

Commit

Permalink
Merge pull request #853 from tidymodels/extract_fit_time
Browse files Browse the repository at this point in the history
add extract_fit_time
EmilHvitfeldt authored Apr 5, 2024
2 parents eb526fa + 6d76a59 commit fd5124c
Showing 11 changed files with 96 additions and 38 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ Imports:
ggplot2,
globals,
glue,
hardhat (>= 1.1.0),
hardhat (>= 1.3.1.9000),
lifecycle,
magrittr,
pillar,
@@ -77,4 +77,6 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
Remotes:
tidymodels/hardhat
RoxygenNote: 7.3.1
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ S3method(check_args,svm_linear)
S3method(check_args,svm_poly)
S3method(check_args,svm_rbf)
S3method(extract_fit_engine,model_fit)
S3method(extract_fit_time,model_fit)
S3method(extract_parameter_dials,model_spec)
S3method(extract_parameter_set_dials,model_spec)
S3method(extract_spec_parsnip,model_fit)
@@ -222,6 +223,7 @@ export(discrim_quad)
export(discrim_regularized)
export(eval_args)
export(extract_fit_engine)
export(extract_fit_time)
export(extract_parameter_dials)
export(extract_parameter_set_dials)
export(extract_spec_parsnip)
@@ -376,6 +378,7 @@ importFrom(generics,varying_args)
importFrom(ggplot2,autoplot)
importFrom(glue,glue_collapse)
importFrom(hardhat,extract_fit_engine)
importFrom(hardhat,extract_fit_time)
importFrom(hardhat,extract_parameter_dials)
importFrom(hardhat,extract_parameter_set_dials)
importFrom(hardhat,extract_spec_parsnip)
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# parsnip (development version)

* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).

# parsnip 1.2.1

* Added a missing `tidy()` method for survival analysis glmnet models (#1086).
24 changes: 24 additions & 0 deletions R/extract.R
Original file line number Diff line number Diff line change
@@ -14,8 +14,15 @@
#'
#' - `extract_parameter_set_dials()` returns a set of dials parameter objects.
#'
#' - `extract_fit_time()` returns a tibble with fit times. The fit times
#' correspond to the time for the parsnip engine to fit and do not include
#' other portions of the elapsed time in [parsnip::fit.model_spec()].
#'
#' @param x A parsnip `model_fit` object or a parsnip `model_spec` object.
#' @param parameter A single string for the parameter ID.
#' @param summarize A logical for whether the elapsed fit time should be
#' returned as a single row or multiple rows. Doesn't support `FALSE` for
#' parsnip models.
#' @param ... Not currently used.
#' @details
#' Extracting the underlying engine fit can be helpful for describing the
@@ -127,3 +134,20 @@ eval_call_info <- function(x) {
extract_parameter_dials.model_spec <- function(x, parameter, ...) {
extract_parameter_dials(extract_parameter_set_dials(x), parameter)
}

#' @export
#' @rdname extract-parsnip
extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) {
elapsed <- x[["elapsed"]][["elapsed"]][["elapsed"]]

if (is.na(elapsed) || is.null(elapsed)) {
rlang::abort(
"This model was fit before `extract_fit_time()` was added."
)
}

dplyr::tibble(
stage_id = class(x$spec)[1],
elapsed = elapsed
)
}
11 changes: 9 additions & 2 deletions R/fit.R
Original file line number Diff line number Diff line change
@@ -453,8 +453,15 @@ allow_sparse <- function(x) {
#' @export
print.model_fit <- function(x, ...) {
cat("parsnip model object\n\n")
if (!is.na(x$elapsed[["elapsed"]])) {
cat("Fit time: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]), "\n")

if (is.null(x$elapsed$print) && !is.na(x$elapsed[["elapsed"]])) {
elapsed <- x$elapsed[["elapsed"]]
cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n")
}

if (isTRUE(x$elapsed$print)) {
elapsed <- x$elapsed$elapsed[["elapsed"]]
cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n")
}

if (inherits(x$fit, "try-error")) {
47 changes: 13 additions & 34 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
@@ -44,29 +44,19 @@ form_form <-
spec = object
)

if (control$verbosity > 1L) {
elapsed <- system.time(
res$fit <- eval_mod(
fit_call,
capture = control$verbosity == 0,
catch = control$catch,
envir = env,
...
),
gcFirst = FALSE
)
} else {
elapsed <- system.time(
res$fit <- eval_mod(
fit_call,
capture = control$verbosity == 0,
catch = control$catch,
envir = env,
...
)
elapsed <- list(elapsed = NA_real_)
}
),
gcFirst = FALSE
)
res$preproc <- list(y_var = all.vars(rlang::f_lhs(env$formula)))
res$elapsed <- elapsed
res$elapsed <- list(elapsed = elapsed, print = control$verbosity > 1L)

res
}

@@ -102,35 +92,24 @@ xy_xy <- function(object, env, control, target = "none", ...) {

res <- list(lvl = levels(env$y), spec = object)

if (control$verbosity > 1L) {
elapsed <- system.time(
res$fit <- eval_mod(
fit_call,
capture = control$verbosity == 0,
catch = control$catch,
envir = env,
...
),
gcFirst = FALSE
)
} else {
elapsed <- system.time(
res$fit <- eval_mod(
fit_call,
capture = control$verbosity == 0,
catch = control$catch,
envir = env,
...
)
elapsed <- list(elapsed = NA_real_)
}
),
gcFirst = FALSE
)

if (is.atomic(env$y)) {
y_name <- character(0)
} else {
y_name <- colnames(env$y)
}
res$preproc <- list(y_var = y_name)
res$elapsed <- elapsed
res$elapsed <- list(elapsed = elapsed, print = control$verbosity > 1L)
res
}

@@ -176,9 +155,9 @@ xy_form <- function(object, env, control, ...) {
check_outcome(env$y, object)

encoding_info <- get_encoding(class(object)[1])
encoding_info <-
encoding_info <-
vctrs::vec_slice(
encoding_info,
encoding_info,
encoding_info$mode == object$mode & encoding_info$engine == object$engine
)

3 changes: 3 additions & 0 deletions R/reexports.R
Original file line number Diff line number Diff line change
@@ -58,3 +58,6 @@ hardhat::frequency_weights
#' @export
hardhat::importance_weights

#' @importFrom hardhat extract_fit_time
#' @export
hardhat::extract_fit_time
10 changes: 10 additions & 0 deletions man/extract-parsnip.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions tests/testthat/_snaps/extract.md
Original file line number Diff line number Diff line change
@@ -18,3 +18,11 @@
i The parsnip extension package baguette implements support for this specification.
i Please install (if needed) and load to continue.

# extract_fit_time() works

Code
extract_fit_time(lm_fit)
Condition
Error in `extract_fit_time()`:
! This model was fit before `extract_fit_time()` was added.

19 changes: 19 additions & 0 deletions tests/testthat/test_extract.R
Original file line number Diff line number Diff line change
@@ -95,3 +95,22 @@ test_that("extract_parameter_dials doesn't error if namespaced args are used", {
NA
)
})

test_that("extract_fit_time() works", {
lm_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars)

res <- extract_fit_time(lm_fit)

expect_true(is_tibble(res))
expect_identical(names(res), c("stage_id", "elapsed"))
expect_identical(res$stage_id, "linear_reg")
expect_true(is.double(res$elapsed))
expect_true(res$elapsed >= 0)

lm_fit$elapsed$elapsed <- NULL

expect_snapshot(
error = TRUE,
extract_fit_time(lm_fit)
)
})

0 comments on commit fd5124c

Please sign in to comment.