Skip to content

Commit

Permalink
Activation checks (#1065)
Browse files Browse the repository at this point in the history
* remove mlp activation check for #1019

* unit test for #1019

* update news

* version requirement in skip

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>
  • Loading branch information
topepo and ‘topepo’ authored Feb 14, 2024
1 parent 25fc2c6 commit 525baee
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 20 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# parsnip (development version)

* parsnip now lets the engines for [mlp()] check for acceptable values of the activation function (#1019)

* Tightened logic for outcome checking. This resolves issues—some errors and some silent failures—when atomic outcome variables have an attribute (#1060, #1061).

* `rpart_train()` has been deprecated in favor of using `decision_tree()` with the `"rpart"` engine or `rpart::rpart()` directly (#1044).
Expand Down
25 changes: 5 additions & 20 deletions R/mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
#' @param activation A single character string denoting the type of relationship
#' between the original predictors and the hidden unit layer. The activation
#' function between the hidden and output layers is automatically set to either
#' "linear" or "softmax" depending on the type of outcome. Possible values are:
#' "linear", "softmax", "relu", and "elu"
#' "linear" or "softmax" depending on the type of outcome. Possible values
#' depend on the engine being used.
#'
#' @templateVar modeltype mlp
#' @template spec-details
Expand Down Expand Up @@ -142,24 +142,6 @@ check_args.mlp <- function(object) {
if (args$dropout > 0 & args$penalty > 0)
rlang::abort("Both weight decay and dropout should not be specified.")


if (object$engine == "brulee") {
act_funs <- c("linear", "relu", "elu", "tanh")
} else if (object$engine == "keras") {
act_funs <- c("linear", "softmax", "relu", "elu")
} else if (object$engine == "h2o") {
act_funs <- c("relu", "tanh")
}

if (is.character(args$activation)) {
if (!any(args$activation %in% c(act_funs))) {
rlang::abort(
glue::glue("`activation` should be one of: ",
glue::glue_collapse(glue::glue("'{act_funs}'"), sep = ", "))
)
}
}

invisible(object)
}

Expand Down Expand Up @@ -210,6 +192,9 @@ keras_mlp <-
seeds = sample.int(10^5, size = 3),
...) {

act_funs <- c("linear", "softmax", "relu", "elu")
rlang::arg_match(activation, act_funs,)

if (penalty > 0 & dropout > 0) {
rlang::abort("Please use either dropoput or weight decay.", call. = FALSE)
}
Expand Down
25 changes: 25 additions & 0 deletions tests/testthat/test_mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,28 @@ test_that("nnet_softmax", {
expect_equal(res$b, 1 - res$a)
})

test_that("more activations for brulee", {
skip_if_not_installed("brulee", minimum_version = "0.3.0")
skip_on_cran()

data(ames, package = "modeldata")

ames$Sale_Price <- log10(ames$Sale_Price)

set.seed(122)
in_train <- sample(1:nrow(ames), 2000)
ames_train <- ames[ in_train,]
ames_test <- ames[-in_train,]

set.seed(1)
fit <-
try(
mlp(penalty = 0.10, activation = "softplus") %>%
set_mode("regression") %>%
set_engine("brulee") %>%
fit_xy(x = as.matrix(ames_train[, c("Longitude", "Latitude")]),
y = ames_train$Sale_Price),
silent = TRUE)
expect_true(inherits(fit$fit, "brulee_mlp"))
})

0 comments on commit 525baee

Please sign in to comment.