From 40524aab2cd6b4605fa4b7f65c91c3b363fc16ea Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 16 Dec 2022 12:16:39 -0800 Subject: [PATCH 01/21] ad extract_fit_time.model_fit --- NAMESPACE | 3 +++ R/extract.R | 12 ++++++++++++ R/fit.R | 10 ++++++++-- R/fit_helpers.R | 42 ++++++++++-------------------------------- R/reexports.R | 3 +++ 5 files changed, 36 insertions(+), 34 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 3bd487e2a..af24791fa 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,7 @@ S3method(augment,model_fit) S3method(extract_fit_engine,model_fit) +S3method(extract_fit_time,model_fit) S3method(extract_parameter_dials,model_spec) S3method(extract_parameter_set_dials,model_spec) S3method(extract_spec_parsnip,model_fit) @@ -171,6 +172,7 @@ export(discrim_quad) export(discrim_regularized) export(eval_args) export(extract_fit_engine) +export(extract_fit_time) export(extract_parameter_dials) export(extract_parameter_set_dials) export(extract_spec_parsnip) @@ -320,6 +322,7 @@ importFrom(generics,varying_args) importFrom(ggplot2,autoplot) importFrom(glue,glue_collapse) importFrom(hardhat,extract_fit_engine) +importFrom(hardhat,extract_fit_time) importFrom(hardhat,extract_parameter_dials) importFrom(hardhat,extract_parameter_set_dials) importFrom(hardhat,extract_spec_parsnip) diff --git a/R/extract.R b/R/extract.R index 4a3b71c8b..19f8af3f6 100644 --- a/R/extract.R +++ b/R/extract.R @@ -14,6 +14,8 @@ #' #' - `extract_parameter_set_dials()` returns a set of dials parameter objects. #' +#' - `extract_fit_time()` returns a tibble with fit times. +#' #' @param x A parsnip `model_fit` object or a parsnip `model_spec` object. #' @param parameter A single string for the parameter ID. #' @param ... Not currently used. @@ -127,3 +129,13 @@ eval_call_info <- function(x) { extract_parameter_dials.model_spec <- function(x, parameter, ...) { extract_parameter_dials(extract_parameter_set_dials(x), parameter) } + +#' @export +#' @rdname extract-parsnip +extract_fit_time.model_fit <- function(x, ...) { + dplyr::tibble( + id = class(x$spec)[1], + part = "fit", + time = x$elapsed$elapsed[["elapsed"]] + ) +} diff --git a/R/fit.R b/R/fit.R index 6ed0ea036..9f21553dc 100644 --- a/R/fit.R +++ b/R/fit.R @@ -445,8 +445,14 @@ allow_sparse <- function(x) { #' @export print.model_fit <- function(x, ...) { cat("parsnip model object\n\n") - if (!is.na(x$elapsed[["elapsed"]])) { - cat("Fit time: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]), "\n") + if (is.null(x$elapsed$print)) { + if (!is.na(x$elapsed[["elapsed"]])) { + cat("Fit time: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]), "\n") + } + } else { + if (x$elapsed$print) { + cat("Fit time: ", prettyunits::pretty_sec(x$elapsed$elapsed[["elapsed"]]), "\n") + } } if (inherits(x$fit, "try-error")) { diff --git a/R/fit_helpers.R b/R/fit_helpers.R index f11d47e26..d54a91da3 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -32,29 +32,18 @@ form_form <- spec = object ) - if (control$verbosity > 1L) { - elapsed <- system.time( - res$fit <- eval_mod( - fit_call, - capture = control$verbosity == 0, - catch = control$catch, - envir = env, - ... - ), - gcFirst = FALSE - ) - } else { + elapsed <- system.time( res$fit <- eval_mod( fit_call, capture = control$verbosity == 0, catch = control$catch, envir = env, ... - ) - elapsed <- list(elapsed = NA_real_) - } + ), + gcFirst = FALSE + ) res$preproc <- list(y_var = all.vars(env$formula[[2]])) - res$elapsed <- elapsed + res$elapsed <- list(elapsed = elapsed, print = control$verbosity > 1L) res } @@ -90,27 +79,16 @@ xy_xy <- function(object, env, control, target = "none", ...) { res <- list(lvl = levels(env$y), spec = object) - if (control$verbosity > 1L) { - elapsed <- system.time( - res$fit <- eval_mod( - fit_call, - capture = control$verbosity == 0, - catch = control$catch, - envir = env, - ... - ), - gcFirst = FALSE - ) - } else { + elapsed <- system.time( res$fit <- eval_mod( fit_call, capture = control$verbosity == 0, catch = control$catch, envir = env, ... - ) - elapsed <- list(elapsed = NA_real_) - } + ), + gcFirst = FALSE + ) if (is.vector(env$y)) { y_name <- character(0) @@ -118,7 +96,7 @@ xy_xy <- function(object, env, control, target = "none", ...) { y_name <- colnames(env$y) } res$preproc <- list(y_var = y_name) - res$elapsed <- elapsed + res$elapsed <- list(elapsed = elapsed, print = control$verbosity > 1L) res } diff --git a/R/reexports.R b/R/reexports.R index a7d7264d9..2760c2a14 100644 --- a/R/reexports.R +++ b/R/reexports.R @@ -58,3 +58,6 @@ hardhat::frequency_weights #' @export hardhat::importance_weights +#' @importFrom hardhat extract_fit_time +#' @export +hardhat::extract_fit_time From 3dc756c8f8a99f03abceed479b931f127f1162be Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 16 Dec 2022 13:14:47 -0800 Subject: [PATCH 02/21] use dev hardhat --- DESCRIPTION | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 1fd35efef..0bc927a13 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,7 +25,7 @@ Imports: ggplot2, globals, glue, - hardhat (>= 1.1.0), + hardhat (>= 1.2.0.9000), lifecycle, magrittr, pillar, @@ -76,3 +76,5 @@ Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) RoxygenNote: 7.2.1.9000 +Remotes: + tidymodels/hardhat From 4757dc93882a52492da0eb3025a3a49f7e5334ec Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 16 Dec 2022 13:16:39 -0800 Subject: [PATCH 03/21] more specific dev --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 0bc927a13..e8a202fa3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -77,4 +77,4 @@ LazyData: true Roxygen: list(markdown = TRUE) RoxygenNote: 7.2.1.9000 Remotes: - tidymodels/hardhat + tidymodels/hardhat@extract_fit_time From 7091618a7c65d039a43cef792840fac87782d991 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 15 Jun 2023 14:39:40 -0700 Subject: [PATCH 04/21] reknit --- man/extract-parsnip.Rd | 4 ++++ man/reexports.Rd | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/man/extract-parsnip.Rd b/man/extract-parsnip.Rd index a9a8d8b60..9ce9300a8 100644 --- a/man/extract-parsnip.Rd +++ b/man/extract-parsnip.Rd @@ -6,6 +6,7 @@ \alias{extract_fit_engine.model_fit} \alias{extract_parameter_set_dials.model_spec} \alias{extract_parameter_dials.model_spec} +\alias{extract_fit_time.model_fit} \title{Extract elements of a parsnip model object} \usage{ \method{extract_spec_parsnip}{model_fit}(x, ...) @@ -15,6 +16,8 @@ \method{extract_parameter_set_dials}{model_spec}(x, ...) \method{extract_parameter_dials}{model_spec}(x, parameter, ...) + +\method{extract_fit_time}{model_fit}(x, summarize = FALSE, ...) } \arguments{ \item{x}{A parsnip \code{model_fit} object or a parsnip \code{model_spec} object.} @@ -37,6 +40,7 @@ a parsnip model fit. For example, when using \code{\link[=linear_reg]{linear_reg with the \code{"lm"} engine, this returns the underlying \code{lm} object. \item \code{extract_parameter_dials()} returns a single dials parameter object. \item \code{extract_parameter_set_dials()} returns a set of dials parameter objects. +\item \code{extract_fit_time()} returns a tibble with fit times. } } \details{ diff --git a/man/reexports.Rd b/man/reexports.Rd index b498b5b7e..f87bde459 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -18,6 +18,7 @@ \alias{tune} \alias{frequency_weights} \alias{importance_weights} +\alias{extract_fit_time} \alias{varying_args} \title{Objects exported from other packages} \keyword{internal} @@ -30,7 +31,7 @@ below to see their documentation. \item{ggplot2}{\code{\link[ggplot2]{autoplot}}} - \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{frequency_weights}}, \code{\link[hardhat]{importance_weights}}, \code{\link[hardhat]{tune}}} + \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_fit_time}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{frequency_weights}}, \code{\link[hardhat]{importance_weights}}, \code{\link[hardhat]{tune}}} \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} }} From 3cff7041d74ef21576d0ced163043689dc0313b4 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 15 Jun 2023 14:39:54 -0700 Subject: [PATCH 05/21] add summarise argument to extract_fit_time.model_fit() --- R/extract.R | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/R/extract.R b/R/extract.R index 19f8af3f6..9b649fc9b 100644 --- a/R/extract.R +++ b/R/extract.R @@ -132,10 +132,13 @@ extract_parameter_dials.model_spec <- function(x, parameter, ...) { #' @export #' @rdname extract-parsnip -extract_fit_time.model_fit <- function(x, ...) { +extract_fit_time.model_fit <- function(x, summarize = FALSE, ...) { + if (summarize == TRUE) { + rlang::abort("`summarize = TRUE` is not supported for `model_fit` objects.") + } + dplyr::tibble( id = class(x$spec)[1], - part = "fit", time = x$elapsed$elapsed[["elapsed"]] ) } From cfbb47fca80e4953d1e09d9542350b0b2b01f6f3 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 16 Jun 2023 11:09:35 -0700 Subject: [PATCH 06/21] set summarize = TRUE as default in extract_fit_time() --- R/extract.R | 11 ++++++++--- man/extract-parsnip.Rd | 6 +++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/R/extract.R b/R/extract.R index 9b649fc9b..86bf40f46 100644 --- a/R/extract.R +++ b/R/extract.R @@ -18,6 +18,9 @@ #' #' @param x A parsnip `model_fit` object or a parsnip `model_spec` object. #' @param parameter A single string for the parameter ID. +#' @param summarize A logical for whether the elapsed fit time should be +#' returned as a single row or multiple rows. Doesn't support `FALSE` for +#' parsnip models. #' @param ... Not currently used. #' @details #' Extracting the underlying engine fit can be helpful for describing the @@ -132,9 +135,11 @@ extract_parameter_dials.model_spec <- function(x, parameter, ...) { #' @export #' @rdname extract-parsnip -extract_fit_time.model_fit <- function(x, summarize = FALSE, ...) { - if (summarize == TRUE) { - rlang::abort("`summarize = TRUE` is not supported for `model_fit` objects.") +extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) { + if (summarize == FALSE) { + rlang::abort( + "`summarize = FALSE` is not supported for `model_fit` objects." + ) } dplyr::tibble( diff --git a/man/extract-parsnip.Rd b/man/extract-parsnip.Rd index 9ce9300a8..56750d251 100644 --- a/man/extract-parsnip.Rd +++ b/man/extract-parsnip.Rd @@ -17,7 +17,7 @@ \method{extract_parameter_dials}{model_spec}(x, parameter, ...) -\method{extract_fit_time}{model_fit}(x, summarize = FALSE, ...) +\method{extract_fit_time}{model_fit}(x, summarize = TRUE, ...) } \arguments{ \item{x}{A parsnip \code{model_fit} object or a parsnip \code{model_spec} object.} @@ -25,6 +25,10 @@ \item{...}{Not currently used.} \item{parameter}{A single string for the parameter ID.} + +\item{summarize}{A logical for whether the elapsed fit time should be +returned as a single row or multiple rows. Doesn't support \code{FALSE} for +parsnip models.} } \value{ The extracted value from the parsnip object, \code{x}, as described in the description From 934ea932457da2b67cc5a1cb08972ec8b240027e Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 16 Jun 2023 11:10:24 -0700 Subject: [PATCH 07/21] id -> process_id in extract_fit_time() --- R/extract.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/extract.R b/R/extract.R index 86bf40f46..acf5f697a 100644 --- a/R/extract.R +++ b/R/extract.R @@ -143,7 +143,7 @@ extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) { } dplyr::tibble( - id = class(x$spec)[1], + process_id = class(x$spec)[1], time = x$elapsed$elapsed[["elapsed"]] ) } From 41c73e772561ab88e38590a168e2a23a6af70ac0 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 16 Jun 2023 11:16:28 -0700 Subject: [PATCH 08/21] add backwards compatibility --- R/extract.R | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/R/extract.R b/R/extract.R index acf5f697a..895572005 100644 --- a/R/extract.R +++ b/R/extract.R @@ -142,8 +142,16 @@ extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) { ) } + time <- x$elapsed$elapsed[["elapsed"]] + + if (is.na(time)) { + rlang::abort( + "This model was fit before `extract_fit_time()` was added." + ) + } + dplyr::tibble( process_id = class(x$spec)[1], - time = x$elapsed$elapsed[["elapsed"]] + time = time ) } From c57a823f01bf6a9fbc6d9fb64bdc4653d4870390 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 16 Jun 2023 11:21:27 -0700 Subject: [PATCH 09/21] better error catching --- R/extract.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/extract.R b/R/extract.R index 895572005..01892e03f 100644 --- a/R/extract.R +++ b/R/extract.R @@ -142,9 +142,9 @@ extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) { ) } - time <- x$elapsed$elapsed[["elapsed"]] + time <- x[["elapsed"]][["elapsed"]][["elapsed"]] - if (is.na(time)) { + if (is.na(time) || is.null(time)) { rlang::abort( "This model was fit before `extract_fit_time()` was added." ) From 075efff7cd4f9ff19bf2cdb9c7fd01b509ef7a37 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 16 Jun 2023 11:21:43 -0700 Subject: [PATCH 10/21] add tests for extract_fit_time() --- tests/testthat/_snaps/extract.md | 18 +++--------------- tests/testthat/test-extract.R | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 15 deletions(-) create mode 100644 tests/testthat/test-extract.R diff --git a/tests/testthat/_snaps/extract.md b/tests/testthat/_snaps/extract.md index ce27245af..c0d88ac29 100644 --- a/tests/testthat/_snaps/extract.md +++ b/tests/testthat/_snaps/extract.md @@ -1,20 +1,8 @@ -# extract parameter set from model with no loaded implementation +# extract_fit_time() works - Code - extract_parameter_set_dials(bt_mod) - Condition - Error: - ! parsnip could not locate an implementation for `bag_tree` regression model specifications. - i The parsnip extension package baguette implements support for this specification. - i Please install (if needed) and load to continue. + `summarize = FALSE` is not supported for `model_fit` objects. --- - Code - extract_parameter_dials(bt_mod, parameter = "min_n") - Condition - Error: - ! parsnip could not locate an implementation for `bag_tree` regression model specifications. - i The parsnip extension package baguette implements support for this specification. - i Please install (if needed) and load to continue. + This model was fit before `extract_fit_time()` was added. diff --git a/tests/testthat/test-extract.R b/tests/testthat/test-extract.R new file mode 100644 index 000000000..0f6b2cd9b --- /dev/null +++ b/tests/testthat/test-extract.R @@ -0,0 +1,21 @@ +test_that("extract_fit_time() works", { + lm_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars) + + res <- extract_fit_time(lm_fit) + + expect_true(is_tibble(res)) + expect_identical(names(res), c("process_id", "time")) + expect_identical(res$process_id, "linear_reg") + expect_true(is.double(res$time)) + expect_true(res$time >= 0) + + expect_snapshot_error( + extract_fit_time(lm_fit, summarize = FALSE) + ) + + lm_fit$elapsed$elapsed <- NULL + + expect_snapshot_error( + extract_fit_time(lm_fit) + ) +}) From 7076ea6a472befef2b2920744277b52712a14e72 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 16 Jun 2023 11:29:12 -0700 Subject: [PATCH 11/21] add news bullet --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index fef606ff8..de89662b1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,7 @@ * A few censored regression helper functions were exported: `.extract_surv_status()`, `.extract_surv_time()`, and `.time_as_binary_event()` (#973). +* New `extract_fit_time()` method has been added that return the time it took to train the recipe. (#853) # parsnip 1.1.0 From 48fe212566d582efa753aeb94a173036286619bb Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 16 Jun 2023 12:12:52 -0700 Subject: [PATCH 12/21] clarify news bullet --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index de89662b1..c506df12a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,7 +8,7 @@ * A few censored regression helper functions were exported: `.extract_surv_status()`, `.extract_surv_time()`, and `.time_as_binary_event()` (#973). -* New `extract_fit_time()` method has been added that return the time it took to train the recipe. (#853) +* New `extract_fit_time()` method has been added that return the time it took to train the model. (#853) # parsnip 1.1.0 From eba98329780f3422361c0341d371be7c25907907 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Wed, 27 Mar 2024 11:30:58 -0500 Subject: [PATCH 13/21] transition `expect_snapshot_error()` -> `expect_snapshot(error = TRUE)` --- tests/testthat/_snaps/extract.md | 12 ++++++++++-- tests/testthat/test-extract.R | 6 ++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/testthat/_snaps/extract.md b/tests/testthat/_snaps/extract.md index c0d88ac29..988b415ae 100644 --- a/tests/testthat/_snaps/extract.md +++ b/tests/testthat/_snaps/extract.md @@ -1,8 +1,16 @@ # extract_fit_time() works - `summarize = FALSE` is not supported for `model_fit` objects. + Code + extract_fit_time(lm_fit, summarize = FALSE) + Condition + Error in `extract_fit_time()`: + ! `summarize = FALSE` is not supported for `model_fit` objects. --- - This model was fit before `extract_fit_time()` was added. + Code + extract_fit_time(lm_fit) + Condition + Error in `extract_fit_time()`: + ! This model was fit before `extract_fit_time()` was added. diff --git a/tests/testthat/test-extract.R b/tests/testthat/test-extract.R index 0f6b2cd9b..a1aab9a23 100644 --- a/tests/testthat/test-extract.R +++ b/tests/testthat/test-extract.R @@ -9,13 +9,15 @@ test_that("extract_fit_time() works", { expect_true(is.double(res$time)) expect_true(res$time >= 0) - expect_snapshot_error( + expect_snapshot( + error = TRUE, extract_fit_time(lm_fit, summarize = FALSE) ) lm_fit$elapsed$elapsed <- NULL - expect_snapshot_error( + expect_snapshot( + error = TRUE, extract_fit_time(lm_fit) ) }) From 982e91d20915259c702a96c75043c2d156ade819 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Wed, 27 Mar 2024 15:54:12 -0500 Subject: [PATCH 14/21] `time` -> `elapsed` --- R/extract.R | 6 +++--- tests/testthat/test-extract.R | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/R/extract.R b/R/extract.R index 983ca6b1b..c9f3d5dd3 100644 --- a/R/extract.R +++ b/R/extract.R @@ -142,9 +142,9 @@ extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) { ) } - time <- x[["elapsed"]][["elapsed"]][["elapsed"]] + elapsed <- x[["elapsed"]][["elapsed"]][["elapsed"]] - if (is.na(time) || is.null(time)) { + if (is.na(elapsed) || is.null(elapsed)) { rlang::abort( "This model was fit before `extract_fit_time()` was added." ) @@ -152,6 +152,6 @@ extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) { dplyr::tibble( process_id = class(x$spec)[1], - time = time + elapsed = elapsed ) } diff --git a/tests/testthat/test-extract.R b/tests/testthat/test-extract.R index a1aab9a23..5edde148e 100644 --- a/tests/testthat/test-extract.R +++ b/tests/testthat/test-extract.R @@ -4,10 +4,10 @@ test_that("extract_fit_time() works", { res <- extract_fit_time(lm_fit) expect_true(is_tibble(res)) - expect_identical(names(res), c("process_id", "time")) + expect_identical(names(res), c("process_id", "elapsed")) expect_identical(res$process_id, "linear_reg") - expect_true(is.double(res$time)) - expect_true(res$time >= 0) + expect_true(is.double(res$elapsed)) + expect_true(res$elapsed >= 0) expect_snapshot( error = TRUE, From 5cd5f02560285e43029ada2631bfc738a113efdd Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Wed, 27 Mar 2024 16:08:03 -0500 Subject: [PATCH 15/21] merge extract tests, generate snaps --- tests/testthat/_snaps/extract.md | 20 ++++++++++++++++++++ tests/testthat/test-extract.R | 23 ----------------------- tests/testthat/test_extract.R | 24 ++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 23 deletions(-) delete mode 100644 tests/testthat/test-extract.R diff --git a/tests/testthat/_snaps/extract.md b/tests/testthat/_snaps/extract.md index 988b415ae..3fe075e21 100644 --- a/tests/testthat/_snaps/extract.md +++ b/tests/testthat/_snaps/extract.md @@ -1,3 +1,23 @@ +# extract parameter set from model with no loaded implementation + + Code + extract_parameter_set_dials(bt_mod) + Condition + Error: + ! parsnip could not locate an implementation for `bag_tree` regression model specifications. + i The parsnip extension package baguette implements support for this specification. + i Please install (if needed) and load to continue. + +--- + + Code + extract_parameter_dials(bt_mod, parameter = "min_n") + Condition + Error: + ! parsnip could not locate an implementation for `bag_tree` regression model specifications. + i The parsnip extension package baguette implements support for this specification. + i Please install (if needed) and load to continue. + # extract_fit_time() works Code diff --git a/tests/testthat/test-extract.R b/tests/testthat/test-extract.R deleted file mode 100644 index 5edde148e..000000000 --- a/tests/testthat/test-extract.R +++ /dev/null @@ -1,23 +0,0 @@ -test_that("extract_fit_time() works", { - lm_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars) - - res <- extract_fit_time(lm_fit) - - expect_true(is_tibble(res)) - expect_identical(names(res), c("process_id", "elapsed")) - expect_identical(res$process_id, "linear_reg") - expect_true(is.double(res$elapsed)) - expect_true(res$elapsed >= 0) - - expect_snapshot( - error = TRUE, - extract_fit_time(lm_fit, summarize = FALSE) - ) - - lm_fit$elapsed$elapsed <- NULL - - expect_snapshot( - error = TRUE, - extract_fit_time(lm_fit) - ) -}) diff --git a/tests/testthat/test_extract.R b/tests/testthat/test_extract.R index ebd5dbc8b..90bf303c4 100644 --- a/tests/testthat/test_extract.R +++ b/tests/testthat/test_extract.R @@ -95,3 +95,27 @@ test_that("extract_parameter_dials doesn't error if namespaced args are used", { NA ) }) + +test_that("extract_fit_time() works", { + lm_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars) + + res <- extract_fit_time(lm_fit) + + expect_true(is_tibble(res)) + expect_identical(names(res), c("process_id", "elapsed")) + expect_identical(res$process_id, "linear_reg") + expect_true(is.double(res$elapsed)) + expect_true(res$elapsed >= 0) + + expect_snapshot( + error = TRUE, + extract_fit_time(lm_fit, summarize = FALSE) + ) + + lm_fit$elapsed$elapsed <- NULL + + expect_snapshot( + error = TRUE, + extract_fit_time(lm_fit) + ) +}) From 52c39281e66cca002aa6a28d643e4571e2543e33 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 5 Apr 2024 13:15:33 -0700 Subject: [PATCH 16/21] remove check for wrong input in extract_fit_time --- R/extract.R | 6 ------ tests/testthat/_snaps/extract.md | 8 -------- tests/testthat/test_extract.R | 5 ----- 3 files changed, 19 deletions(-) diff --git a/R/extract.R b/R/extract.R index c9f3d5dd3..881e25e59 100644 --- a/R/extract.R +++ b/R/extract.R @@ -136,12 +136,6 @@ extract_parameter_dials.model_spec <- function(x, parameter, ...) { #' @export #' @rdname extract-parsnip extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) { - if (summarize == FALSE) { - rlang::abort( - "`summarize = FALSE` is not supported for `model_fit` objects." - ) - } - elapsed <- x[["elapsed"]][["elapsed"]][["elapsed"]] if (is.na(elapsed) || is.null(elapsed)) { diff --git a/tests/testthat/_snaps/extract.md b/tests/testthat/_snaps/extract.md index 3fe075e21..d4363fb70 100644 --- a/tests/testthat/_snaps/extract.md +++ b/tests/testthat/_snaps/extract.md @@ -20,14 +20,6 @@ # extract_fit_time() works - Code - extract_fit_time(lm_fit, summarize = FALSE) - Condition - Error in `extract_fit_time()`: - ! `summarize = FALSE` is not supported for `model_fit` objects. - ---- - Code extract_fit_time(lm_fit) Condition diff --git a/tests/testthat/test_extract.R b/tests/testthat/test_extract.R index 90bf303c4..2435243e7 100644 --- a/tests/testthat/test_extract.R +++ b/tests/testthat/test_extract.R @@ -107,11 +107,6 @@ test_that("extract_fit_time() works", { expect_true(is.double(res$elapsed)) expect_true(res$elapsed >= 0) - expect_snapshot( - error = TRUE, - extract_fit_time(lm_fit, summarize = FALSE) - ) - lm_fit$elapsed$elapsed <- NULL expect_snapshot( From 4a7bb42289a37cfcd7050996695ddf05c774ee2a Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 5 Apr 2024 13:25:32 -0700 Subject: [PATCH 17/21] refactor fit time logic in print.model_fit() --- R/fit.R | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/R/fit.R b/R/fit.R index 6d12f8017..288fc64a2 100644 --- a/R/fit.R +++ b/R/fit.R @@ -453,14 +453,15 @@ allow_sparse <- function(x) { #' @export print.model_fit <- function(x, ...) { cat("parsnip model object\n\n") - if (is.null(x$elapsed$print)) { - if (!is.na(x$elapsed[["elapsed"]])) { - cat("Fit time: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]), "\n") - } - } else { - if (x$elapsed$print) { - cat("Fit time: ", prettyunits::pretty_sec(x$elapsed$elapsed[["elapsed"]]), "\n") - } + + if (is.null(x$elapsed$print) && !is.na(x$elapsed[["elapsed"]])) { + elapsed <- x$elapsed[["elapsed"]] + cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n") + } + + if (isTRUE(x$elapsed$print)) { + elapsed <- x$elapsed$elapsed[["elapsed"]] + cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n") } if (inherits(x$fit, "try-error")) { From 834a3c9a4a0a7639307893481dfe14b3ce2be72d Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 5 Apr 2024 13:33:13 -0700 Subject: [PATCH 18/21] process_id -> stage_id --- R/extract.R | 2 +- tests/testthat/test_extract.R | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/extract.R b/R/extract.R index 881e25e59..049186b58 100644 --- a/R/extract.R +++ b/R/extract.R @@ -145,7 +145,7 @@ extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) { } dplyr::tibble( - process_id = class(x$spec)[1], + stage_id = class(x$spec)[1], elapsed = elapsed ) } diff --git a/tests/testthat/test_extract.R b/tests/testthat/test_extract.R index 2435243e7..e48727ee9 100644 --- a/tests/testthat/test_extract.R +++ b/tests/testthat/test_extract.R @@ -102,8 +102,8 @@ test_that("extract_fit_time() works", { res <- extract_fit_time(lm_fit) expect_true(is_tibble(res)) - expect_identical(names(res), c("process_id", "elapsed")) - expect_identical(res$process_id, "linear_reg") + expect_identical(names(res), c("stage_id", "elapsed")) + expect_identical(res$stage_id, "linear_reg") expect_true(is.double(res$elapsed)) expect_true(res$elapsed >= 0) From 631ae7ecfd67752d8a3e5e5812db0608324b3a55 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 5 Apr 2024 14:03:22 -0700 Subject: [PATCH 19/21] Update R/extract.R Co-authored-by: Simon P. Couch --- R/extract.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/extract.R b/R/extract.R index 049186b58..5445e0d01 100644 --- a/R/extract.R +++ b/R/extract.R @@ -14,7 +14,9 @@ #' #' - `extract_parameter_set_dials()` returns a set of dials parameter objects. #' -#' - `extract_fit_time()` returns a tibble with fit times. +#' - `extract_fit_time()` returns a tibble with fit times. The fit times correspond to +#' the time for the parsnip engine to fit and do not include other portions of the +#' elapsed time in [parsnip::fit.model_spec()]. #' #' @param x A parsnip `model_fit` object or a parsnip `model_spec` object. #' @param parameter A single string for the parameter ID. From 09fdfb7cc7e354ea4553764ed50062d3b2de370d Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 5 Apr 2024 14:04:59 -0700 Subject: [PATCH 20/21] redocument --- R/extract.R | 6 +++--- man/extract-parsnip.Rd | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/R/extract.R b/R/extract.R index 5445e0d01..d85f05d3f 100644 --- a/R/extract.R +++ b/R/extract.R @@ -14,9 +14,9 @@ #' #' - `extract_parameter_set_dials()` returns a set of dials parameter objects. #' -#' - `extract_fit_time()` returns a tibble with fit times. The fit times correspond to -#' the time for the parsnip engine to fit and do not include other portions of the -#' elapsed time in [parsnip::fit.model_spec()]. +#' - `extract_fit_time()` returns a tibble with fit times. The fit times +#' correspond to the time for the parsnip engine to fit and do not include +#' other portions of the elapsed time in [parsnip::fit.model_spec()]. #' #' @param x A parsnip `model_fit` object or a parsnip `model_spec` object. #' @param parameter A single string for the parameter ID. diff --git a/man/extract-parsnip.Rd b/man/extract-parsnip.Rd index eb3d1e631..f6544fbc9 100644 --- a/man/extract-parsnip.Rd +++ b/man/extract-parsnip.Rd @@ -44,7 +44,9 @@ a parsnip model fit. For example, when using \code{\link[=linear_reg]{linear_reg with the \code{"lm"} engine, this returns the underlying \code{lm} object. \item \code{extract_parameter_dials()} returns a single dials parameter object. \item \code{extract_parameter_set_dials()} returns a set of dials parameter objects. -\item \code{extract_fit_time()} returns a tibble with fit times. +\item \code{extract_fit_time()} returns a tibble with fit times. The fit times +correspond to the time for the parsnip engine to fit and do not include +other portions of the elapsed time in \code{\link[=fit.model_spec]{fit.model_spec()}}. } } \details{ From 6d76a59e30370aecbf209be36f7568849e7a9e02 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 5 Apr 2024 14:08:22 -0700 Subject: [PATCH 21/21] update Remotes --- DESCRIPTION | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index e3c5b5eef..dc7494571 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,7 +25,7 @@ Imports: ggplot2, globals, glue, - hardhat (>= 1.2.0.9000), + hardhat (>= 1.3.1.9000), lifecycle, magrittr, pillar, @@ -78,5 +78,5 @@ Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) Remotes: - tidymodels/hardhat@extract_fit_time + tidymodels/hardhat RoxygenNote: 7.3.1