diff --git a/DESCRIPTION b/DESCRIPTION index be3ef41bf..dc7494571 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,7 +25,7 @@ Imports: ggplot2, globals, glue, - hardhat (>= 1.1.0), + hardhat (>= 1.3.1.9000), lifecycle, magrittr, pillar, @@ -77,4 +77,6 @@ Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) +Remotes: + tidymodels/hardhat RoxygenNote: 7.3.1 diff --git a/NAMESPACE b/NAMESPACE index fa9511099..e37fbcbac 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -29,6 +29,7 @@ S3method(check_args,svm_linear) S3method(check_args,svm_poly) S3method(check_args,svm_rbf) S3method(extract_fit_engine,model_fit) +S3method(extract_fit_time,model_fit) S3method(extract_parameter_dials,model_spec) S3method(extract_parameter_set_dials,model_spec) S3method(extract_spec_parsnip,model_fit) @@ -222,6 +223,7 @@ export(discrim_quad) export(discrim_regularized) export(eval_args) export(extract_fit_engine) +export(extract_fit_time) export(extract_parameter_dials) export(extract_parameter_set_dials) export(extract_spec_parsnip) @@ -375,6 +377,7 @@ importFrom(generics,varying_args) importFrom(ggplot2,autoplot) importFrom(glue,glue_collapse) importFrom(hardhat,extract_fit_engine) +importFrom(hardhat,extract_fit_time) importFrom(hardhat,extract_parameter_dials) importFrom(hardhat,extract_parameter_set_dials) importFrom(hardhat,extract_spec_parsnip) diff --git a/NEWS.md b/NEWS.md index 00f246bdf..1761217f9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # parsnip (development version) +* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853). + # parsnip 1.2.1 * Added a missing `tidy()` method for survival analysis glmnet models (#1086). diff --git a/R/extract.R b/R/extract.R index 4dbe4b0b7..d85f05d3f 100644 --- a/R/extract.R +++ b/R/extract.R @@ -14,8 +14,15 @@ #' #' - `extract_parameter_set_dials()` returns a set of dials parameter objects. #' +#' - `extract_fit_time()` returns a tibble with fit times. The fit times +#' correspond to the time for the parsnip engine to fit and do not include +#' other portions of the elapsed time in [parsnip::fit.model_spec()]. +#' #' @param x A parsnip `model_fit` object or a parsnip `model_spec` object. #' @param parameter A single string for the parameter ID. +#' @param summarize A logical for whether the elapsed fit time should be +#' returned as a single row or multiple rows. Doesn't support `FALSE` for +#' parsnip models. #' @param ... Not currently used. #' @details #' Extracting the underlying engine fit can be helpful for describing the @@ -127,3 +134,20 @@ eval_call_info <- function(x) { extract_parameter_dials.model_spec <- function(x, parameter, ...) { extract_parameter_dials(extract_parameter_set_dials(x), parameter) } + +#' @export +#' @rdname extract-parsnip +extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) { + elapsed <- x[["elapsed"]][["elapsed"]][["elapsed"]] + + if (is.na(elapsed) || is.null(elapsed)) { + rlang::abort( + "This model was fit before `extract_fit_time()` was added." + ) + } + + dplyr::tibble( + stage_id = class(x$spec)[1], + elapsed = elapsed + ) +} diff --git a/R/fit.R b/R/fit.R index 21dc461b7..288fc64a2 100644 --- a/R/fit.R +++ b/R/fit.R @@ -453,8 +453,15 @@ allow_sparse <- function(x) { #' @export print.model_fit <- function(x, ...) { cat("parsnip model object\n\n") - if (!is.na(x$elapsed[["elapsed"]])) { - cat("Fit time: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]), "\n") + + if (is.null(x$elapsed$print) && !is.na(x$elapsed[["elapsed"]])) { + elapsed <- x$elapsed[["elapsed"]] + cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n") + } + + if (isTRUE(x$elapsed$print)) { + elapsed <- x$elapsed$elapsed[["elapsed"]] + cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n") } if (inherits(x$fit, "try-error")) { diff --git a/R/fit_helpers.R b/R/fit_helpers.R index ec4ddf426..29002a41d 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -44,29 +44,19 @@ form_form <- spec = object ) - if (control$verbosity > 1L) { - elapsed <- system.time( - res$fit <- eval_mod( - fit_call, - capture = control$verbosity == 0, - catch = control$catch, - envir = env, - ... - ), - gcFirst = FALSE - ) - } else { + elapsed <- system.time( res$fit <- eval_mod( fit_call, capture = control$verbosity == 0, catch = control$catch, envir = env, ... - ) - elapsed <- list(elapsed = NA_real_) - } + ), + gcFirst = FALSE + ) res$preproc <- list(y_var = all.vars(rlang::f_lhs(env$formula))) - res$elapsed <- elapsed + res$elapsed <- list(elapsed = elapsed, print = control$verbosity > 1L) + res } @@ -102,27 +92,16 @@ xy_xy <- function(object, env, control, target = "none", ...) { res <- list(lvl = levels(env$y), spec = object) - if (control$verbosity > 1L) { - elapsed <- system.time( - res$fit <- eval_mod( - fit_call, - capture = control$verbosity == 0, - catch = control$catch, - envir = env, - ... - ), - gcFirst = FALSE - ) - } else { + elapsed <- system.time( res$fit <- eval_mod( fit_call, capture = control$verbosity == 0, catch = control$catch, envir = env, ... - ) - elapsed <- list(elapsed = NA_real_) - } + ), + gcFirst = FALSE + ) if (is.atomic(env$y)) { y_name <- character(0) @@ -130,7 +109,7 @@ xy_xy <- function(object, env, control, target = "none", ...) { y_name <- colnames(env$y) } res$preproc <- list(y_var = y_name) - res$elapsed <- elapsed + res$elapsed <- list(elapsed = elapsed, print = control$verbosity > 1L) res } @@ -176,9 +155,9 @@ xy_form <- function(object, env, control, ...) { check_outcome(env$y, object) encoding_info <- get_encoding(class(object)[1]) - encoding_info <- + encoding_info <- vctrs::vec_slice( - encoding_info, + encoding_info, encoding_info$mode == object$mode & encoding_info$engine == object$engine ) diff --git a/R/reexports.R b/R/reexports.R index a7d7264d9..2760c2a14 100644 --- a/R/reexports.R +++ b/R/reexports.R @@ -58,3 +58,6 @@ hardhat::frequency_weights #' @export hardhat::importance_weights +#' @importFrom hardhat extract_fit_time +#' @export +hardhat::extract_fit_time diff --git a/man/extract-parsnip.Rd b/man/extract-parsnip.Rd index f926c7848..f6544fbc9 100644 --- a/man/extract-parsnip.Rd +++ b/man/extract-parsnip.Rd @@ -6,6 +6,7 @@ \alias{extract_fit_engine.model_fit} \alias{extract_parameter_set_dials.model_spec} \alias{extract_parameter_dials.model_spec} +\alias{extract_fit_time.model_fit} \title{Extract elements of a parsnip model object} \usage{ \method{extract_spec_parsnip}{model_fit}(x, ...) @@ -15,6 +16,8 @@ \method{extract_parameter_set_dials}{model_spec}(x, ...) \method{extract_parameter_dials}{model_spec}(x, parameter, ...) + +\method{extract_fit_time}{model_fit}(x, summarize = TRUE, ...) } \arguments{ \item{x}{A parsnip \code{model_fit} object or a parsnip \code{model_spec} object.} @@ -22,6 +25,10 @@ \item{...}{Not currently used.} \item{parameter}{A single string for the parameter ID.} + +\item{summarize}{A logical for whether the elapsed fit time should be +returned as a single row or multiple rows. Doesn't support \code{FALSE} for +parsnip models.} } \value{ The extracted value from the parsnip object, \code{x}, as described in the description @@ -37,6 +44,9 @@ a parsnip model fit. For example, when using \code{\link[=linear_reg]{linear_reg with the \code{"lm"} engine, this returns the underlying \code{lm} object. \item \code{extract_parameter_dials()} returns a single dials parameter object. \item \code{extract_parameter_set_dials()} returns a set of dials parameter objects. +\item \code{extract_fit_time()} returns a tibble with fit times. The fit times +correspond to the time for the parsnip engine to fit and do not include +other portions of the elapsed time in \code{\link[=fit.model_spec]{fit.model_spec()}}. } } \details{ diff --git a/man/reexports.Rd b/man/reexports.Rd index b498b5b7e..f87bde459 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -18,6 +18,7 @@ \alias{tune} \alias{frequency_weights} \alias{importance_weights} +\alias{extract_fit_time} \alias{varying_args} \title{Objects exported from other packages} \keyword{internal} @@ -30,7 +31,7 @@ below to see their documentation. \item{ggplot2}{\code{\link[ggplot2]{autoplot}}} - \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{frequency_weights}}, \code{\link[hardhat]{importance_weights}}, \code{\link[hardhat]{tune}}} + \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_fit_time}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{frequency_weights}}, \code{\link[hardhat]{importance_weights}}, \code{\link[hardhat]{tune}}} \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} }} diff --git a/tests/testthat/_snaps/extract.md b/tests/testthat/_snaps/extract.md index ce27245af..d4363fb70 100644 --- a/tests/testthat/_snaps/extract.md +++ b/tests/testthat/_snaps/extract.md @@ -18,3 +18,11 @@ i The parsnip extension package baguette implements support for this specification. i Please install (if needed) and load to continue. +# extract_fit_time() works + + Code + extract_fit_time(lm_fit) + Condition + Error in `extract_fit_time()`: + ! This model was fit before `extract_fit_time()` was added. + diff --git a/tests/testthat/test_extract.R b/tests/testthat/test_extract.R index ebd5dbc8b..e48727ee9 100644 --- a/tests/testthat/test_extract.R +++ b/tests/testthat/test_extract.R @@ -95,3 +95,22 @@ test_that("extract_parameter_dials doesn't error if namespaced args are used", { NA ) }) + +test_that("extract_fit_time() works", { + lm_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars) + + res <- extract_fit_time(lm_fit) + + expect_true(is_tibble(res)) + expect_identical(names(res), c("stage_id", "elapsed")) + expect_identical(res$stage_id, "linear_reg") + expect_true(is.double(res$elapsed)) + expect_true(res$elapsed >= 0) + + lm_fit$elapsed$elapsed <- NULL + + expect_snapshot( + error = TRUE, + extract_fit_time(lm_fit) + ) +})