Skip to content

Commit

Permalink
fix model fit for spark tbls
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Jan 16, 2024
1 parent d65dde2 commit d12ee60
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# parsnip (development version)

* Fixed bug in fitting some model types with the `"spark"` engine (#1045).

* Improved errors and documentation related to special terms in formulas. See `?model_formula` to learn more. (#770, #1014)

* Improved errors in cases where the outcome column is mis-specified. (#1003)
Expand Down
10 changes: 9 additions & 1 deletion R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,11 @@ min_cols <- function(num_cols, source) {
#' @export
#' @rdname min_cols
min_rows <- function(num_rows, source, offset = 0) {
n <- nrow(source)
if (inherits(source, "tbl_spark")) {
n <- nrow_spark(source)
} else {
n <- nrow(source)
}

if (num_rows > n - offset) {
msg <- paste0(num_rows, " samples were requested but there were ", n,
Expand All @@ -340,3 +344,7 @@ min_rows <- function(num_rows, source, offset = 0) {
as.integer(num_rows)
}

nrow_spark <- function(source) {
rlang::check_installed("sparklyr")
sparklyr::sdf_nrow(source)
}
7 changes: 7 additions & 0 deletions tests/testthat/test_boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ test_that('bad input', {
## -----------------------------------------------------------------------------

test_that('argument checks for data dimensions', {
skip_if_not_installed("sparklyr")
library(sparklyr)
skip_if(nrow(spark_installed_versions()) == 0)

spec <-
boost_tree(mtry = 1000, min_n = 1000, trees = 5) %>%
Expand All @@ -36,6 +39,10 @@ test_that('argument checks for data dimensions', {

args <- translate(spec)$method$fit$args
expect_equal(args$min_instances_per_node, expr(min_rows(1000, x)))

sc = spark_connect(master = "local")
cars = copy_to(sc, mtcars, overwrite = TRUE)
expect_equal(min_rows(10, cars), 10)
})

test_that('boost_tree can be fit with 1 predictor if validation is used', {
Expand Down

0 comments on commit d12ee60

Please sign in to comment.