From d12ee606c946287be8bf06bb2d76b9548a3af923 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Tue, 16 Jan 2024 08:29:52 -0600 Subject: [PATCH] fix model fit for spark tbls --- NEWS.md | 2 ++ R/arguments.R | 10 +++++++++- tests/testthat/test_boost_tree.R | 7 +++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index a39566d52..a0e2b6dd2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # parsnip (development version) +* Fixed bug in fitting some model types with the `"spark"` engine (#1045). + * Improved errors and documentation related to special terms in formulas. See `?model_formula` to learn more. (#770, #1014) * Improved errors in cases where the outcome column is mis-specified. (#1003) diff --git a/R/arguments.R b/R/arguments.R index 2ed1de0af..42721aac4 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -328,7 +328,11 @@ min_cols <- function(num_cols, source) { #' @export #' @rdname min_cols min_rows <- function(num_rows, source, offset = 0) { - n <- nrow(source) + if (inherits(source, "tbl_spark")) { + n <- nrow_spark(source) + } else { + n <- nrow(source) + } if (num_rows > n - offset) { msg <- paste0(num_rows, " samples were requested but there were ", n, @@ -340,3 +344,7 @@ min_rows <- function(num_rows, source, offset = 0) { as.integer(num_rows) } +nrow_spark <- function(source) { + rlang::check_installed("sparklyr") + sparklyr::sdf_nrow(source) +} diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index 7abfcf9a4..f92216870 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -28,6 +28,9 @@ test_that('bad input', { ## ----------------------------------------------------------------------------- test_that('argument checks for data dimensions', { + skip_if_not_installed("sparklyr") + library(sparklyr) + skip_if(nrow(spark_installed_versions()) == 0) spec <- boost_tree(mtry = 1000, min_n = 1000, trees = 5) %>% @@ -36,6 +39,10 @@ test_that('argument checks for data dimensions', { args <- translate(spec)$method$fit$args expect_equal(args$min_instances_per_node, expr(min_rows(1000, x))) + + sc = spark_connect(master = "local") + cars = copy_to(sc, mtcars, overwrite = TRUE) + expect_equal(min_rows(10, cars), 10) }) test_that('boost_tree can be fit with 1 predictor if validation is used', {