Skip to content

Commit

Permalink
add input type checkers to all steps
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Nov 7, 2024
1 parent e583e94 commit b05d073
Show file tree
Hide file tree
Showing 36 changed files with 630 additions and 46 deletions.
3 changes: 3 additions & 0 deletions R/collapse_cart.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ step_collapse_cart <-
id = rand_id("step_collapse_cart")) {
recipes_pkg_check(required_pkgs.step_discretize_cart())

check_number_decimal(cost_complexity, min = 0)
check_number_whole(min_n, min = 1)

add_step(
recipe,
step_collapse_cart_new(
Expand Down
5 changes: 2 additions & 3 deletions R/collapse_stringdist.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ step_collapse_stringdist <-
columns = NULL,
skip = FALSE,
id = rand_id("collapse_stringdist")) {
if (is.null(distance)) {
cli::cli_abort("The {.arg distance} argument must be set.")
}
check_number_decimal(distance, min = 0)
check_string(method)

add_step(
recipe,
Expand Down
4 changes: 4 additions & 0 deletions R/discretize_cart.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ cart_binning <- function(predictor, term, outcome, cost_complexity, tree_depth,
prep.step_discretize_cart <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_decimal(x$cost_complexity, min = 0, arg = "cost_complexity")
check_number_decimal(x$tree_depth, min = 0, arg = "tree_depth")
check_number_decimal(x$min_n, min = 0, arg = "min_n")

wts <- get_case_weights(info, training)
were_weights_used <- are_weights_used(wts)
if (isFALSE(were_weights_used)) {
Expand Down
6 changes: 6 additions & 0 deletions R/discretize_xgb.R
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,12 @@ xgb_binning <- function(df, outcome, predictor, sample_val, learn_rate,
prep.step_discretize_xgb <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_decimal(x$sample_val, min = 0, max = 1, arg = "sample_val")
check_number_decimal(x$learn_rate, min = 0, arg = "learn_rate")
check_number_whole(x$num_breaks, min = 0, arg = "num_breaks")
check_number_whole(x$tree_depth, min = 0, arg = "tree_depth")
check_number_whole(x$min_n, min = 0, arg = "min_n")

wts <- get_case_weights(info, training)
were_weights_used <- are_weights_used(wts)
if (isFALSE(were_weights_used) || is.null(wts)) {
Expand Down
3 changes: 3 additions & 0 deletions R/embed.R
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ step_embed_new <-
prep.step_embed <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_whole(x$num_terms, min = 0, arg = "num_terms")
check_number_whole(x$hidden_units, min = 0, arg = "hidden_units")

if (length(col_names) > 0) {
check_type(training[, col_names], types = c("string", "factor", "ordered"))
y_name <- recipes_eval_select(x$outcome, training, info)
Expand Down
2 changes: 2 additions & 0 deletions R/feature_hash.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ step_feature_hash_new <-
prep.step_feature_hash <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_whole(x$num_hash, min = 0, arg = "num_hash")

if (length(col_names) > 0) {
check_type(training[, col_names], types = c("string", "factor", "ordered"))
}
Expand Down
3 changes: 3 additions & 0 deletions R/lencode_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ step_lencode_bayes <-
if (is.null(outcome)) {
cli::cli_abort("Please list a variable in {.code outcome}.")
}

check_bool(verbose)

add_step(
recipe,
step_lencode_bayes_new(
Expand Down
7 changes: 7 additions & 0 deletions R/pca_sparse.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ step_pca_sparse <- function(recipe,
keep_original_cols = FALSE,
skip = FALSE,
id = rand_id("pca_sparse")) {
check_string(prefix)

add_step(
recipe,
step_pca_sparse_new(
Expand Down Expand Up @@ -144,6 +146,11 @@ step_pca_sparse_new <-
prep.step_pca_sparse <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_whole(x$num_comp, min = 0, arg = "num_comp")
check_number_decimal(
x$predictor_prop, min = 0, max = 1, arg = "predictor_prop"
)

if (length(col_names) > 0 && x$num_comp > 0) {
check_type(training[, col_names], types = c("double", "integer"))

Expand Down
11 changes: 10 additions & 1 deletion R/pca_sparse_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ step_pca_sparse_bayes <- function(recipe,
keep_original_cols = FALSE,
skip = FALSE,
id = rand_id("pca_sparse_bayes")) {

check_string(prefix)

add_step(
recipe,
step_pca_sparse_bayes_new(
Expand Down Expand Up @@ -169,6 +170,14 @@ step_pca_sparse_bayes_new <-
prep.step_pca_sparse_bayes <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_whole(x$num_comp, min = 0, arg = "num_comp")
check_number_decimal(
x$prior_slab_dispersion, min = 0, arg = "prior_slab_dispersion"
)
check_number_decimal(
x$prior_mixture_threshold, min = 0, max = 1, arg = "prior_mixture_threshold"
)

if (length(col_names) > 0) {
check_type(training[, col_names], types = c("double", "integer"))

Expand Down
4 changes: 4 additions & 0 deletions R/pca_truncated.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ step_pca_truncated <- function(recipe,
keep_original_cols = FALSE,
skip = FALSE,
id = rand_id("pca_truncated")) {
check_string(prefix)

add_step(
recipe,
step_pca_truncated_new(
Expand Down Expand Up @@ -140,6 +142,8 @@ prep.step_pca_truncated <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("double", "integer"))

check_number_whole(x$num_comp, min = 0, arg = "num_comp")

wts <- get_case_weights(info, training)
were_weights_used <- are_weights_used(wts, unsupervised = TRUE)
if (isFALSE(were_weights_used)) {
Expand Down
23 changes: 22 additions & 1 deletion R/umap.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ step_umap <-
keep_original_cols <- retain
}

check_string(prefix)

recipes_pkg_check(required_pkgs.step_umap())
if (is.numeric(seed) && !is.integer(seed)) {
seed <- as.integer(seed)
Expand Down Expand Up @@ -229,6 +231,14 @@ umap_fit_call <- function(obj, y = NULL) {
prep.step_umap <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_whole(x$num_comp, min = 0, arg = "num_comp")
check_number_whole(x$neighbors, min = 0, arg = "neighbors")
check_number_decimal(x$min_dist, arg = "min_dist")
check_number_decimal(x$learn_rate, min = 0, arg = "learn_rate")
check_number_whole(x$epochs, min = 0, allow_null = TRUE, arg = "epochs")
rlang::arg_match0(x$initial, initial_umap_values, arg_nm = "initial")
check_number_decimal(x$target_weight, min = 0, max = 1, arg = "target_weight")

if (length(col_names) > 0) {
if (length(x$outcome) > 0) {
y_name <- recipes_eval_select(x$outcome, training, info)
Expand Down Expand Up @@ -355,11 +365,22 @@ tunable.step_umap <- function(x, ...) {
list(pkg = "dials", fun = "min_dist", range = c(-4, -0.69897)),
list(pkg = "dials", fun = "learn_rate"),
list(pkg = "dials", fun = "epochs", range = c(100, 700)),
list(pkg = "dials", fun = "initial_umap", values = c("spectral", "normlaplacian", "random", "lvrandom", "laplacian", "pca", "spca", "agspectral")),
list(pkg = "dials", fun = "initial_umap", values = initial_umap_values),
list(pkg = "dials", fun = "target_weight", range = c(0, 1))
),
source = "recipe",
component = "step_umap",
component_id = x$id
)
}

initial_umap_values <- c(
"spectral",
"normlaplacian",
"random",
"lvrandom",
"laplacian",
"pca",
"spca",
"agspectral"
)
4 changes: 4 additions & 0 deletions R/woe.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ step_woe <- function(recipe,
cli::cli_abort("The {.arg outcome} argument is missing, with no default.")
}

check_string(prefix)

add_step(
recipe,
step_woe_new(
Expand Down Expand Up @@ -423,6 +425,8 @@ add_woe <- function(.data, outcome, ..., dictionary = NULL, prefix = "woe") {
prep.step_woe <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

check_number_decimal(x$Laplace, arg = "Laplace")

if (length(col_names) > 0) {
outcome_name <- recipes_eval_select(x$outcome, training, info)

Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/_snaps/collapse_cart.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# bad args

Code
recipe(~., data = mtcars) %>% step_collapse_cart(cost_complexity = -4)
Condition
Error in `step_collapse_cart()`:
! `cost_complexity` must be a number larger than or equal to 0, not the number -4.

---

Code
recipe(~., data = mtcars) %>% step_collapse_cart(min_n = -4)
Condition
Error in `step_collapse_cart()`:
! `min_n` must be a whole number larger than or equal to 1, not the number -4.

# bake method errors when needed non-standard role columns are missing

Code
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/_snaps/collapse_stringdist.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# bad args

Code
recipe(~., data = mtcars) %>% step_collapse_stringdist(cost_complexity = -4)
Condition
Error in `step_collapse_stringdist()`:
! `distance` must be a number, not `NULL`.

---

Code
recipe(~., data = mtcars) %>% step_collapse_stringdist(min_n = -4)
Condition
Error in `step_collapse_stringdist()`:
! `distance` must be a number, not `NULL`.

# bake method errors when needed non-standard role columns are missing

Code
Expand Down
30 changes: 30 additions & 0 deletions tests/testthat/_snaps/discretize_cart.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,36 @@
x All columns selected for the step should be double or integer.
* 1 factor variable found: `w`

---

Code
recipe(~., data = mtcars) %>% step_discretize_cart(outcome = vars("mpg"),
cost_complexity = -4) %>% prep()
Condition
Error in `step_discretize_cart()`:
Caused by error in `prep()`:
! `cost_complexity` must be a number larger than or equal to 0, not the number -4.

---

Code
recipe(~., data = mtcars) %>% step_discretize_cart(outcome = vars("mpg"),
min_n = -4) %>% prep()
Condition
Error in `step_discretize_cart()`:
Caused by error in `prep()`:
! `min_n` must be a number larger than or equal to 0, not the number -4.

---

Code
recipe(~., data = mtcars) %>% step_discretize_cart(outcome = vars("mpg"),
tree_depth = -4) %>% prep()
Condition
Error in `step_discretize_cart()`:
Caused by error in `prep()`:
! `tree_depth` must be a number larger than or equal to 0, not the number -4.

# tidy method

Code
Expand Down
50 changes: 50 additions & 0 deletions tests/testthat/_snaps/discretize_xgb.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,56 @@
-- Operations
* Discretizing variables using xgboost: x and z | Trained, weighted

# bad args

Code
recipe(~., data = mtcars) %>% step_discretize_xgb(outcome = "class",
sample_val = -4) %>% prep()
Condition
Error in `step_discretize_xgb()`:
Caused by error in `prep()`:
! `sample_val` must be a number between 0 and 1, not the number -4.

---

Code
recipe(~., data = mtcars) %>% step_discretize_xgb(outcome = "class",
learn_rate = -4) %>% prep()
Condition
Error in `step_discretize_xgb()`:
Caused by error in `prep()`:
! `learn_rate` must be a number larger than or equal to 0, not the number -4.

---

Code
recipe(~., data = mtcars) %>% step_discretize_xgb(outcome = "class",
num_breaks = -4) %>% prep()
Condition
Error in `step_discretize_xgb()`:
Caused by error in `prep()`:
! `num_breaks` must be a whole number larger than or equal to 0, not the number -4.

---

Code
recipe(~., data = mtcars) %>% step_discretize_xgb(outcome = "class",
tree_depth = -4) %>% prep()
Condition
Error in `step_discretize_xgb()`:
Caused by error in `prep()`:
! `tree_depth` must be a whole number larger than or equal to 0, not the number -4.

---

Code
recipe(~., data = mtcars) %>% step_discretize_xgb(outcome = "class", min_n = -4) %>%
prep()
Condition
Error in `step_discretize_xgb()`:
Caused by error in `prep()`:
! `min_n` must be a whole number larger than or equal to 0, not the number -4.

# bake method errors when needed non-standard role columns are missing

Code
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/_snaps/embed.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,26 @@
x All columns selected for the step should be string, factor, or ordered.
* 1 double variable found: `Sepal.Length`

---

Code
recipe(~., data = mtcars) %>% step_embed(outcome = vars(mpg), num_terms = -4) %>%
prep()
Condition
Error in `step_embed()`:
Caused by error in `prep()`:
! `num_terms` must be a whole number larger than or equal to 0, not the number -4.

---

Code
recipe(~., data = mtcars) %>% step_embed(outcome = vars(mpg), hidden_units = -4) %>%
prep()
Condition
Error in `step_embed()`:
Caused by error in `prep()`:
! `hidden_units` must be a whole number larger than or equal to 0, not the number -4.

# check_name() is used

Code
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/_snaps/feature_hash.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@
! Name collision occurred. The following variable names already exist:
* `x3_hash_01`

# bad args

Code
recipe(~., data = mtcars) %>% step_feature_hash(num_hash = -4) %>% prep()
Condition
Warning:
`step_feature_hash()` was deprecated in embed 0.2.0.
i Please use `textrecipes::step_dummy_hash()` instead.
Error in `step_feature_hash()`:
Caused by error in `prep()`:
! `num_hash` must be a whole number larger than or equal to 0, not the number -4.

# bake method errors when needed non-standard role columns are missing

Code
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/_snaps/lencode_bayes.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@
-- Operations
* Linear embedding for factors via Bayesian GLM for: x3 | Trained, weighted

# bad args

Code
recipe(~., data = mtcars) %>% step_lencode_bayes(outcome = vars(mpg), verbose = "yes")
Condition
Error in `step_lencode_bayes()`:
! `verbose` must be `TRUE` or `FALSE`, not the string "yes".

# bake method errors when needed non-standard role columns are missing

Code
Expand Down
Loading

0 comments on commit b05d073

Please sign in to comment.