Skip to content

Commit

Permalink
Merge pull request #520 from tidymodels/all-snapshots
Browse files Browse the repository at this point in the history
All snapshots
  • Loading branch information
EmilHvitfeldt authored Oct 30, 2024
2 parents 2b0c096 + e258025 commit 8ab8903
Show file tree
Hide file tree
Showing 27 changed files with 219 additions and 13 deletions.
6 changes: 5 additions & 1 deletion R/aaa-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ metric_set <- function(...) {
)) {
make_survival_metric_function(fns)
} else {
# should not be reachable
cli::cli_abort(
"{.fn validate_function_class} should have errored on unknown classes.",
.internal = TRUE
Expand Down Expand Up @@ -345,6 +346,7 @@ get_quo_label <- function(quo) {
out <- as_label(quo)

if (length(out) != 1L) {
# should not be reachable
cli::cli_abort(
"{.code as_label(quo)} resulted in a character vector of length >1.",
.internal = TRUE
Expand Down Expand Up @@ -573,7 +575,9 @@ make_survival_metric_function <- function(fns) {

validate_not_empty <- function(x, call = caller_env()) {
if (is_empty(x)) {
cli::cli_abort("At least 1 function supplied to `...`.", call = call)
cli::cli_abort(
"At least 1 function must be supplied to {.code ...}.", call = call
)
}
}

Expand Down
1 change: 1 addition & 0 deletions R/class-mcc.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ mcc_multiclass_impl <- function(C) {

check_mcc_data <- function(data) {
if (!is.double(data) && !is.matrix(data)) {
# should not be reachable
cli::cli_abort(
"{.arg data} should be a double matrix at this point.",
.internal = TRUE
Expand Down
14 changes: 6 additions & 8 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
# Column name extractors

pos_val <- function(xtab, event_level) {
if (!all(dim(xtab) == 2)) {
cli::cli_abort("Only relevant for 2x2 tables.")
}

if (is_event_first(event_level)) {
colnames(xtab)[[1]]
} else {
Expand All @@ -15,10 +11,6 @@ pos_val <- function(xtab, event_level) {
}

neg_val <- function(xtab, event_level) {
if (!all(dim(xtab) == 2)) {
cli::cli_abort("Only relevant for 2x2 tables.")
}

if (is_event_first(event_level)) {
colnames(xtab)[[2]]
} else {
Expand Down Expand Up @@ -196,13 +188,15 @@ yardstick_cov <- function(truth,

size <- vec_size(truth)
if (size != vec_size(estimate)) {
# should be unreachable
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and
{.arg estimate} ({vec_size(estimate)}) must be the same size.",
.internal = TRUE
)
}
if (size != vec_size(case_weights)) {
# should be unreachable
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and
{.arg case_weights} ({vec_size(case_weights)}) must be the same size.",
Expand Down Expand Up @@ -250,13 +244,15 @@ yardstick_cor <- function(truth,

size <- vec_size(truth)
if (size != vec_size(estimate)) {
# should be unreachable
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and
{.arg estimate} ({vec_size(estimate)}) must be the same size.",
.internal = TRUE
)
}
if (size != vec_size(case_weights)) {
# should be unreachable
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and
{.arg case_weights} ({vec_size(case_weights)}) must be the same size.",
Expand Down Expand Up @@ -489,13 +485,15 @@ yardstick_truth_table <- function(truth, ..., case_weights = NULL) {
abort_if_class_pred(truth)

if (!is.factor(truth)) {
# should be unreachable
cli::cli_abort("{.arg truth} must be a factor.", .internal = TRUE)
}

levels <- levels(truth)
n_levels <- length(levels)

if (n_levels < 2) {
# should be unreachable
cli::cli_abort(
"{.arg truth} must have at least 2 factor levels.",
.internal = TRUE
Expand Down
5 changes: 5 additions & 0 deletions R/prob-binary-thresholds.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,36 @@ binary_threshold_curve <- function(truth,
case_weights <- vec_cast(case_weights, to = double())

if (!is.factor(truth)) {
# should be unreachable
cli::cli_abort(
"{.arg truth} must be a factor, not {.obj_type_friendly {truth}}.",
.internal = TRUE
)
}
if (length(levels(truth)) != 2L) {
# should be unreachable
cli::cli_abort(
"{.arg truth} must have two levels, not {length(levels(truth))}.",
.internal = TRUE
)
}
if (!is.numeric(estimate)) {
# should be unreachable
cli::cli_abort(
"{.arg estimate} must be numeric vector, not {.obj_type_friendly {estimate}}.",
.internal = TRUE
)
}
if (length(truth) != length(estimate)) {
# should be unreachable
cli::cli_abort(
"{.arg truth} ({length(truth)}) and
{.arg estimate} ({length(estimate)}) must be the same length.",
.internal = TRUE
)
}
if (length(truth) != length(case_weights)) {
# should be unreachable
cli::cli_abort(
"{.arg truth} ({length(truth)}) and
{.arg case_weights} ({length(case_weights)}) must be the same length.",
Expand Down
1 change: 1 addition & 0 deletions R/prob-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ auc <- function(x, y, na_rm = TRUE) {
}

if (is.unsorted(x, na.rm = TRUE, strictly = FALSE)) {
# should not be reachable
cli::cli_abort(
"{.arg x} must already be in weakly increasing order.",
.internal = TRUE
Expand Down
1 change: 1 addition & 0 deletions R/prob-roc_auc.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ roc_auc_estimator_impl <- function(truth,
roc_auc_binary(truth, estimate, event_level, case_weights)
} else if (estimator == "hand_till") {
if (!is.null(case_weights)) {
# should be unreachable
cli::cli_abort(
"{.arg case_weights} should be `NULL` at this point for hand-till.",
.internal = TRUE
Expand Down
4 changes: 3 additions & 1 deletion R/template.R
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ prob_estimate_convert <- function(estimate) {
n_estimate <- ncol(estimate)

if (n_estimate == 0L) {
# should be unreachable
cli::cli_abort(
"{.arg estimate} should have errored during tidy-selection.",
.internal = TRUE
Expand Down Expand Up @@ -780,7 +781,8 @@ yardstick_eval_select <- function(expr,

if (length(out) != 1L) {
cli::cli_abort(
"{.arg arg} must select exactly 1 column from `data`, not {length(out)}.",
"{.arg {arg}} must select exactly 1 column from {.arg data},
not {length(out)}.",
call = error_call
)
}
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/_snaps/aaa-metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,11 @@
! Can't select columns that don't exist.
x Column `weight` doesn't exist.

# metric_set() errors on empty input

Code
metric_set()
Condition
Error in `metric_set()`:
! At least 1 function must be supplied to `...`.

8 changes: 8 additions & 0 deletions tests/testthat/_snaps/conf_mat.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,11 @@
! Can't select columns that don't exist.
x Column `not_predicted` doesn't exist.

# conf_mat() error on 1-level factor truth

Code
conf_mat(table(1, 1))
Condition
Error in `conf_mat()`:
! There must be at least 2 factors levels in the `data`.

8 changes: 8 additions & 0 deletions tests/testthat/_snaps/estimator-helpers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# get_weights() errors with wrong estimator

Code
get_weights(mtcars, "wrong")
Condition
Error in `get_weights()`:
! `estimator` type "wrong" is unknown.

8 changes: 8 additions & 0 deletions tests/testthat/_snaps/event-level.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@
i Please use the metric function argument `event_level` instead.
i The global option is being ignored entirely.

# validate_event_level() works

Code
recall(two_class_example, truth, predicted, event_level = "wrong")
Condition
Error in `recall()`:
! `event_level` must be "first" or "second".

8 changes: 8 additions & 0 deletions tests/testthat/_snaps/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,11 @@
Error in `weighted_quantile()`:
! `probabilities` can't have missing values.

# work with class_pred input

Code
accuracy_vec(fct_truth, cp_estimate)
Condition
Error in `as_factor_from_class_pred()`:
! A <class_pred> input was detected, but the probably package isn't installed. Install probably to be able to convert <class_pred> to <factor>.

8 changes: 8 additions & 0 deletions tests/testthat/_snaps/num-mase.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@
Error in `mase_vec()`:
! `mae_train` must be a number or `NULL`, not the string "x".

# mase() errors if m is larger than number of observations

Code
mase(mtcars, mpg, disp, m = 100)
Condition
Error in `mase()`:
! `truth` (32) must have a length greater than `m` (100) to compute the out-of-sample naive mean absolute error.

9 changes: 9 additions & 0 deletions tests/testthat/_snaps/prob-gain_curve.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
Error in `gain_curve()`:
! `truth` should be a factor, not a a number.

# na_rm = FALSE errors if missing values are present

Code
gain_curve_vec(df$truth, df$Class1, na_rm = FALSE)
Condition
Error in `gain_curve_vec()`:
x Missing values were detected and `na_ra = FALSE`.
i Not able to perform calculations.

# errors with class_pred input

Code
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/_snaps/prob-pr_curve.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,12 @@
Error in `pr_curve_vec()`:
! `truth` should not a <class_pred> object.

# na_rm = FALSE errors if missing values are present

Code
pr_curve_vec(df$truth, df$Class1, na_rm = FALSE)
Condition
Error in `pr_curve_vec()`:
x Missing values were detected and `na_ra = FALSE`.
i Not able to perform calculations.

16 changes: 16 additions & 0 deletions tests/testthat/_snaps/template.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# errors are thrown truth or estimate selects more than 1 column

Code
rmse(mtcars, mpg, tidyselect::starts_with("d"))
Condition
Error in `rmse()`:
! `estimate` must select exactly 1 column from `data`, not 2.

---

Code
rmse(mtcars, tidyselect::starts_with("d"), mpg)
Condition
Error in `rmse()`:
! `truth` must select exactly 1 column from `data`, not 2.

# numeric_metric_summarizer()'s errors when wrong things are passes

Code
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/_snaps/validation.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@
Error:
! `estimate` should be a list, not a a double vector.

---

Code
validate_surv_truth_list_estimate(lung_surv$surv_obj[1:5, ], lung_surv$.pred)
Condition
Error:
! `truth` (5) and `estimate` (228) must be the same length.

---

Code
Expand Down
7 changes: 7 additions & 0 deletions tests/testthat/test-aaa-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,10 @@ test_that("metric_tweak and metric_set plays nicely together (#351)", {
ref
)
})

test_that("metric_set() errors on empty input", {
expect_snapshot(
error = TRUE,
metric_set()
)
})
7 changes: 7 additions & 0 deletions tests/testthat/test-conf_mat.R
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,10 @@ test_that("conf_mat()'s errors when wrong things are passes", {
)
)
})

test_that("conf_mat() error on 1-level factor truth", {
expect_snapshot(
error = TRUE,
conf_mat(table(1, 1))
)
})
6 changes: 6 additions & 0 deletions tests/testthat/test-estimator-helpers.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
test_that("get_weights() errors with wrong estimator", {
expect_snapshot(
error = TRUE,
get_weights(mtcars, "wrong")
)
})
8 changes: 8 additions & 0 deletions tests/testthat/test-event-level.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,11 @@ test_that("`yardstick_event_level()` ignores option - FALSE, with a warning", {
expect_snapshot(out <- yardstick_event_level())
expect_identical(out, "first")
})

test_that("validate_event_level() works", {
expect_snapshot(
error = TRUE,
recall(two_class_example, truth, predicted, event_level = "wrong")
)
})

25 changes: 25 additions & 0 deletions tests/testthat/test-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,28 @@ test_that("`probabilities` must be in [0, 1]", {
test_that("`probabilities` can't be missing", {
expect_snapshot(error = TRUE, weighted_quantile(1, 1, NA))
})

test_that("work with class_pred input", {
skip_if_not_installed("probably")

cp_truth <- probably::as_class_pred(two_class_example$truth, which = 1)
cp_estimate <- probably::as_class_pred(two_class_example$predicted, which = 2)

fct_truth <- two_class_example$truth
fct_truth[1] <- NA

fct_estimate <- two_class_example$predicted
fct_estimate[2] <- NA

local_mocked_bindings(
.package = "rlang",
detect_installed = function(pkg, ...) {
FALSE
}
)

expect_snapshot(
error = TRUE,
accuracy_vec(fct_truth, cp_estimate)
)
})
Loading

0 comments on commit 8ab8903

Please sign in to comment.