From 227ab92dfb803f5a6f42373e30cc15f9332712cd Mon Sep 17 00:00:00 2001 From: "Simon P. Couch" Date: Mon, 9 Sep 2024 11:23:43 -0500 Subject: [PATCH] use type checkers in remaining functions (#1186) --------- Co-authored-by: Emil Hvitfeldt --- R/autoplot.R | 15 +++------------ R/control_parsnip.R | 8 ++------ R/required_pkgs.R | 2 ++ R/tidy_glmnet.R | 10 ++-------- R/tune_args.R | 5 +++-- 5 files changed, 12 insertions(+), 28 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 785f0a5e4..5ace8f2e0 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -34,6 +34,9 @@ autoplot.model_fit <- function(object, ...) { #' @rdname autoplot.model_fit autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL, top_n = 3L) { + check_number_decimal(min_penalty, min = 0, max = 1) + check_number_decimal(best_penalty, min = 0, max = 1, allow_null = TRUE) + check_number_whole(top_n, min = 1, max = Inf, allow_infinite = TRUE) autoplot_glmnet(object, min_penalty, best_penalty, top_n, ...) } @@ -87,8 +90,6 @@ top_coefs <- function(x, top_n = 5) { } autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) { - check_penalty_value(min_penalty) - tidy_coefs <- map_glmnet_coefs(x) %>% dplyr::filter(penalty >= min_penalty) @@ -138,7 +139,6 @@ autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, } if (!is.null(best_penalty)) { - check_penalty_value(best_penalty) p <- p + ggplot2::geom_vline(xintercept = best_penalty, lty = 3) } @@ -159,13 +159,4 @@ autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, p } -check_penalty_value <- function(x) { - cl <- match.call() - arg_val <- as.character(cl$x) - if (!is.vector(x) || length(x) != 1 || !is.numeric(x) || x < 0) { - cli::cli_abort("{.arg {arg_val}} should be a single, non-negative value.") - } - invisible(x) -} - # nocov end diff --git a/R/control_parsnip.R b/R/control_parsnip.R index 52f6cbd29..bdb9aaeb1 100644 --- a/R/control_parsnip.R +++ b/R/control_parsnip.R @@ -37,12 +37,8 @@ check_control <- function(x, call = rlang::caller_env()) { and {.field catch}.", call = call ) - # based on ?is.integer - int_check <- function(x, tol = .Machine$double.eps^0.5) abs(x - round(x)) < tol - if (!int_check(x$verbosity)) - cli::cli_abort("{.arg verbosity} should be an integer.", call = call) - if (!is.logical(x$catch)) - cli::cli_abort("{.arg catch} should be a logical.", call = call) + check_number_whole(x$verbosity, call = call) + check_bool(x$catch, call = call) x } diff --git a/R/required_pkgs.R b/R/required_pkgs.R index e025f7c01..56b709284 100644 --- a/R/required_pkgs.R +++ b/R/required_pkgs.R @@ -26,12 +26,14 @@ required_pkgs.model_spec <- function(x, infra = TRUE, ...) { if (is.null(x$engine)) { cli::cli_abort("Please set an engine.") } + check_bool(infra) get_pkgs(x, infra) } #' @export #' @rdname required_pkgs.model_spec required_pkgs.model_fit <- function(x, infra = TRUE, ...) { + check_bool(infra) get_pkgs(x$spec, infra) } diff --git a/R/tidy_glmnet.R b/R/tidy_glmnet.R index 59761ca68..55dd8e232 100644 --- a/R/tidy_glmnet.R +++ b/R/tidy_glmnet.R @@ -55,15 +55,9 @@ get_glmn_coefs <- function(x, penalty = 0.01) { res } -tidy_glmnet <- function(x, penalty = NULL, ...) { +tidy_glmnet <- function(x, penalty = NULL, ..., call = caller_env()) { check_installs(x$spec) load_libs(x$spec, quiet = TRUE, attach = TRUE) - if (is.null(penalty)) { - if (isTRUE(is.numeric(x$spec$args$penalty))){ - penalty <- x$spec$args$penalty - } else { - rlang::abort("Please pick a single value of `penalty`.") - } - } + check_number_decimal(penalty, min = 0, max = 1, allow_null = TRUE, call = call) get_glmn_coefs(x$fit, penalty = penalty) } diff --git a/R/tune_args.R b/R/tune_args.R index dc4a4a607..a37e68fe9 100644 --- a/R/tune_args.R +++ b/R/tune_args.R @@ -56,9 +56,10 @@ tune_tbl <- function(name = character(), source = character(), component = character(), component_id = character(), - full = FALSE) { - + full = FALSE, + call = caller_env()) { + check_bool(full, call = call) complete_id <- id[!is.na(id)] dups <- duplicated(complete_id) if (any(dups)) {