Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantile predictions output constructor #1191

Merged
merged 33 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f7e25d5
small change to predict checks
dajmcdon Aug 15, 2024
55897a5
merge remote
dajmcdon Sep 3, 2024
1980cb4
Merge branch 'quantile-mode' of https://github.com/dajmcdon/parsnip i…
dajmcdon Sep 9, 2024
43d1918
add vctrs for quantiles and test, refactor *_rq_preds
dajmcdon Sep 9, 2024
728c046
revise tests
dajmcdon Sep 9, 2024
32ea877
Apply some of the suggestions from code review
dajmcdon Sep 9, 2024
cc1f8de
rename tests on suggestion from code review
dajmcdon Sep 9, 2024
1d27996
export missing funs from vctrs for formatting
dajmcdon Sep 9, 2024
4a996c2
convert errors to snapshot tests
dajmcdon Sep 9, 2024
f03bcc3
pass call through input check
dajmcdon Sep 9, 2024
73e43e9
update snapshots for caller_env
dajmcdon Sep 9, 2024
7ca367e
rename to parsnip_quantiles, add format snapshot tests
dajmcdon Sep 9, 2024
49cc02e
Apply suggestions from @topepo
dajmcdon Sep 10, 2024
3ff6930
rename parsnip_quantiles to quantile_pred
dajmcdon Sep 10, 2024
8e601c5
rename parsnip_quantiles to quantile_pred and add vector probability …
dajmcdon Sep 10, 2024
f4c90ca
fix: two bugs introduced earlier
dajmcdon Sep 10, 2024
13b6010
add formatting tests for single quantile
dajmcdon Sep 10, 2024
f3ac33e
replace walk with a loop to avoid "Error in map()"
dajmcdon Sep 10, 2024
7ffcb38
remove row/col names
dajmcdon Sep 10, 2024
90655c9
adjust quantile_pred format
dajmcdon Sep 10, 2024
e8feed3
as_tibble method
topepo Sep 11, 2024
2748d06
updated NEWS file
topepo Sep 11, 2024
5b09175
add PR number
topepo Sep 11, 2024
30760de
small new update
topepo Sep 11, 2024
926d587
helper methods
topepo Sep 12, 2024
b575c34
update docs
Sep 12, 2024
83c744b
re-enable quantiles prediction for #1203
topepo Sep 12, 2024
11dd169
update some tests
topepo Sep 12, 2024
9fa5bf0
no longer needed
Sep 13, 2024
1e74bae
use tibble::new_tibble
Sep 13, 2024
9122073
braces
Sep 13, 2024
9ee98e9
test as_tibble
Sep 13, 2024
9ce72c0
remove print methods
Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ S3method(extract_spec_parsnip,model_fit)
S3method(fit,model_spec)
S3method(fit_xy,gen_additive_mod)
S3method(fit_xy,model_spec)
S3method(format,vctrs_quantiles)
S3method(glance,model_fit)
S3method(has_multi_predict,default)
S3method(has_multi_predict,model_fit)
Expand All @@ -54,6 +55,7 @@ S3method(multi_predict_args,default)
S3method(multi_predict_args,model_fit)
S3method(multi_predict_args,workflow)
S3method(nullmodel,default)
S3method(obj_print_footer,vctrs_quantiles)
S3method(predict,"_elnet")
S3method(predict,"_glmnetfit")
S3method(predict,"_lognet")
Expand Down Expand Up @@ -172,6 +174,8 @@ S3method(update,svm_rbf)
S3method(varying_args,model_spec)
S3method(varying_args,recipe)
S3method(varying_args,step)
S3method(vec_ptype_abbr,vctrs_quantiles)
S3method(vec_ptype_full,vctrs_quantiles)
export("%>%")
export(.censoring_weights_graf)
export(.check_glmnet_penalty_fit)
Expand Down Expand Up @@ -280,6 +284,7 @@ export(new_model_spec)
export(null_model)
export(null_value)
export(nullmodel)
export(obj_print_footer)
export(parsnip_addin)
export(pls)
export(poisson_reg)
Expand Down Expand Up @@ -350,6 +355,7 @@ export(update_model_info_file)
export(update_spec)
export(varying)
export(varying_args)
export(vec_quantiles)
export(xgb_predict)
export(xgb_train)
import(rlang)
Expand Down Expand Up @@ -396,6 +402,8 @@ importFrom(purrr,map)
importFrom(purrr,map_chr)
importFrom(purrr,map_dbl)
importFrom(purrr,map_lgl)
importFrom(rlang,"!!!")
importFrom(rlang,is_double)
importFrom(stats,.checkMFClasses)
importFrom(stats,.getXlevels)
importFrom(stats,as.formula)
Expand Down Expand Up @@ -426,5 +434,6 @@ importFrom(utils,globalVariables)
importFrom(utils,head)
importFrom(utils,methods)
importFrom(utils,stack)
importFrom(vctrs,obj_print_footer)
importFrom(vctrs,vec_size)
importFrom(vctrs,vec_unique)
107 changes: 92 additions & 15 deletions R/aaa_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,100 @@ check_quantile_level <- function(x, object, call) {
x
}

