Skip to content

Commit

Permalink
Merge pull request #457 from tidymodels/cli-phase-two
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Nov 15, 2023
2 parents ec51a38 + 3a82666 commit 00ea962
Show file tree
Hide file tree
Showing 39 changed files with 190 additions and 149 deletions.
8 changes: 6 additions & 2 deletions R/class-kap.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,19 @@ make_weighting_matrix <- function(weighting, n_levels, call = caller_env()) {

validate_weighting <- function(x, call = caller_env()) {
if (!is_string(x)) {
abort("`weighting` must be a string.", call = call)
cli::cli_abort("{.arg weighting} must be a string.", call = call)
}

ok <- is_no_weighting(x) ||
is_linear_weighting(x) ||
is_quadratic_weighting(x)

if (!ok) {
abort("`weighting` must be 'none', 'linear', or 'quadratic'.", call = call)
cli::cli_abort(
"{.arg weighting} must be {.val none}, {.val linear}, or \\
{.val quadratic}, not {.val {x}}.",
call = call
)
}

invisible(x)
Expand Down
5 changes: 4 additions & 1 deletion R/class-mcc.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ mcc_multiclass_impl <- function(C) {

check_mcc_data <- function(data) {
if (!is.double(data) && !is.matrix(data)) {
abort("`data` should be a double matrix at this point.", .internal = TRUE)
cli::cli_abort(
"{.arg data} should be a double matrix at this point.",
.internal = TRUE
)
}
invisible()
}
9 changes: 7 additions & 2 deletions R/conf_mat.R
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ conf_mat_impl <- function(truth, estimate, case_weights, call = caller_env()) {
check_class_metric(truth, estimate, case_weights, estimator, call = call)

if (length(levels(truth)) < 2) {
abort("`truth` must have at least 2 factor levels.", call = call)
cli::cli_abort(
"{.arg truth} must have at least 2 factor levels.",
call = call
)
}

yardstick_table(
Expand All @@ -245,7 +248,9 @@ conf_mat.table <- function(data, ...) {
num_lev <- length(class_lev)

if (num_lev < 2) {
abort("There must be at least 2 factors levels in the `data`")
cli::cli_abort(
"There must be at least 2 factors levels in the {.arg data}."
)
}

structure(
Expand Down
79 changes: 57 additions & 22 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

pos_val <- function(xtab, event_level) {
if (!all(dim(xtab) == 2)) {
abort("Only relevant for 2x2 tables")
cli::cli_abort("Only relevant for 2x2 tables.")
}

if (is_event_first(event_level)) {
Expand All @@ -16,7 +16,7 @@ pos_val <- function(xtab, event_level) {

neg_val <- function(xtab, event_level) {
if (!all(dim(xtab) == 2)) {
abort("Only relevant for 2x2 tables")
cli::cli_abort("Only relevant for 2x2 tables.")
}

if (is_event_first(event_level)) {
Expand Down Expand Up @@ -67,19 +67,19 @@ as_factor_from_class_pred <- function(x) {
}

if (!is_installed("probably")) {
abort(paste0(
"A <class_pred> input was detected, but the probably package ",
"isn't installed. Install probably to be able to convert <class_pred> ",
"to <factor>."
))
cli::cli_abort(
"A {.cls class_pred} input was detected, but the {.pkg probably} \\
package isn't installed. Install {.pkg probably} to be able to convert \\
{.cls class_pred} to {.cls factor}."
)
}
probably::as.factor(x)
}

abort_if_class_pred <- function(x, call = caller_env()) {
if (is_class_pred(x)) {
abort(
"`truth` should not a `class_pred` object.",
cli::cli_abort(
"{.arg truth} should not a {.cls class_pred} object.",
call = call
)
}
Expand Down Expand Up @@ -186,10 +186,18 @@ yardstick_cov <- function(truth,

size <- vec_size(truth)
if (size != vec_size(estimate)) {
abort("`truth` and `estimate` must be the same size.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and \\
{.arg estimate} ({vec_size(estimate)}) must be the same size.",
.internal = TRUE
)
}
if (size != vec_size(case_weights)) {
abort("`truth` and `case_weights` must be the same size.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and \\
{.arg case_weights} ({vec_size(case_weights)}) must be the same size.",
.internal = TRUE
)
}

if (size == 0L || size == 1L) {
Expand Down Expand Up @@ -232,10 +240,18 @@ yardstick_cor <- function(truth,

size <- vec_size(truth)
if (size != vec_size(estimate)) {
abort("`truth` and `estimate` must be the same size.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and \\
{.arg estimate} ({vec_size(estimate)}) must be the same size.",
.internal = TRUE
)
}
if (size != vec_size(case_weights)) {
abort("`truth` and `case_weights` must be the same size.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and \\
{.arg case_weights} ({vec_size(case_weights)}) must be the same size.",
.internal = TRUE
)
}

if (size == 0L || size == 1L) {
Expand Down Expand Up @@ -345,14 +361,17 @@ weighted_quantile <- function(x, weights, probabilities) {

size <- vec_size(x)
if (size != vec_size(weights)) {
abort("`x` and `weights` must have the same size.")
cli::cli_abort(
"{.arg x} ({vec_size(x)}) and {.arg weights} ({vec_size(weights)}) \\
must have the same size."
)
}

if (any(is.na(probabilities))) {
abort("`probabilities` can't be missing.")
cli::cli_abort("{.arg probabilities} can't have missing values.")
}
if (any(probabilities > 1 | probabilities < 0)) {
abort("`probabilities` must be within `[0, 1]`.")
cli::cli_abort("{.arg probabilities} must be within `[0, 1]`.")
}

if (size == 0L) {
Expand Down Expand Up @@ -397,20 +416,33 @@ yardstick_table <- function(truth, estimate, ..., case_weights = NULL) {
}

if (!is.factor(truth)) {
abort("`truth` must be a factor.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} must be a factor, not {.obj_type_friendly {truth}}.",
.internal = TRUE
)
}
if (!is.factor(estimate)) {
abort("`estimate` must be a factor.", .internal = TRUE)
cli::cli_abort(
"{.arg estimate} must be a factor, not {.obj_type_friendly {estimate}}.",
.internal = TRUE
)
}

levels <- levels(truth)
n_levels <- length(levels)

if (!identical(levels, levels(estimate))) {
abort("`truth` and `estimate` must have the same levels in the same order.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} and {.arg estimate} must have the same levels in the same \\
order.",
.internal = TRUE
)
}
if (n_levels < 2) {
abort("`truth` must have at least 2 factor levels.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} must have at least 2 factor levels.",
.internal = TRUE
)
}

# Supply `estimate` first to get it to correspond to the row names.
Expand Down Expand Up @@ -447,14 +479,17 @@ yardstick_truth_table <- function(truth, ..., case_weights = NULL) {
abort_if_class_pred(truth)

if (!is.factor(truth)) {
abort("`truth` must be a factor.", .internal = TRUE)
cli::cli_abort("{.arg truth} must be a factor.", .internal = TRUE)
}

levels <- levels(truth)
n_levels <- length(levels)

if (n_levels < 2) {
abort("`truth` must have at least 2 factor levels.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} must have at least 2 factor levels.",
.internal = TRUE
)
}

# Always return a double matrix for type stability
Expand Down
7 changes: 1 addition & 6 deletions R/num-huber_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,7 @@ huber_loss_impl <- function(truth,
# Weighted Huber Loss implementation confirmed against matlab:
# https://www.mathworks.com/help/deeplearning/ref/dlarray.huber.html

if (!is_bare_numeric(delta, n = 1L)) {
abort("`delta` must be a single numeric value.", call = call)
}
if (!(delta >= 0)) {
abort("`delta` must be a positive value.", call = call)
}
check_number_decimal(delta, min = 0, call = call)

a <- truth - estimate
abs_a <- abs(a)
Expand Down
48 changes: 7 additions & 41 deletions R/num-mase.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ mase_impl <- function(truth,
mae_train = NULL,
case_weights = NULL,
call = caller_env()) {
validate_m(m, call = call)
validate_mae_train(mae_train, call = call)
check_number_whole(m, min = 0, call = call)
check_number_decimal(mae_train, min = 0, allow_null = TRUE, call = call)

if (is.null(mae_train)) {
validate_truth_m(truth, m, call = call)
Expand All @@ -139,46 +139,12 @@ mase_impl <- function(truth,
out
}

validate_m <- function(m, call = caller_env()) {
abort_msg <- "`m` must be a single positive integer value."

if (!is_integerish(m, n = 1L)) {
abort(abort_msg, call = call)
}

if (!(m > 0)) {
abort(abort_msg, call = call)
}

invisible(m)
}

validate_mae_train <- function(mae_train, call = caller_env()) {
if (is.null(mae_train)) {
return(invisible(mae_train))
}

is_single_numeric <- is_bare_numeric(mae_train, n = 1L)
abort_msg <- "`mae_train` must be a single positive numeric value."

if (!is_single_numeric) {
abort(abort_msg, call = call)
}

if (!(mae_train > 0)) {
abort(abort_msg, call = call)
}

invisible(mae_train)
}

validate_truth_m <- function(truth, m, call = caller_env()) {
if (length(truth) <= m) {
abort(paste0(
"`truth` must have a length greater than `m` ",
"to compute the out-of-sample naive mean absolute error."
), call = call)
cli::cli_abort(
"{.arg truth} ({length(truth)}) must have a length greater than \\
{.arg m} ({m}) to compute the out-of-sample naive mean absolute error.",
call = call
)
}

invisible(truth)
}
7 changes: 1 addition & 6 deletions R/num-pseudo_huber_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,7 @@ huber_loss_pseudo_impl <- function(truth,
delta,
case_weights,
call = caller_env()) {
if (!is_bare_numeric(delta, n = 1L)) {
abort("`delta` must be a single numeric value.", call = call)
}
if (!(delta >= 0)) {
abort("`delta` must be a positive value.", call = call)
}
check_number_decimal(delta, min = 0, call = call)

a <- truth - estimate
loss <- delta^2 * (sqrt(1 + (a / delta)^2) - 1)
Expand Down
27 changes: 22 additions & 5 deletions R/prob-binary-thresholds.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,36 @@ binary_threshold_curve <- function(truth,
case_weights <- vec_cast(case_weights, to = double())

if (!is.factor(truth)) {
abort("`truth` must be a factor.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} must be a factor, not {.obj_friendly_type {truth}}.",
.internal = TRUE
)
}
if (length(levels(truth)) != 2L) {
abort("`truth` must have two levels.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} must have two levels, not {length(levels(truth))}.",
.internal = TRUE
)
}
if (!is.numeric(estimate)) {
abort("`estimate` must be numeric.", .internal = TRUE)
cli::cli_abort(
"{.arg estimate} must be numeric, not {.obj_friendly_type {estimate}}.",
.internal = TRUE
)
}
if (length(truth) != length(estimate)) {
abort("`truth` and `estimate` must be the same length.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({length(truth)}) and \\
{.arg estimate} ({length(estimate)}) must be the same length.",
.internal = TRUE
)
}
if (length(truth) != length(case_weights)) {
abort("`truth` and `case_weights` must be the same length.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({length(truth)}) and \\
{.arg case_weights} ({length(case_weights)}) must be the same length.",
.internal = TRUE
)
}

truth <- unclass(truth)
Expand Down
Loading

0 comments on commit 00ea962

Please sign in to comment.