Skip to content

Commit

Permalink
Merge branch 'main' into sparse-input
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo authored Aug 28, 2024
2 parents f2faed9 + ed86b96 commit f0f92e9
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 109 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

* `fit_xy()` can now take dgCMatrix input for `x` argument (#1121).

* Transitioned package errors and warnings to use cli (#1147 and #1148 by
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
#1161).

* `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775).

* Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083).
Expand Down
4 changes: 3 additions & 1 deletion R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ update.bart <-
#' @param std_err Attach column for standard error of prediction or not.
bartMachine_interval_calc <- function(new_data, obj, ci = TRUE, level = 0.95) {
if (obj$spec$mode == "classification") {
rlang::abort("In bartMachine: Prediction intervals are not possible for classification")
cli::cli_abort(
"Prediction intervals are not possible for classification"
)
}
get_std_err <- obj$spec$method$pred$pred_int$extras$std_error

Expand Down
58 changes: 34 additions & 24 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,9 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {

if (engine == "spark") {
if (x$mode == "unknown") {
rlang::abort(
glue::glue(
"For spark boosted trees models, the mode cannot be 'unknown' ",
"if the specification is to be translated."
)
cli::cli_abort(
"For spark boosted tree models, the mode cannot be {.val unknown}
if the specification is to be translated."
)
} else {
arg_vals$type <- x$mode
Expand Down Expand Up @@ -172,7 +170,7 @@ check_args.boost_tree <- function(object, call = rlang::caller_env()) {
check_number_decimal(args$sample_size, min = 0, max = 1, allow_null = TRUE, call = call, arg = "sample_size")
check_number_whole(args$tree_depth, min = 0, allow_null = TRUE, call = call, arg = "tree_depth")
check_number_whole(args$min_n, min = 0, allow_null = TRUE, call = call, arg = "min_n")

invisible(object)
}

Expand Down Expand Up @@ -229,15 +227,15 @@ xgb_train <- function(
num_class <- length(levels(y))

if (!is.numeric(validation) || validation < 0 || validation >= 1) {
rlang::abort("`validation` should be on [0, 1).")
cli::cli_abort("{.arg validation} should be on [0, 1).")
}

if (!is.null(early_stop)) {
if (early_stop <= 1) {
rlang::abort(paste0("`early_stop` should be on [2, ", nrounds, ")."))
cli::cli_abort("{.arg early_stop} should be on [2, {nrounds}).")
} else if (early_stop >= nrounds) {
early_stop <- nrounds - 1
rlang::warn(paste0("`early_stop` was reduced to ", early_stop, "."))
cli::cli_warn("{.arg early_stop} was reduced to {early_stop}.")
}
}

Expand All @@ -252,7 +250,7 @@ xgb_train <- function(


if (!is.numeric(subsample) || subsample < 0 || subsample > 1) {
rlang::abort("`subsample` should be on [0, 1].")
cli::cli_abort("{.arg subsample} should be on [0, 1].")
}

# initialize
Expand All @@ -268,9 +266,13 @@ xgb_train <- function(
}

if (min_child_weight > n) {
msg <- paste0(min_child_weight, " samples were requested but there were ",
n, " rows in the data. ", n, " will be used.")
rlang::warn(msg)
cli::cli_warn(
c(
"!" = "{min_child_weight} samples were requested but there were {n} rows
in the data.",
"i" = "{n} will be used."
)
)
min_child_weight <- min(min_child_weight, n)
}

Expand Down Expand Up @@ -369,14 +371,16 @@ recalc_param <- function(x, counts, denom) {
x
}

maybe_proportion <- function(x, nm) {
maybe_proportion <- function(x, nm, call = rlang::caller_env()) {
if (x < 1) {
msg <- paste0(
"The option `counts = TRUE` was used but parameter `", nm,
"` was given as ", signif(x, 3), ". Please use a value >= 1 or use ",
"`counts = FALSE`."
cli::cli_abort(
c(
"The option `counts = TRUE` was used but {.arg {nm}} was given
as {signif(x, 3)}.",
"i" = "Please use a value >= 1 or use {.code counts = FALSE}."
),
call = call
)
rlang::abort(msg)
}
}

Expand Down Expand Up @@ -418,7 +422,9 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir
y <- as.numeric(y) - 1
}
} else {
if (event_level == "second") rlang::warn("`event_level` can only be set for binary variables.")
if (event_level == "second") {
cli::cli_warn("{.arg event_level} can only be set for binary outcomes.")
}
y <- as.numeric(y) - 1
}
}
Expand Down Expand Up @@ -573,15 +579,19 @@ C5.0_train <-

n <- nrow(x)
if (n == 0) {
rlang::abort("There are zero rows in the predictor set.")
cli::cli_abort("There are zero rows in the predictor set.")
}


ctrl <- call2("C5.0Control", .ns = "C50")
if (minCases > n) {
msg <- paste0(minCases, " samples were requested but there were ",
n, " rows in the data. ", n, " will be used.")
rlang::warn(msg)

cli::cli_warn(
c(
"!" = "{minCases} samples were requested but there were {n} rows in the data.",
"i" = "{n} will be used."
)
)
minCases <- n
}
ctrl$minCases <- minCases
Expand Down
8 changes: 3 additions & 5 deletions R/decision_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,9 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {

if (x$engine == "spark") {
if (x$mode == "unknown") {
rlang::abort(
glue::glue(
"For spark decision tree models, the mode cannot be 'unknown' ",
"if the specification is to be translated."
)
cli::cli_abort(
"For spark decision tree models, the mode cannot be {.val unknown}
if the specification is to be translated."
)
}
}
Expand Down
18 changes: 9 additions & 9 deletions R/nearest_neighbor_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ set_pred(
# model mode
pre = function(x, object) {
if (object$fit$response != "continuous") {
rlang::abort(
glue::glue("`kknn` model does not appear to use numeric predictions.",
" Was the model fit with a continuous response variable?")
cli::cli_abort(
c("`kknn` model does not appear to use numeric predictions.",
"i" = "Was the model fit with a continuous response variable?")
)
}
x
Expand Down Expand Up @@ -136,9 +136,9 @@ set_pred(
value = list(
pre = function(x, object) {
if (!(object$fit$response %in% c("ordinal", "nominal"))) {
rlang::abort(
glue::glue("`kknn` model does not appear to use class predictions.",
" Was the model fit with a factor response variable?")
cli::cli_abort(
c("`kknn` model does not appear to use class predictions.",
"i" = "Was the model fit with a factor response variable?")
)
}
x
Expand All @@ -162,9 +162,9 @@ set_pred(
value = list(
pre = function(x, object) {
if (!(object$fit$response %in% c("ordinal", "nominal"))) {
rlang::abort(
glue::glue("`kknn` model does not appear to use class predictions.",
" Was the model fit with a factor response variable?")
cli::cli_abort(
c("`kknn` model does not appear to use class predictions.",
"i" = "Was the model fit with a factor response variable?")
)
}
x
Expand Down
Loading

0 comments on commit f0f92e9

Please sign in to comment.