diff --git a/NAMESPACE b/NAMESPACE index d48f33586..55d0da025 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,8 @@ S3method(.censoring_weights_graf,default) S3method(.censoring_weights_graf,model_fit) +S3method(as.matrix,quantile_pred) +S3method(as_tibble,quantile_pred) S3method(augment,model_fit) S3method(autoplot,glmnet) S3method(autoplot,model_fit) @@ -36,10 +38,12 @@ S3method(extract_spec_parsnip,model_fit) S3method(fit,model_spec) S3method(fit_xy,gen_additive_mod) S3method(fit_xy,model_spec) +S3method(format,quantile_pred) S3method(glance,model_fit) S3method(has_multi_predict,default) S3method(has_multi_predict,model_fit) S3method(has_multi_predict,workflow) +S3method(median,quantile_pred) S3method(multi_predict,"_C5.0") S3method(multi_predict,"_earth") S3method(multi_predict,"_elnet") @@ -54,6 +58,7 @@ S3method(multi_predict_args,default) S3method(multi_predict_args,model_fit) S3method(multi_predict_args,workflow) S3method(nullmodel,default) +S3method(obj_print_footer,quantile_pred) S3method(predict,"_elnet") S3method(predict,"_glmnetfit") S3method(predict,"_lognet") @@ -172,6 +177,8 @@ S3method(update,svm_rbf) S3method(varying_args,model_spec) S3method(varying_args,recipe) S3method(varying_args,step) +S3method(vec_ptype_abbr,quantile_pred) +S3method(vec_ptype_full,quantile_pred) export("%>%") export(.censoring_weights_graf) export(.check_glmnet_penalty_fit) @@ -226,6 +233,7 @@ export(extract_fit_engine) export(extract_fit_time) export(extract_parameter_dials) export(extract_parameter_set_dials) +export(extract_quantile_levels) export(extract_spec_parsnip) export(find_engine_files) export(fit) @@ -280,6 +288,7 @@ export(new_model_spec) export(null_model) export(null_value) export(nullmodel) +export(obj_print_footer) export(parsnip_addin) export(pls) export(poisson_reg) @@ -307,6 +316,7 @@ export(prepare_data) export(print_model_spec) export(prompt_missing_implementation) export(proportional_hazards) +export(quantile_pred) export(rand_forest) export(repair_call) export(req_pkgs) @@ -350,6 +360,8 @@ export(update_model_info_file) export(update_spec) export(varying) export(varying_args) +export(vec_ptype_abbr) +export(vec_ptype_full) export(xgb_predict) export(xgb_train) import(rlang) @@ -402,6 +414,7 @@ importFrom(stats,as.formula) importFrom(stats,binomial) importFrom(stats,coef) importFrom(stats,delete.response) +importFrom(stats,median) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,model.offset) @@ -426,5 +439,8 @@ importFrom(utils,globalVariables) importFrom(utils,head) importFrom(utils,methods) importFrom(utils,stack) +importFrom(vctrs,obj_print_footer) +importFrom(vctrs,vec_ptype_abbr) +importFrom(vctrs,vec_ptype_full) importFrom(vctrs,vec_size) importFrom(vctrs,vec_unique) diff --git a/NEWS.md b/NEWS.md index c51afb0e7..62f6ea8da 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,9 @@ # parsnip (development version) - +* A new model mode (`"quantile regression"`) was added. Including: + * A function to create a new vector class called `quantile_pred()` was added (#1191). + * A `linear_reg()` engine for `"quantreg"`. + * `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775). * Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083). diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index a2920757e..2b62dfcf8 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -1,43 +1,222 @@ # Helpers for quantile regression models check_quantile_level <- function(x, object, call) { - if ( object$mode != "quantile regression" ) { + if (object$mode != "quantile regression") { return(invisible(TRUE)) } else { - if ( is.null(x) ) { + if (is.null(x)) { cli::cli_abort("In {.fn check_mode}, at least one value of {.arg quantile_level} must be specified for quantile regression models.") } } + if (any(is.na(x))) { + cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", + call = call) + } x <- sort(unique(x)) - # TODO we need better vectorization here, otherwise we get things like: - # "Error during wrapup: i In index: 2." in the traceback. - res <- - purrr::map(x, - ~ check_number_decimal(.x, min = 0, max = 1, - arg = "quantile_level", call = call, - allow_infinite = FALSE) - ) + check_vector_probability(x, arg = "quantile_level", call = call) x } -# Assumes the columns have the same order as quantile_level -restructure_rq_pred <- function(x, object) { - num_quantiles <- NCOL(x) - if ( num_quantiles == 1L ){ - x <- matrix(x, ncol = 1) + +# ------------------------------------------------------------------------- +# A column vector of quantiles with an attribute + +#' @importFrom vctrs vec_ptype_abbr +#' @export +vctrs::vec_ptype_abbr + +#' @importFrom vctrs vec_ptype_full +#' @export +vctrs::vec_ptype_full + + +#' @export +vec_ptype_abbr.quantile_pred <- function(x, ...) { + n_lvls <- length(attr(x, "quantile_levels")) + cli::format_inline("qtl{?s}({n_lvls})") +} + +#' @export +vec_ptype_full.quantile_pred <- function(x, ...) "quantiles" + +new_quantile_pred <- function(values = list(), quantile_levels = double()) { + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) + vctrs::new_vctr( + values, quantile_levels = quantile_levels, class = "quantile_pred" + ) +} + +#' Create a vector containing sets of quantiles +#' +#' [quantile_pred()] is a special vector class used to efficiently store +#' predictions from a quantile regression model. It requires the same quantile +#' levels for each row being predicted. +#' +#' @param values A matrix of values. Each column should correspond to one of +#' the quantile levels. +#' @param quantile_levels A vector of probabilities corresponding to `values`. +#' @param x An object produced by [quantile_pred()]. +#' @param .rows,.name_repair,rownames Arguments not used but required by the +#' original S3 method. +#' @param ... Not currently used. +#' +#' @export +#' @return +#' * [quantile_pred()] returns a vector of values associated with the +#' quantile levels. +#' * [extract_quantile_levels()] returns a numeric vector of levels. +#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`, +#' `".quantile_levels"`, and `".row"`. +#' * [as.matrix()] returns an unnamed matrix with rows as sames, columns as +#' quantile levels, and entries are predictions. +#' @examples +#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +#' +#' unclass(.pred_quantile) +#' +#' # Access the underlying information +#' extract_quantile_levels(.pred_quantile) +#' +#' # Matrix format +#' as.matrix(.pred_quantile) +#' +#' # Tidy format +#' tibble::as_tibble(.pred_quantile) +quantile_pred <- function(values, quantile_levels = double()) { + check_quantile_pred_inputs(values, quantile_levels) + + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) + num_lvls <- length(quantile_levels) + + if (ncol(values) != num_lvls) { + cli::cli_abort( + "The number of columns in {.arg values} must be equal to the length of + {.arg quantile_levels}." + ) + } + rownames(values) <- NULL + colnames(values) <- NULL + values <- lapply(vctrs::vec_chop(values), drop) + new_quantile_pred(values, quantile_levels) +} + +check_quantile_pred_inputs <- function(values, levels, call = caller_env()) { + if (any(is.na(levels))) { + cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", + call = call) } - n <- nrow(x) + if (!is.matrix(values)) { + cli::cli_abort( + "{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.", + call = call + ) + } + check_vector_probability(levels, arg = "quantile_levels", call = call) + + if (is.unsorted(levels)) { + cli::cli_abort( + "{.arg quantile_levels} must be sorted in increasing order.", + call = call + ) + } + invisible(NULL) +} + +#' @export +format.quantile_pred <- function(x, ...) { + quantile_levels <- attr(x, "quantile_levels") + if (length(quantile_levels) == 1L) { + x <- unlist(x) + out <- round(x, 3L) + out[is.na(x)] <- NA_real_ + } else { + rng <- sapply(x, range, na.rm = TRUE) + out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]") + out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_ + m <- median(x) + out <- paste0("[", round(m, 3L), "]") + } + out +} + +#' @importFrom vctrs obj_print_footer +#' @export +vctrs::obj_print_footer + +#' @export +obj_print_footer.quantile_pred <- function(x, digits = 3, ...) { + lvls <- attr(x, "quantile_levels") + cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ") +} + +check_vector_probability <- function(x, ..., + allow_na = FALSE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + for (d in x) { + check_number_decimal( + d, min = 0, max = 1, + arg = arg, call = call, + allow_na = allow_na, + allow_null = allow_null, + allow_infinite = FALSE + ) + } +} + +#' @export +median.quantile_pred <- function(x, ...) { + lvls <- attr(x, "quantile_levels") + loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps)) + if (any(loc_median)) { + return(map_dbl(x, ~ .x[min(which(loc_median))])) + } + if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) { + return(rep(NA, vctrs::vec_size(x))) + } + map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y) +} + +restructure_rq_pred <- function(x, object) { + if (!is.matrix(x)) { + x <- as.matrix(x) + } + rownames(x) <- NULL + n_pred_quantiles <- ncol(x) quantile_level <- object$spec$quantile_level - res <- - tibble::tibble( - .pred_quantile = as.vector(x), - .quantile_level = rep(quantile_level, each = n), - .row = rep(1:n, num_quantiles)) - res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"]) - res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val))) - res$.row <- NULL - res + + tibble::new_tibble(x = list(.pred_quantile = quantile_pred(x, quantile_level))) +} + +#' @export +#' @rdname quantile_pred +extract_quantile_levels <- function(x) { + if (!inherits(x, "quantile_pred")) { + cli::cli_abort("{.arg x} should have class {.val quantile_pred}.") + } + attr(x, "quantile_levels") } +#' @export +#' @rdname quantile_pred +as_tibble.quantile_pred <- + function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) { + lvls <- attr(x, "quantile_levels") + n_samp <- length(x) + n_quant <- length(lvls) + tibble::tibble( + .pred_quantile = unlist(x), + .quantile_levels = rep(lvls, n_samp), + .row = rep(1:n_samp, each = n_quant) + ) + } + +#' @export +#' @rdname quantile_pred +as.matrix.quantile_pred <- function(x, ...) { + num_samp <- length(x) + matrix(unlist(x), nrow = num_samp) +} diff --git a/R/parsnip-package.R b/R/parsnip-package.R index 01f1f42c1..c4dd3c81d 100644 --- a/R/parsnip-package.R +++ b/R/parsnip-package.R @@ -21,7 +21,7 @@ #' @importFrom stats .checkMFClasses .getXlevels as.formula binomial coef #' @importFrom stats delete.response model.frame model.matrix model.offset #' @importFrom stats model.response model.weights na.omit na.pass predict qnorm -#' @importFrom stats qt quantile setNames terms update +#' @importFrom stats qt quantile setNames terms update median #' @importFrom tibble as_tibble is_tibble tibble #' @importFrom tidyr gather #' @importFrom utils capture.output getFromNamespace globalVariables head diff --git a/R/predict_quantile.R b/R/predict_quantile.R index f9154d6a9..fc2d91b15 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -6,7 +6,12 @@ #' @method predict_quantile model_fit #' @export predict_quantile.model_fit #' @export -predict_quantile.model_fit <- function(object, new_data, ...) { +predict_quantile.model_fit <- function(object, + new_data, + quantile = (1:9)/10, + interval = "none", + level = 0.95, + ...) { check_spec_pred_type(object, "quantile") @@ -23,7 +28,7 @@ predict_quantile.model_fit <- function(object, new_data, ...) { } # Pass some extra arguments to be used in post-processor - object$spec$method$pred$quantile$args$quantile_level <- object$quantile_level + object$spec$method$pred$quantile$args$p <- quantile pred_call <- make_pred_call(object$spec$method$pred$quantile) res <- eval_tidy(pred_call) @@ -40,5 +45,6 @@ predict_quantile.model_fit <- function(object, new_data, ...) { # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_quantile <- function (object, ...) +predict_quantile <- function (object, ...) { UseMethod("predict_quantile") +} diff --git a/man/other_predict.Rd b/man/other_predict.Rd index bc1d104bf..6c997e28d 100644 --- a/man/other_predict.Rd +++ b/man/other_predict.Rd @@ -46,7 +46,14 @@ predict_linear_pred(object, ...) predict_numeric(object, ...) -\method{predict_quantile}{model_fit}(object, new_data, ...) +\method{predict_quantile}{model_fit}( + object, + new_data, + quantile = (1:9)/10, + interval = "none", + level = 0.95, + ... +) \method{predict_survival}{model_fit}( object, diff --git a/man/quantile_pred.Rd b/man/quantile_pred.Rd new file mode 100644 index 000000000..abb34ca20 --- /dev/null +++ b/man/quantile_pred.Rd @@ -0,0 +1,60 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aaa_quantiles.R +\name{quantile_pred} +\alias{quantile_pred} +\alias{extract_quantile_levels} +\alias{as_tibble.quantile_pred} +\alias{as.matrix.quantile_pred} +\title{Create a vector containing sets of quantiles} +\usage{ +quantile_pred(values, quantile_levels = double()) + +extract_quantile_levels(x) + +\method{as_tibble}{quantile_pred}(x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) + +\method{as.matrix}{quantile_pred}(x, ...) +} +\arguments{ +\item{values}{A matrix of values. Each column should correspond to one of +the quantile levels.} + +\item{quantile_levels}{A vector of probabilities corresponding to \code{values}.} + +\item{x}{An object produced by \code{\link[=quantile_pred]{quantile_pred()}}.} + +\item{...}{Not currently used.} + +\item{.rows, .name_repair, rownames}{Arguments not used but required by the +original S3 method.} +} +\value{ +\itemize{ +\item \code{\link[=quantile_pred]{quantile_pred()}} returns a vector of values associated with the +quantile levels. +\item \code{\link[=extract_quantile_levels]{extract_quantile_levels()}} returns a numeric vector of levels. +\item \code{\link[=as_tibble]{as_tibble()}} returns a tibble with rows \code{".pred_quantile"}, +\code{".quantile_levels"}, and \code{".row"}. +\item \code{\link[=as.matrix]{as.matrix()}} returns an unnamed matrix with rows as sames, columns as +quantile levels, and entries are predictions. +} +} +\description{ +\code{\link[=quantile_pred]{quantile_pred()}} is a special vector class used to efficiently store +predictions from a quantile regression model. It requires the same quantile +levels for each row being predicted. +} +\examples{ +.pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) + +unclass(.pred_quantile) + +# Access the underlying information +extract_quantile_levels(.pred_quantile) + +# Matrix format +as.matrix(.pred_quantile) + +# Tidy format +tibble::as_tibble(.pred_quantile) +} diff --git a/man/reexports.Rd b/man/reexports.Rd index f87bde459..13baaa850 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -1,8 +1,11 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/reexports.R, R/varying.R +% Please edit documentation in R/aaa_quantiles.R, R/reexports.R, R/varying.R \docType{import} \name{reexports} \alias{reexports} +\alias{vec_ptype_abbr} +\alias{vec_ptype_full} +\alias{obj_print_footer} \alias{autoplot} \alias{\%>\%} \alias{fit} @@ -34,5 +37,7 @@ below to see their documentation. \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_fit_time}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{frequency_weights}}, \code{\link[hardhat]{importance_weights}}, \code{\link[hardhat]{tune}}} \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} + + \item{vctrs}{\code{\link[vctrs:obj_print]{obj_print_footer}}, \code{\link[vctrs:vec_ptype_full]{vec_ptype_abbr}}, \code{\link[vctrs]{vec_ptype_full}}} }} diff --git a/tests/testthat/_snaps/aaa_quantiles.md b/tests/testthat/_snaps/aaa_quantiles.md new file mode 100644 index 000000000..0925f61df --- /dev/null +++ b/tests/testthat/_snaps/aaa_quantiles.md @@ -0,0 +1,114 @@ +# quantile_pred error types + + Code + quantile_pred(1:10, 1:4 / 5) + Condition + Error in `quantile_pred()`: + ! `values` must be a , not an integer vector. + +--- + + Code + quantile_pred(matrix(1:20, 5), -1:4 / 5) + Condition + Error in `quantile_pred()`: + ! `quantile_levels` must be a number between 0 and 1, not the number -0.2. + +--- + + Code + quantile_pred(matrix(1:20, 5), 1:5 / 6) + Condition + Error in `quantile_pred()`: + ! The number of columns in `values` must be equal to the length of `quantile_levels`. + +--- + + Code + quantile_pred(matrix(1:20, 5), 4:1 / 5) + Condition + Error in `quantile_pred()`: + ! `quantile_levels` must be sorted in increasing order. + +# quantile_pred formatting + + Code + v + Output + + [1] [8.5] [9.5] [10.5] [11.5] [12.5] + # Quantile levels: 0.2 0.4 0.6 0.8 + +--- + + Code + quantile_pred(matrix(1:18, 9), c(1 / 3, 2 / 3)) + Output + + [1] [5.5] [6.5] [7.5] [8.5] [9.5] [10.5] [11.5] [12.5] [13.5] + # Quantile levels: 0.333 0.667 + +--- + + Code + quantile_pred(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(0.2, 0.8)) + Output + + [1] [0.304] [0.5] [0.696] + # Quantile levels: 0.2 0.8 + +--- + + Code + tibble(qntls = v) + Output + # A tibble: 5 x 1 + qntls + + 1 [8.5] + 2 [9.5] + 3 [10.5] + 4 [11.5] + 5 [12.5] + +--- + + Code + quantile_pred(m, 1:4 / 5) + Output + + [1] [8.5] [9.5] [10.5] [11.5] [12.5] + # Quantile levels: 0.2 0.4 0.6 0.8 + +--- + + Code + one_quantile + Output + + [1] 1 2 3 4 5 + # Quantile levels: 0.556 + +--- + + Code + tibble(qntls = one_quantile) + Output + # A tibble: 5 x 1 + qntls + + 1 1 + 2 2 + 3 3 + 4 4 + 5 5 + +--- + + Code + quantile_pred(m, 5 / 9) + Output + + [1] 1 NA 3 4 5 + # Quantile levels: 0.556 + diff --git a/tests/testthat/_snaps/quantile-reg-specs.md b/tests/testthat/_snaps/quantile-reg-specs.md index f7c24584c..627a5248b 100644 --- a/tests/testthat/_snaps/quantile-reg-specs.md +++ b/tests/testthat/_snaps/quantile-reg-specs.md @@ -20,9 +20,7 @@ linear_reg() %>% set_engine("quantreg") %>% set_mode("quantile regression", quantile_level = 2) Condition - Error in `purrr::map()`: - i In index: 1. - Caused by error in `set_mode()`: + Error in `set_mode()`: ! `quantile_level` must be a number between 0 and 1, not the number 2. --- @@ -31,9 +29,7 @@ linear_reg() %>% set_engine("quantreg") %>% set_mode("quantile regression", quantile_level = 1:2) Condition - Error in `purrr::map()`: - i In index: 2. - Caused by error in `set_mode()`: + Error in `set_mode()`: ! `quantile_level` must be a number between 0 and 1, not the number 2. --- @@ -42,8 +38,6 @@ linear_reg() %>% set_engine("quantreg") %>% set_mode("quantile regression", quantile_level = NA_real_) Condition - Error in `purrr::map()`: - i In index: 1. - Caused by error in `set_mode()`: - ! `quantile_level` must be a number, not a numeric `NA`. + Error in `set_mode()`: + ! Missing values are not allowed in `quantile_levels`. diff --git a/tests/testthat/helper-objects.R b/tests/testthat/helper-objects.R index a9297a65a..14c3931fe 100644 --- a/tests/testthat/helper-objects.R +++ b/tests/testthat/helper-objects.R @@ -24,3 +24,16 @@ is_tf_ok <- function() { } res } + +# ------------------------------------------------------------------------------ +# for quantile regression tests + +data("Sacramento") + +Sacramento_small <- + modeldata::Sacramento %>% + dplyr::mutate(price = log10(price)) %>% + dplyr::select(price, beds, baths, sqft, latitude, longitude) + +sac_train <- Sacramento_small[-(1:5), ] +sac_test <- Sacramento_small[ 1:5 , ] diff --git a/tests/testthat/test-aaa_quantiles.R b/tests/testthat/test-aaa_quantiles.R new file mode 100644 index 000000000..cdf71aa7d --- /dev/null +++ b/tests/testthat/test-aaa_quantiles.R @@ -0,0 +1,59 @@ +test_that("quantile_pred error types", { + expect_snapshot( + error = TRUE, + quantile_pred(1:10, 1:4 / 5) + ) + expect_snapshot( + error = TRUE, + quantile_pred(matrix(1:20, 5), -1:4 / 5) + ) + expect_snapshot( + error = TRUE, + quantile_pred(matrix(1:20, 5), 1:5 / 6) + ) + expect_snapshot( + error = TRUE, + quantile_pred(matrix(1:20, 5), 4:1 / 5) + ) +}) + +test_that("quantile_pred outputs", { + v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) + expect_s3_class(v, "quantile_pred") + expect_identical(attr(v, "quantile_levels"), 1:4 / 5) + expect_identical( + vctrs::vec_data(v), + lapply(vctrs::vec_chop(matrix(1:20, 5)), drop) + ) +}) + +test_that("quantile_pred formatting", { + # multiple quantiles + v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) + expect_snapshot(v) + expect_snapshot(quantile_pred(matrix(1:18, 9), c(1/3, 2/3))) + expect_snapshot( + quantile_pred(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(.2, .8)) + ) + expect_snapshot(tibble(qntls = v)) + m <- matrix(1:20, 5) + m[2, 3] <- NA + m[4, 2] <- NA + expect_snapshot(quantile_pred(m, 1:4 / 5)) + + # single quantile + m <- matrix(1:5) + one_quantile <- quantile_pred(m, 5/9) + expect_snapshot(one_quantile) + expect_snapshot(tibble(qntls = one_quantile)) + m[2] <- NA + expect_snapshot(quantile_pred(m, 5/9)) +}) + +test_that("as_tibble() for quantile_pred", { + v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) + tbl <- as_tibble(v) + expect_s3_class(tbl, c("tbl_df", "tbl", "data.frame")) + expect_named(tbl, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(tbl) == 20) +}) diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index 97d310bbf..47a9f7c88 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -1,14 +1,7 @@ test_that('linear quantile regression via quantreg - single quantile', { skip_if_not_installed("quantreg") - data("Sacramento") - - Sacramento_small <- - Sacramento %>% - dplyr::select(price, beds, baths, sqft, latitude, longitude) - - sac_train <- Sacramento_small[-(1:5), ] - sac_test <- Sacramento_small[ 1:5 , ] + # data in `helper-objects.R` one_quant <- linear_reg() %>% @@ -24,9 +17,18 @@ test_that('linear quantile regression via quantreg - single quantile', { expect_true(nrow(one_quant_pred) == nrow(sac_test)) expect_named(one_quant_pred, ".pred_quantile") expect_true(is.list(one_quant_pred[[1]])) - expect_s3_class(one_quant_pred$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame")) - expect_named(one_quant_pred$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level")) - expect_true(nrow(one_quant_pred$.pred_quantile[[1]]) == 1L) + expect_s3_class( + one_quant_pred$.pred_quantile[1], + c("quantile_pred", "vctrs_vctr", "list") + ) + expect_identical(class(one_quant_pred$.pred_quantile[[1]]), "numeric") + expect_true(length(one_quant_pred$.pred_quantile[[1]]) == 1L) + expect_identical(attr(one_quant_pred$.pred_quantile, "quantile_levels"), .5) + + one_quant_df <- as_tibble(one_quant_pred$.pred_quantile) + expect_s3_class(one_quant_df, c("tbl_df", "tbl", "data.frame")) + expect_named(one_quant_df, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(one_quant_df) == nrow(sac_test) * 1) ### @@ -34,22 +36,24 @@ test_that('linear quantile regression via quantreg - single quantile', { expect_true(nrow(one_quant_one_row) == 1L) expect_named(one_quant_one_row, ".pred_quantile") expect_true(is.list(one_quant_one_row[[1]])) - expect_s3_class(one_quant_one_row$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame")) - expect_named(one_quant_one_row$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level")) - expect_true(nrow(one_quant_one_row$.pred_quantile[[1]]) == 1L) + expect_s3_class( + one_quant_one_row$.pred_quantile[1], + c("quantile_pred", "vctrs_vctr", "list") + ) + expect_identical(class(one_quant_one_row$.pred_quantile[[1]]), "numeric") + expect_true(length(one_quant_one_row$.pred_quantile[[1]]) == 1L) + expect_identical(attr(one_quant_pred$.pred_quantile, "quantile_levels"), .5) + + one_quant_one_row_df <- as_tibble(one_quant_one_row$.pred_quantile) + expect_s3_class(one_quant_one_row_df, c("tbl_df", "tbl", "data.frame")) + expect_named(one_quant_one_row_df, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(one_quant_one_row_df) == nrow(sac_test[1,]) * 1) }) test_that('linear quantile regression via quantreg - multiple quantiles', { skip_if_not_installed("quantreg") - data("Sacramento") - - Sacramento_small <- - Sacramento %>% - dplyr::select(price, beds, baths, sqft, latitude, longitude) - - sac_train <- Sacramento_small[-(1:5), ] - sac_test <- Sacramento_small[ 1:5 , ] + # data in `helper-objects.R` ten_quant <- linear_reg() %>% @@ -65,9 +69,18 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(nrow(ten_quant_pred) == nrow(sac_test)) expect_named(ten_quant_pred, ".pred_quantile") expect_true(is.list(ten_quant_pred[[1]])) - expect_s3_class(ten_quant_pred$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame")) - expect_named(ten_quant_pred$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level")) - expect_true(nrow(ten_quant_pred$.pred_quantile[[1]]) == 10L) + expect_s3_class( + ten_quant_pred$.pred_quantile[1], + c("quantile_pred", "vctrs_vctr", "list") + ) + expect_identical(class(ten_quant_pred$.pred_quantile[[1]]), "numeric") + expect_true(length(ten_quant_pred$.pred_quantile[[1]]) == 10L) + expect_identical(attr(ten_quant_pred$.pred_quantile, "quantile_levels"), (0:9)/9) + + ten_quant_df <- as_tibble(ten_quant_pred$.pred_quantile) + expect_s3_class(ten_quant_df, c("tbl_df", "tbl", "data.frame")) + expect_named(ten_quant_df, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10) ### @@ -75,9 +88,21 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(nrow(ten_quant_one_row) == 1L) expect_named(ten_quant_one_row, ".pred_quantile") expect_true(is.list(ten_quant_one_row[[1]])) - expect_s3_class(ten_quant_one_row$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame")) - expect_named(ten_quant_one_row$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level")) - expect_true(nrow(ten_quant_one_row$.pred_quantile[[1]]) == 10L) + expect_s3_class( + ten_quant_one_row$.pred_quantile[1], + c("quantile_pred", "vctrs_vctr", "list") + ) + expect_identical(class(ten_quant_one_row$.pred_quantile[[1]]), "numeric") + expect_true(length(ten_quant_one_row$.pred_quantile[[1]]) == 10L) + expect_identical( + attr(ten_quant_one_row$.pred_quantile, "quantile_levels"), + (0:9)/9 + ) + + ten_quant_one_row_df <- as_tibble(ten_quant_one_row$.pred_quantile) + expect_s3_class(ten_quant_one_row_df, c("tbl_df", "tbl", "data.frame")) + expect_named(ten_quant_one_row_df, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(ten_quant_one_row_df) == nrow(sac_test[1,]) * 10) })