# Assumes the columns have the same order as quantile_level
restructure_rq_pred <- function(x, object) {
num_quantiles <- NCOL(x)
if ( num_quantiles == 1L ){
x <- matrix(x, ncol = 1)

# -------------------------------------------------------------------------
# A column vector of quantiles with an attribute

#' @export
vec_ptype_abbr.vctrs_quantiles <- function(x, ...) "qntls"

#' @export
vec_ptype_full.vctrs_quantiles <- function(x, ...) "quantiles"

#' @importFrom rlang is_double !!!
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
new_vec_quantiles <- function(values = list(), quantile_levels = double()) {
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
vctrs::new_vctr(
values, quantile_levels = quantile_levels, class = "vctrs_quantiles"
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
)
}


#' Create a vector containing sets of quantiles
#'
#' @param values A matrix of values. Each column should correspond to one of
#' the quantile levels.
#' @param quantile_levels A vector of probabilities corresponding to `values`.
#'
#' @export
#' @return A vector of values associated with the quantile levels.
#'
#' @examples
#' v <- vec_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8))
#'
#' # Access the underlying information
#' attr(v, "quantile_levels")
#' vctrs::vec_data(v)
vec_quantiles <- function(values, quantile_levels = double()) {
check_vec_quantiles_inputs(values, quantile_levels)
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
num_lvls <- length(quantile_levels)

if (ncol(values) != num_lvls) {
cli::cli_abort(
"The number of columns in {.arg values} must be equal to the length of
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
{.arg quantile_levels}."
)
}
n <- nrow(x)
values <- lapply(vctrs::vec_chop(values), drop)
new_vec_quantiles(values, quantile_levels)
}

check_vec_quantiles_inputs <- function(values, levels) {
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
if (!is.matrix(values)) {
cls <- class(values)[1]
cli::cli_abort("{.arg values} must be a {.cls matrix} not a {.cls {cls}}.")
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
}
purrr::walk(levels,
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
~ check_number_decimal(.x, min = 0, max = 1, arg = "quantile_levels")
)
if (is.unsorted(levels)) {
cli::cli_abort("{.arg quantile_levels} must be sorted in increasing order.")
}
invisible(NULL)
}

#' @export
format.vctrs_quantiles <- function(x, ...) {
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
quantile_levels <- attr(x, "levels")
if (length(quantile_levels) == 1L) {
x <- unlist(x)
out <- round(x, 3L)
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
out[is.na(x)] <- NA
} else {
rng <- sapply(x, range)
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]")
out[is.na(rng[1, ]) | is.na(rng[2, ])] <- NA
}
out
}

#' @importFrom vctrs obj_print_footer
#' @export
vctrs::obj_print_footer

#' @export
obj_print_footer.vctrs_quantiles <- function(x, ...) {
lvls <- attr(x, "quantile_levels")
cat("# Quantile levels: ", format(lvls, digits = 3), "\n", sep = " ")
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
}

dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
restructure_rq_pred <- function(x, object) {
if (!is.matrix(x)) x <- as.matrix(x)
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
rownames(x) <- NULL
n_pred_quantiles <- ncol(x)
# TODO check p = length(quantile_level)
topepo marked this conversation as resolved.
Show resolved Hide resolved
quantile_level <- object$spec$quantile_level
res <-
tibble::tibble(
.pred_quantile = as.vector(x),
.quantile_level = rep(quantile_level, each = n),
.row = rep(1:n, num_quantiles))
res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"])
res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val)))
res$.row <- NULL
res
tibble::tibble(.pred_quantile = vec_quantiles(x, quantile_level))
}

