diff --git a/DESCRIPTION b/DESCRIPTION index a1b30100b..d7ca644f7 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.1.1.9003 +Version: 1.1.1.9007 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), @@ -76,4 +76,4 @@ Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3.9000 +RoxygenNote: 7.2.3 diff --git a/NAMESPACE b/NAMESPACE index cf2cff73d..2465d298d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,7 +2,6 @@ S3method(.censoring_weights_graf,default) S3method(.censoring_weights_graf,model_fit) -S3method(.censoring_weights_graf,workflow) S3method(augment,model_fit) S3method(autoplot,glmnet) S3method(autoplot,model_fit) @@ -179,6 +178,7 @@ export(bag_tree) export(bart) export(bartMachine_interval_calc) export(boost_tree) +export(case_weights_allowed) export(cforest_train) export(check_empty_ellipse) export(check_final_param) diff --git a/NEWS.md b/NEWS.md index a0e2b6dd2..a479f5c59 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ * Fixed bug in fitting some model types with the `"spark"` engine (#1045). +* `.filter_eval_time()` was moved to the survival standalone file. + * Improved errors and documentation related to special terms in formulas. See `?model_formula` to learn more. (#770, #1014) * Improved errors in cases where the outcome column is mis-specified. (#1003) @@ -12,6 +14,10 @@ * When computing censoring weights, the resulting vectors are no longer named (#1023). +* Fixed a bug in the integration with workflows where using a model formula with a formula preprocessor could result in a double intercept (#1033). + +* The `predict()` method for `censoring_model_reverse_km` objects now checks that `...` are empty (#1029). + # parsnip 1.1.1 * Fixed bug where prediction on rank deficient `lm()` models produced `.pred_res` instead of `.pred`. (#985) diff --git a/R/bag_tree.R b/R/bag_tree.R index 1f8e650e9..915a13deb 100644 --- a/R/bag_tree.R +++ b/R/bag_tree.R @@ -10,6 +10,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' +#' @inheritParams boost_tree #' @inheritParams decision_tree #' @param class_cost A non-negative scalar for a class cost (where a cost of 1 #' means no extra cost). This is useful for when the first level of the outcome diff --git a/R/bart.R b/R/bart.R index 63c8d83e5..0251680b7 100644 --- a/R/bart.R +++ b/R/bart.R @@ -11,6 +11,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' +#' @inheritParams nearest_neighbor #' @inheritParams boost_tree #' @param prior_terminal_node_coef A coefficient for the prior probability that #' a node is a terminal node. Values are usually between 0 and one with diff --git a/R/boost_tree.R b/R/boost_tree.R index 59cbf6f79..f56da2fef 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -15,8 +15,8 @@ #' \url{https://www.tidymodels.org/}. #' #' @param mode A single character string for the prediction outcome mode. -#' Possible values for this model are "unknown", "regression", or -#' "classification". +#' Possible values for this model are "unknown", "regression", +#' "classification", or "censored regression". #' @param engine A single character string specifying what computational engine #' to use for fitting. #' @param mtry A number for the number (or proportion) of predictors that will diff --git a/R/c5_rules.R b/R/c5_rules.R index c48d0382b..87ec7bb83 100644 --- a/R/c5_rules.R +++ b/R/c5_rules.R @@ -13,7 +13,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param mode A single character string for the type of model. #' The only possible value for this model is "classification". #' @param trees A non-negative integer (no greater than 100) for the number diff --git a/R/case_weights.R b/R/case_weights.R index 8c5c0a790..5709286bb 100644 --- a/R/case_weights.R +++ b/R/case_weights.R @@ -67,6 +67,17 @@ patch_formula_environment_with_case_weights <- function(formula, # ------------------------------------------------------------------------------ +#' Determine if case weights are used +#' +#' Not all modeling engines can incorporate case weights into their +#' calculations. This function can determine whether they can be used. +#' +#' @param spec A parsnip model specification. +#' @return A single logical. +#' @examples +#' case_weights_allowed(linear_reg()) +#' case_weights_allowed(linear_reg(engine = "keras")) +#' @export case_weights_allowed <- function(spec) { mod_type <- class(spec)[1] mod_eng <- spec$engine diff --git a/R/convert_data.R b/R/convert_data.R index ae4d1c426..39494bb0a 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -45,6 +45,10 @@ rlang::abort("`composition` should be either 'data.frame' or 'matrix'.") } + if (remove_intercept) { + data <- data[, colnames(data) != "(Intercept)", drop = FALSE] + } + ## Assemble model.frame call from call arguments mf_call <- quote(model.frame(formula, data)) mf_call$na.action <- match.call()$na.action # TODO this should work better diff --git a/R/cubist_rules.R b/R/cubist_rules.R index b1eb3f237..be5c8f783 100644 --- a/R/cubist_rules.R +++ b/R/cubist_rules.R @@ -10,7 +10,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param mode A single character string for the type of model. #' The only possible value for this model is "regression". #' @param committees A non-negative integer (no greater than 100) for the number diff --git a/R/discrim_flexible.R b/R/discrim_flexible.R index 3e3b04ec2..330e6bc25 100644 --- a/R/discrim_flexible.R +++ b/R/discrim_flexible.R @@ -11,7 +11,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @inheritParams discrim_linear #' @param num_terms The number of features that will be retained in the #' final model, including the intercept. diff --git a/R/discrim_linear.R b/R/discrim_linear.R index 1b9709b9c..469b8dc8f 100644 --- a/R/discrim_linear.R +++ b/R/discrim_linear.R @@ -13,7 +13,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param mode A single character string for the type of model. The only #' possible value for this model is "classification". #' @param penalty An non-negative number representing the amount of diff --git a/R/discrim_quad.R b/R/discrim_quad.R index df89f8aab..eb1ad8715 100644 --- a/R/discrim_quad.R +++ b/R/discrim_quad.R @@ -13,7 +13,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param mode A single character string for the type of model. The only #' possible value for this model is "classification". #' @param regularization_method A character string for the type of regularized diff --git a/R/discrim_regularized.R b/R/discrim_regularized.R index c27630d4d..776369627 100644 --- a/R/discrim_regularized.R +++ b/R/discrim_regularized.R @@ -13,7 +13,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @inheritParams discrim_linear #' @param frac_common_cov,frac_identity Numeric values between zero and one. #' diff --git a/R/fit_helpers.R b/R/fit_helpers.R index a54557ce3..494797598 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -8,6 +8,15 @@ form_form <- if (inherits(env$data, "data.frame")) { check_outcome(eval_tidy(rlang::f_lhs(env$formula), env$data), object) + + encoding_info <- + get_encoding(class(object)[1]) %>% + dplyr::filter(mode == object$mode, engine == object$engine) + + remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept) + if (remove_intercept) { + env$data <- env$data[, colnames(env$data) != "(Intercept)", drop = FALSE] + } } # prob rewrite this as simple subset/levels diff --git a/R/gen_additive_mod.R b/R/gen_additive_mod.R index 05e091245..8777a6b37 100644 --- a/R/gen_additive_mod.R +++ b/R/gen_additive_mod.R @@ -10,7 +10,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param select_features `TRUE` or `FALSE.` If `TRUE`, the model has the #' ability to eliminate a predictor (via penalization). Increasing #' `adjust_deg_free` will increase the likelihood of removing predictors. diff --git a/R/mars.R b/R/mars.R index 3a29f0f70..86c5330ef 100644 --- a/R/mars.R +++ b/R/mars.R @@ -12,7 +12,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param num_terms The number of features that will be retained in the #' final model, including the intercept. #' @param prod_degree The highest possible interaction degree. diff --git a/R/mlp.R b/R/mlp.R index 1f134ffcb..909d982e0 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -10,6 +10,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' +#' @inheritParams nearest_neighbor #' @inheritParams boost_tree #' @param hidden_units An integer for the number of units in the hidden model. #' @param penalty A non-negative numeric value for the amount of weight diff --git a/R/naive_Bayes.R b/R/naive_Bayes.R index 24877c32c..ffe91ae66 100644 --- a/R/naive_Bayes.R +++ b/R/naive_Bayes.R @@ -11,7 +11,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @inheritParams discrim_linear #' @param smoothness An non-negative number representing the the relative #' smoothness of the class boundary. Smaller examples result in model flexible diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 814673ef1..41b738095 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -11,7 +11,11 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @param mode A single character string for the prediction outcome mode. +#' Possible values for this model are "unknown", "regression", or +#' "classification". +#' @param engine A single character string specifying what computational engine +#' to use for fitting. #' @param neighbors A single integer for the number of neighbors #' to consider (often called `k`). For \pkg{kknn}, a value of 5 #' is used if `neighbors` is not specified. diff --git a/R/pls.R b/R/pls.R index 5b1d82b1d..491389d01 100644 --- a/R/pls.R +++ b/R/pls.R @@ -10,7 +10,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param predictor_prop The maximum proportion of original predictors that can #' have _non-zero_ coefficients for each PLS component (via regularization). #' This value is used for all PLS components for X. diff --git a/R/poisson_reg.R b/R/poisson_reg.R index 71a1d56b0..b9e103ddd 100644 --- a/R/poisson_reg.R +++ b/R/poisson_reg.R @@ -10,7 +10,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param mode A single character string for the type of model. #' The only possible value for this model is "regression". #' @param penalty A non-negative number representing the total diff --git a/R/proportional_hazards.R b/R/proportional_hazards.R index b144ff4a3..038e19c94 100644 --- a/R/proportional_hazards.R +++ b/R/proportional_hazards.R @@ -10,7 +10,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @inheritParams linear_reg #' @param mode A single character string for the prediction outcome mode. #' The only possible value for this model is "censored regression". diff --git a/R/rule_fit.R b/R/rule_fit.R index af6be8347..81a6ce48d 100644 --- a/R/rule_fit.R +++ b/R/rule_fit.R @@ -10,6 +10,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' +#' @inheritParams nearest_neighbor #' @inheritParams boost_tree #' @param penalty L1 regularization parameter. #' @details diff --git a/R/standalone-survival.R b/R/standalone-survival.R index df4808081..c655bbf5e 100644 --- a/R/standalone-survival.R +++ b/R/standalone-survival.R @@ -1,26 +1,34 @@ # --- # repo: tidymodels/parsnip # file: standalone-survival.R -# last-updated: 2023-06-14 +# last-updated: 2024-01-10 # license: https://unlicense.org # --- -# This file provides a portable set of helper functions for Surv objects +# This file provides a portable set of helper functions for survival analysis. +# # ## Changelog - -# 2023-02-28: -# * Initial version +# 2024-01-10 +# * .filter_eval_time() gives more informative warning. # -# 2023-05-18 -# * added time to factor conversion +# 2023-12-08 +# * move .filter_eval_time() to this file +# +# 2023-11-09 +# * make sure survival vectors are unnamed. # # 2023-06-14 # * removed time to factor conversion # -# 2023-11-09 -# * make sure survival vectors are unnamed. - +# 2023-05-18 +# * added time to factor conversion +# +# 2023-02-28: +# * Initial version +# +# ------------------------------------------------------------------------------ +# # @param surv A [survival::Surv()] object # @details # `.is_censored_right()` always returns a logical while @@ -51,17 +59,21 @@ attr(surv, "type") } -.check_cens_type <- function(surv, type = "right", fail = TRUE, call = rlang::caller_env()) { - .is_surv(surv, call = call) - obj_type <- .extract_surv_type(surv) - good_type <- all(obj_type %in% type) - if (!good_type && fail) { - c_list <- paste0("'", type, "'") - msg <- cli::format_inline("For this usage, the allowed censoring type{?s} {?is/are}: {c_list}") - rlang::abort(msg, call = call) +.check_cens_type <- + function(surv, + type = "right", + fail = TRUE, + call = rlang::caller_env()) { + .is_surv(surv, call = call) + obj_type <- .extract_surv_type(surv) + good_type <- all(obj_type %in% type) + if (!good_type && fail) { + c_list <- paste0("'", type, "'") + msg <- cli::format_inline("For this usage, the allowed censoring type{?s} {?is/are}: {c_list}") + rlang::abort(msg, call = call) + } + good_type } - good_type -} .is_censored_right <- function(surv) { .check_cens_type(surv, type = "right", fail = FALSE) @@ -88,7 +100,8 @@ .is_surv(surv) res <- surv[, "status"] un_vals <- sort(unique(res)) - event_type_to_01 <- !(.extract_surv_type(surv) %in% c("interval", "interval2", "mstate")) + event_type_to_01 <- + !(.extract_surv_type(surv) %in% c("interval", "interval2", "mstate")) if ( event_type_to_01 && (identical(un_vals, 1:2) | identical(un_vals, c(1.0, 2.0))) ) { @@ -96,4 +109,64 @@ } unname(res) } + # nocov end + +# ------------------------------------------------------------------------------ + +# @param eval_time A vector of numeric time points +# @details +# `.filter_eval_time` checks the validity of the time points. +# +# @return A potentially modified vector of time points. +.filter_eval_time <- function(eval_time, fail = TRUE) { + if (!is.null(eval_time)) { + eval_time <- as.numeric(eval_time) + } + eval_time_0 <- eval_time + # will still propagate nulls: + eval_time <- eval_time[!is.na(eval_time)] + eval_time <- eval_time[eval_time >= 0 & is.finite(eval_time)] + eval_time <- unique(eval_time) + if (fail && identical(eval_time, numeric(0))) { + cli::cli_abort( + "There were no usable evaluation times (finite, non-missing, and >= 0).", + call = NULL + ) + } + if (!identical(eval_time, eval_time_0)) { + diffs <- length(eval_time_0) - length(eval_time) + + offenders <- character() + + n_na <- sum(is.na(eval_time_0)) + if (n_na > 0) { + offenders <- c(offenders, "*" = "{n_na} missing value{?s}.") + } + + n_inf <- sum(is.infinite(eval_time_0)) + if (n_inf > 0) { + offenders <- c(offenders, "*" = "{n_inf} infinite value{?s}.") + } + + n_neg <- sum(eval_time_0 < 0, na.rm = TRUE) + if (n_neg > 0) { + offenders <- c(offenders, "*" = "{n_neg} negative value{?s}.") + } + + n_dup <- diffs - n_na - n_inf - n_neg + if (n_dup > 0) { + offenders <- c(offenders, "*" = "{n_dup} duplicate value{?s}.") + } + + cli::cli_warn( + c( + "There {?was/were} {diffs} inappropriate evaluation time \\ + point{?s} that {?was/were} removed. {?It was/They were}:", + offenders + ), + call = NULL + ) + } + eval_time +} diff --git a/R/surv_reg.R b/R/surv_reg.R index 90d005d24..f5e72431d 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -11,7 +11,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param mode A single character string for the prediction outcome mode. #' The only possible value for this model is "regression". #' @param dist A character string for the probability distribution of the diff --git a/R/survival-censoring-model.R b/R/survival-censoring-model.R index f005a847c..0509882a8 100644 --- a/R/survival-censoring-model.R +++ b/R/survival-censoring-model.R @@ -62,6 +62,8 @@ predict.censoring_model <- function(object, ...) { #' @export predict.censoring_model_reverse_km <- function(object, new_data, time, as_vector = FALSE, ...) { + rlang::check_dots_empty() + rlang::check_installed("prodlim", version = "2022.10.13") rlang::check_installed("censored", version = "0.1.1.9002") diff --git a/R/survival-censoring-weights.R b/R/survival-censoring-weights.R index 28b935c0c..2b7b48df8 100644 --- a/R/survival-censoring-weights.R +++ b/R/survival-censoring-weights.R @@ -18,31 +18,6 @@ trunc_probs <- function(probs, trunc = 0.01) { probs } -.filter_eval_time <- function(eval_time, fail = TRUE) { - if (!is.null(eval_time)) { - eval_time <- as.numeric(eval_time) - } - eval_time_0 <- eval_time - # will still propagate nulls: - eval_time <- eval_time[!is.na(eval_time)] - eval_time <- eval_time[eval_time >= 0 & is.finite(eval_time)] - eval_time <- unique(eval_time) - if (fail && identical(eval_time, numeric(0))) { - rlang::abort( - "There were no usable evaluation times (finite, non-missing, and >= 0).", - call = NULL - ) - } - if (!identical(eval_time, eval_time_0)) { - diffs <- setdiff(eval_time_0, eval_time) - msg <- - cli::pluralize( - "There {?was/were} {length(diffs)} inappropriate evaluation time point{?s} that {?was/were} removed.") - rlang::warn(msg) - } - eval_time -} - # nocov start # these are tested in extratests @@ -203,19 +178,6 @@ graf_weight_time_vec <- function(surv_obj, eval_time, eps = 10^-10) { rlang::abort(msg) } - -#' @export -#' @rdname censoring_weights -.censoring_weights_graf.workflow <- function(object, - predictions, - cens_predictors = NULL, - trunc = 0.05, eps = 10^-10, ...) { - if (is.null(object$fit$fit)) { - rlang::abort("The workflow does not have a model fit object.") - } - .censoring_weights_graf(object$fit$fit, predictions, cens_predictors, trunc, eps) -} - #' @export #' @rdname censoring_weights .censoring_weights_graf.model_fit <- function(object, diff --git a/R/survival_reg.R b/R/survival_reg.R index fff9b179a..0ab4cb7a9 100644 --- a/R/survival_reg.R +++ b/R/survival_reg.R @@ -9,7 +9,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param mode A single character string for the prediction outcome mode. #' The only possible value for this model is "censored regression". #' @param dist A character string for the probability distribution of the diff --git a/R/svm_linear.R b/R/svm_linear.R index 4d3a27fa4..178368a40 100644 --- a/R/svm_linear.R +++ b/R/svm_linear.R @@ -13,7 +13,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param cost A positive number for the cost of predicting a sample within #' or on the wrong side of the margin #' @param margin A positive number for the epsilon in the SVM insensitive diff --git a/R/svm_poly.R b/R/svm_poly.R index c9bdce79f..d031a2224 100644 --- a/R/svm_poly.R +++ b/R/svm_poly.R @@ -14,7 +14,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param cost A positive number for the cost of predicting a sample within #' or on the wrong side of the margin #' @param degree A positive number for polynomial degree. diff --git a/R/svm_rbf.R b/R/svm_rbf.R index eb6cc7c60..8d2d69f8a 100644 --- a/R/svm_rbf.R +++ b/R/svm_rbf.R @@ -14,7 +14,7 @@ #' More information on how \pkg{parsnip} is used for modeling is at #' \url{https://www.tidymodels.org/}. #' -#' @inheritParams boost_tree +#' @inheritParams nearest_neighbor #' @param engine A single character string specifying what computational engine #' to use for fitting. Possible engines are listed below. The default for this #' model is `"kernlab"`. diff --git a/_pkgdown.yml b/_pkgdown.yml index 357aac829..78e5d56cc 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -71,6 +71,7 @@ reference: - add_rowindex - augment.model_fit - case_weights + - case_weights_allowed - descriptors - extract-parsnip - fit.model_spec diff --git a/man/bag_tree.Rd b/man/bag_tree.Rd index 98f3c2829..08d029433 100644 --- a/man/bag_tree.Rd +++ b/man/bag_tree.Rd @@ -15,16 +15,17 @@ bag_tree( } \arguments{ \item{mode}{A single character string for the prediction outcome mode. -Possible values for this model are "unknown", "regression", or -"classification".} +Possible values for this model are "unknown", "regression", +"classification", or "censored regression".} \item{cost_complexity}{A positive number for the the cost/complexity parameter (a.k.a. \code{Cp}) used by CART models (specific engines only).} -\item{tree_depth}{An integer for maximum depth of the tree.} +\item{tree_depth}{An integer for the maximum depth of the tree (i.e. number +of splits) (specific engines only).} \item{min_n}{An integer for the minimum number of data points -in a node that are required for the node to be split further.} +in a node that is required for the node to be split further.} \item{class_cost}{A non-negative scalar for a class cost (where a cost of 1 means no extra cost). This is useful for when the first level of the outcome diff --git a/man/boost_tree.Rd b/man/boost_tree.Rd index aa7ff82b8..a36a4de25 100644 --- a/man/boost_tree.Rd +++ b/man/boost_tree.Rd @@ -19,8 +19,8 @@ boost_tree( } \arguments{ \item{mode}{A single character string for the prediction outcome mode. -Possible values for this model are "unknown", "regression", or -"classification".} +Possible values for this model are "unknown", "regression", +"classification", or "censored regression".} \item{engine}{A single character string specifying what computational engine to use for fitting.} diff --git a/man/case_weights_allowed.Rd b/man/case_weights_allowed.Rd new file mode 100644 index 000000000..1c070b6b5 --- /dev/null +++ b/man/case_weights_allowed.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/case_weights.R +\name{case_weights_allowed} +\alias{case_weights_allowed} +\title{Determine if case weights are used} +\usage{ +case_weights_allowed(spec) +} +\arguments{ +\item{spec}{A parsnip model specification.} +} +\value{ +A single logical. +} +\description{ +Not all modeling engines can incorporate case weights into their +calculations. This function can determine whether they can be used. +} +\examples{ +case_weights_allowed(linear_reg()) +case_weights_allowed(linear_reg(engine = "keras")) +} diff --git a/man/censoring_weights.Rd b/man/censoring_weights.Rd index 91051e9d9..85e7907c8 100644 --- a/man/censoring_weights.Rd +++ b/man/censoring_weights.Rd @@ -4,7 +4,6 @@ \alias{censoring_weights} \alias{.censoring_weights_graf} \alias{.censoring_weights_graf.default} -\alias{.censoring_weights_graf.workflow} \alias{.censoring_weights_graf.model_fit} \title{Calculations for inverse probability of censoring weights (IPCW)} \usage{ @@ -12,15 +11,6 @@ \method{.censoring_weights_graf}{default}(object, ...) -\method{.censoring_weights_graf}{workflow}( - object, - predictions, - cens_predictors = NULL, - trunc = 0.05, - eps = 10^-10, - ... -) - \method{.censoring_weights_graf}{model_fit}( object, predictions, diff --git a/man/decision_tree.Rd b/man/decision_tree.Rd index f0ae8fce7..fbdb31742 100644 --- a/man/decision_tree.Rd +++ b/man/decision_tree.Rd @@ -14,8 +14,8 @@ decision_tree( } \arguments{ \item{mode}{A single character string for the prediction outcome mode. -Possible values for this model are "unknown", "regression", or -"classification".} +Possible values for this model are "unknown", "regression", +"classification", or "censored regression".} \item{engine}{A single character string specifying what computational engine to use for fitting.} diff --git a/man/glmnet-details.Rd b/man/glmnet-details.Rd index 4fd6a4f65..9a1c6c6fa 100644 --- a/man/glmnet-details.Rd +++ b/man/glmnet-details.Rd @@ -189,7 +189,7 @@ originally requested: ## 4 hp -0.0101 1 ## 5 drat 0 1 ## 6 wt -2.59 1 -## # … with 5 more rows +## # ℹ 5 more rows }\if{html}{\out{}} Note that there is a \code{tidy()} method for \code{glmnet} objects in the \code{broom} @@ -210,7 +210,7 @@ all_tidy_coefs ## 4 (Intercept) 4 24.7 3.89 0.347 ## 5 (Intercept) 5 26.0 3.55 0.429 ## 6 (Intercept) 6 27.2 3.23 0.497 -## # … with 634 more rows +## # ℹ 634 more rows }\if{html}{\out{}} \if{html}{\out{
}}\preformatted{length(unique(all_tidy_coefs$lambda)) diff --git a/man/rand_forest.Rd b/man/rand_forest.Rd index aaae5dddc..1ec974b0e 100644 --- a/man/rand_forest.Rd +++ b/man/rand_forest.Rd @@ -14,8 +14,8 @@ rand_forest( } \arguments{ \item{mode}{A single character string for the prediction outcome mode. -Possible values for this model are "unknown", "regression", or -"classification".} +Possible values for this model are "unknown", "regression", +"classification", or "censored regression".} \item{engine}{A single character string specifying what computational engine to use for fitting.} diff --git a/tests/testthat/_snaps/survival-censoring-weights.md b/tests/testthat/_snaps/standalone-survival.md similarity index 74% rename from tests/testthat/_snaps/survival-censoring-weights.md rename to tests/testthat/_snaps/standalone-survival.md index 1546e3fdb..4e1a67e42 100644 --- a/tests/testthat/_snaps/survival-censoring-weights.md +++ b/tests/testthat/_snaps/standalone-survival.md @@ -4,7 +4,8 @@ parsnip:::.filter_eval_time(times_duplicated) Condition Warning: - There were 0 inappropriate evaluation time points that were removed. + There were 11 inappropriate evaluation time points that were removed. They were: + * 11 duplicate values. Output [1] 0 1 2 3 4 5 6 7 8 9 10 @@ -22,7 +23,10 @@ parsnip:::.filter_eval_time(times_remove_plural) Condition Warning: - There were 3 inappropriate evaluation time points that were removed. + There were 3 inappropriate evaluation time points that were removed. They were: + * 1 missing value. + * 1 infinite value. + * 1 negative value. Output [1] 0 1 2 3 4 5 6 7 8 9 10 @@ -32,7 +36,8 @@ parsnip:::.filter_eval_time(times_remove_singular) Condition Warning: - There was 1 inappropriate evaluation time point that was removed. + There was 1 inappropriate evaluation time point that was removed. It was: + * 1 negative value. Output [1] 0 1 2 3 4 5 6 7 8 9 10 diff --git a/tests/testthat/test-standalone-survival.R b/tests/testthat/test-standalone-survival.R new file mode 100644 index 000000000..eac864921 --- /dev/null +++ b/tests/testthat/test-standalone-survival.R @@ -0,0 +1,28 @@ +test_that(".filter_eval_time()", { + times_basic <- 0:10 + expect_equal( + parsnip:::.filter_eval_time(times_basic), + times_basic + ) + + times_dont_reorder <- c(10, 1:9) + expect_equal( + parsnip:::.filter_eval_time(times_dont_reorder), + times_dont_reorder + ) + + expect_null(parsnip:::.filter_eval_time(NULL)) + + times_duplicated <- c(times_basic, times_basic) + expect_snapshot( + parsnip:::.filter_eval_time(times_duplicated) + ) + + expect_snapshot(error = TRUE, parsnip:::.filter_eval_time(-1)) + + times_remove_plural <- c(Inf, NA, -3, times_basic) + expect_snapshot(parsnip:::.filter_eval_time(times_remove_plural)) + + times_remove_singular <- c(-3, times_basic) + expect_snapshot(parsnip:::.filter_eval_time(times_remove_singular)) +}) diff --git a/tests/testthat/test-survival-censoring-weights.R b/tests/testthat/test-survival-censoring-weights.R index 3885c35a4..770b5f6a1 100644 --- a/tests/testthat/test-survival-censoring-weights.R +++ b/tests/testthat/test-survival-censoring-weights.R @@ -21,32 +21,3 @@ test_that("probability truncation via trunc_probs()", { probs ) }) - -test_that(".filter_eval_time()", { - times_basic <- 0:10 - expect_equal( - parsnip:::.filter_eval_time(times_basic), - times_basic - ) - - times_dont_reorder <- c(10, 1:9) - expect_equal( - parsnip:::.filter_eval_time(times_dont_reorder), - times_dont_reorder - ) - - expect_null(parsnip:::.filter_eval_time(NULL)) - - times_duplicated <- c(times_basic, times_basic) - expect_snapshot( - parsnip:::.filter_eval_time(times_duplicated) - ) - - expect_snapshot(error = TRUE, parsnip:::.filter_eval_time(-1)) - - times_remove_plural <- c(Inf, NA, -3, times_basic) - expect_snapshot(parsnip:::.filter_eval_time(times_remove_plural)) - - times_remove_singular <- c(-3, times_basic) - expect_snapshot(parsnip:::.filter_eval_time(times_remove_singular)) -})