5 changes: 4 additions & 1 deletion man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions man/vec_quantiles.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 16 additions & 12 deletions tests/testthat/test-linear_reg_quantreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@ test_that('linear quantile regression via quantreg - single quantile', {
expect_true(nrow(one_quant_pred) == nrow(sac_test))
expect_named(one_quant_pred, ".pred_quantile")
expect_true(is.list(one_quant_pred[[1]]))
expect_s3_class(one_quant_pred$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame"))
expect_named(one_quant_pred$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level"))
expect_true(nrow(one_quant_pred$.pred_quantile[[1]]) == 1L)
expect_s3_class(one_quant_pred$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list"))
expect_identical(class(one_quant_pred$.pred_quantile[[1]]), "numeric")
expect_true(length(one_quant_pred$.pred_quantile[[1]]) == 1L)
expect_identical(attr(one_quant_pred$.pred_quantile, "quantile_levels"), .5)

###

one_quant_one_row <- predict(one_quant, new_data = sac_test[1,])
expect_true(nrow(one_quant_one_row) == 1L)
expect_named(one_quant_one_row, ".pred_quantile")
expect_true(is.list(one_quant_one_row[[1]]))
expect_s3_class(one_quant_one_row$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame"))
expect_named(one_quant_one_row$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level"))
expect_true(nrow(one_quant_one_row$.pred_quantile[[1]]) == 1L)
expect_s3_class(one_quant_one_row$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list"))
expect_identical(class(one_quant_one_row$.pred_quantile[[1]]), "numeric")
expect_true(length(one_quant_one_row$.pred_quantile[[1]]) == 1L)
expect_identical(attr(one_quant_pred$.pred_quantile, "quantile_levels"), .5)
})

test_that('linear quantile regression via quantreg - multiple quantiles', {
Expand Down Expand Up @@ -65,19 +67,21 @@ test_that('linear quantile regression via quantreg - multiple quantiles', {
expect_true(nrow(ten_quant_pred) == nrow(sac_test))
expect_named(ten_quant_pred, ".pred_quantile")
expect_true(is.list(ten_quant_pred[[1]]))
expect_s3_class(ten_quant_pred$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame"))
expect_named(ten_quant_pred$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level"))
expect_true(nrow(ten_quant_pred$.pred_quantile[[1]]) == 10L)
expect_s3_class(ten_quant_pred$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list"))
expect_identical(class(ten_quant_pred$.pred_quantile[[1]]), "numeric")
expect_true(length(ten_quant_pred$.pred_quantile[[1]]) == 10L)
expect_identical(attr(ten_quant_pred$.pred_quantile, "quantile_levels"), (0:9)/9)

###

ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,])
expect_true(nrow(ten_quant_one_row) == 1L)
expect_named(ten_quant_one_row, ".pred_quantile")
expect_true(is.list(ten_quant_one_row[[1]]))
expect_s3_class(ten_quant_one_row$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame"))
expect_named(ten_quant_one_row$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level"))
expect_true(nrow(ten_quant_one_row$.pred_quantile[[1]]) == 10L)
expect_s3_class(ten_quant_one_row$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list"))
expect_identical(class(ten_quant_one_row$.pred_quantile[[1]]), "numeric")
expect_true(length(ten_quant_one_row$.pred_quantile[[1]]) == 10L)
expect_identical(attr(ten_quant_one_row$.pred_quantile, "quantile_levels"), (0:9)/9)
})


Expand Down
25 changes: 25 additions & 0 deletions tests/testthat/test-vec_quantiles.R
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
test_that("vec_quantiles error types", {
expect_error(vec_quantiles(1:10, 1:4 / 5), "matrix")
expect_error(
vec_quantiles(matrix(1:20, 5), -1:4 / 5),
"`quantile_levels` must be a number between 0 and 1"
)
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
expect_error(
vec_quantiles(matrix(1:20, 5), 1:5 / 6),
"The number of columns in `values` must be equal to"
)
expect_error(
vec_quantiles(matrix(1:20, 5), 4:1 / 5),
"must be sorted in increasing order"
)
})

test_that("vec_quantiles outputs", {
v <- vec_quantiles(matrix(1:20, 5), 1:4 / 5)
expect_s3_class(v, "vctrs_quantiles")
expect_identical(attr(v, "quantile_levels"), 1:4 / 5)
expect_identical(
vctrs::vec_data(v),
lapply(vctrs::vec_chop(matrix(1:20, 5)), drop)
)
})