Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantile predictions output constructor #1191

Merged
merged 33 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f7e25d5
small change to predict checks
dajmcdon Aug 15, 2024
55897a5
merge remote
dajmcdon Sep 3, 2024
1980cb4
Merge branch 'quantile-mode' of https://github.com/dajmcdon/parsnip i…
dajmcdon Sep 9, 2024
43d1918
add vctrs for quantiles and test, refactor *_rq_preds
dajmcdon Sep 9, 2024
728c046
revise tests
dajmcdon Sep 9, 2024
32ea877
Apply some of the suggestions from code review
dajmcdon Sep 9, 2024
cc1f8de
rename tests on suggestion from code review
dajmcdon Sep 9, 2024
1d27996
export missing funs from vctrs for formatting
dajmcdon Sep 9, 2024
4a996c2
convert errors to snapshot tests
dajmcdon Sep 9, 2024
f03bcc3
pass call through input check
dajmcdon Sep 9, 2024
73e43e9
update snapshots for caller_env
dajmcdon Sep 9, 2024
7ca367e
rename to parsnip_quantiles, add format snapshot tests
dajmcdon Sep 9, 2024
49cc02e
Apply suggestions from @topepo
dajmcdon Sep 10, 2024
3ff6930
rename parsnip_quantiles to quantile_pred
dajmcdon Sep 10, 2024
8e601c5
rename parsnip_quantiles to quantile_pred and add vector probability …
dajmcdon Sep 10, 2024
f4c90ca
fix: two bugs introduced earlier
dajmcdon Sep 10, 2024
13b6010
add formatting tests for single quantile
dajmcdon Sep 10, 2024
f3ac33e
replace walk with a loop to avoid "Error in map()"
dajmcdon Sep 10, 2024
7ffcb38
remove row/col names
dajmcdon Sep 10, 2024
90655c9
adjust quantile_pred format
dajmcdon Sep 10, 2024
e8feed3
as_tibble method
topepo Sep 11, 2024
2748d06
updated NEWS file
topepo Sep 11, 2024
5b09175
add PR number
topepo Sep 11, 2024
30760de
small new update
topepo Sep 11, 2024
926d587
helper methods
topepo Sep 12, 2024
b575c34
update docs
Sep 12, 2024
83c744b
re-enable quantiles prediction for #1203
topepo Sep 12, 2024
11dd169
update some tests
topepo Sep 12, 2024
9fa5bf0
no longer needed
Sep 13, 2024
1e74bae
use tibble::new_tibble
Sep 13, 2024
9122073
braces
Sep 13, 2024
9ee98e9
test as_tibble
Sep 13, 2024
9ce72c0
remove print methods
Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

S3method(.censoring_weights_graf,default)
S3method(.censoring_weights_graf,model_fit)
S3method(as.matrix,quantile_pred)
S3method(as_tibble,quantile_pred)
S3method(augment,model_fit)
S3method(autoplot,glmnet)
S3method(autoplot,model_fit)
Expand Down Expand Up @@ -36,10 +38,12 @@ S3method(extract_spec_parsnip,model_fit)
S3method(fit,model_spec)
S3method(fit_xy,gen_additive_mod)
S3method(fit_xy,model_spec)
S3method(format,quantile_pred)
S3method(glance,model_fit)
S3method(has_multi_predict,default)
S3method(has_multi_predict,model_fit)
S3method(has_multi_predict,workflow)
S3method(median,quantile_pred)
S3method(multi_predict,"_C5.0")
S3method(multi_predict,"_earth")
S3method(multi_predict,"_elnet")
Expand All @@ -54,6 +58,7 @@ S3method(multi_predict_args,default)
S3method(multi_predict_args,model_fit)
S3method(multi_predict_args,workflow)
S3method(nullmodel,default)
S3method(obj_print_footer,quantile_pred)
S3method(predict,"_elnet")
S3method(predict,"_glmnetfit")
S3method(predict,"_lognet")
Expand Down Expand Up @@ -172,6 +177,8 @@ S3method(update,svm_rbf)
S3method(varying_args,model_spec)
S3method(varying_args,recipe)
S3method(varying_args,step)
S3method(vec_ptype_abbr,quantile_pred)
S3method(vec_ptype_full,quantile_pred)
export("%>%")
export(.censoring_weights_graf)
export(.check_glmnet_penalty_fit)
Expand Down Expand Up @@ -226,6 +233,7 @@ export(extract_fit_engine)
export(extract_fit_time)
export(extract_parameter_dials)
export(extract_parameter_set_dials)
export(extract_quantile_levels)
export(extract_spec_parsnip)
export(find_engine_files)
export(fit)
Expand Down Expand Up @@ -280,6 +288,7 @@ export(new_model_spec)
export(null_model)
export(null_value)
export(nullmodel)
export(obj_print_footer)
export(parsnip_addin)
export(pls)
export(poisson_reg)
Expand Down Expand Up @@ -307,6 +316,7 @@ export(prepare_data)
export(print_model_spec)
export(prompt_missing_implementation)
export(proportional_hazards)
export(quantile_pred)
export(rand_forest)
export(repair_call)
export(req_pkgs)
Expand Down Expand Up @@ -350,6 +360,8 @@ export(update_model_info_file)
export(update_spec)
export(varying)
export(varying_args)
export(vec_ptype_abbr)
export(vec_ptype_full)
export(xgb_predict)
export(xgb_train)
import(rlang)
Expand Down Expand Up @@ -402,6 +414,7 @@ importFrom(stats,as.formula)
importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,delete.response)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,model.offset)
Expand All @@ -426,5 +439,8 @@ importFrom(utils,globalVariables)
importFrom(utils,head)
importFrom(utils,methods)
importFrom(utils,stack)
importFrom(vctrs,obj_print_footer)
importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
importFrom(vctrs,vec_size)
importFrom(vctrs,vec_unique)
5 changes: 4 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# parsnip (development version)


* A new model mode (`"quantile regression"`) was added. Including:
* A function to create a new vector class called `quantile_pred()` was added (#1191).
* A `linear_reg()` engine for `"quantreg"`.

* `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775).

* Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083).
Expand Down
225 changes: 202 additions & 23 deletions R/aaa_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,214 @@ check_quantile_level <- function(x, object, call) {
{.arg quantile_level} must be specified for quantile regression models.")
}
}
if ( any(is.na(x)) ) {
topepo marked this conversation as resolved.
Show resolved Hide resolved
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.",
call = call)
}
x <- sort(unique(x))
# TODO we need better vectorization here, otherwise we get things like:
# "Error during wrapup: i In index: 2." in the traceback.
res <-
purrr::map(x,
~ check_number_decimal(.x, min = 0, max = 1,
arg = "quantile_level", call = call,
allow_infinite = FALSE)
)
check_vector_probability(x, arg = "quantile_level", call = call)
x
}

# Assumes the columns have the same order as quantile_level
restructure_rq_pred <- function(x, object) {
num_quantiles <- NCOL(x)
if ( num_quantiles == 1L ){
x <- matrix(x, ncol = 1)

# -------------------------------------------------------------------------
# A column vector of quantiles with an attribute

#' @importFrom vctrs vec_ptype_abbr
#' @export
vctrs::vec_ptype_abbr

#' @importFrom vctrs vec_ptype_full
#' @export
vctrs::vec_ptype_full


#' @export
vec_ptype_abbr.quantile_pred <- function(x, ...) {
n_lvls <- length(attr(x, "quantile_levels"))
cli::format_inline("qtl{?s}({n_lvls})")
}

#' @export
vec_ptype_full.quantile_pred <- function(x, ...) "quantiles"

new_quantile_pred <- function(values = list(), quantile_levels = double()) {
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
vctrs::new_vctr(
values, quantile_levels = quantile_levels, class = "quantile_pred"
)
}

#' Create a vector containing sets of quantiles
#'
#' [quantile_pred()] is a special vector class used to efficiently store
#' predictions from a quantile regression model. It requires the same quantile
#' levels for each row being predicted.
#'
#' @param values A matrix of values. Each column should correspond to one of
#' the quantile levels.
#' @param quantile_levels A vector of probabilities corresponding to `values`.
#' @param x An object produced by [quantile_pred()].
#' @param .rows,.name_repair,rownames Arguments not used but required by the
#' original S3 method.
#' @param ... Not currently used.
#'
#' @export
#' @return
#' * [quantile_pred()] returns a vector of values associated with the
#' quantile levels.
#' * [extract_quantile_levels()] returns a numeric vector of levels.
#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`,
#' `".quantile_levels"`, and `".row"`.
#' * [as.matrix()] returns an unnamed matrix with rows as sames, columns as
#' quantile levels, and entries are predictions.
#' @examples
#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8))
#'
#' unclass(.pred_quantile)
#'
#' # Access the underlying information
#' extract_quantile_levels(.pred_quantile)
#'
#' # Matrix format
#' as.matrix(.pred_quantile)
#'
#' # Tidy format
#' tibble::as_tibble(.pred_quantile)
quantile_pred <- function(values, quantile_levels = double()) {
check_quantile_pred_inputs(values, quantile_levels)

quantile_levels <- vctrs::vec_cast(quantile_levels, double())
num_lvls <- length(quantile_levels)

if (ncol(values) != num_lvls) {
cli::cli_abort(
"The number of columns in {.arg values} must be equal to the length of
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
{.arg quantile_levels}."
)
}
rownames(values) <- NULL
colnames(values) <- NULL
values <- lapply(vctrs::vec_chop(values), drop)
new_quantile_pred(values, quantile_levels)
}

check_quantile_pred_inputs <- function(values, levels, call = caller_env()) {
if ( any(is.na(levels)) ) {
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.",
call = call)
}
n <- nrow(x)

if (!is.matrix(values)) {
cli::cli_abort(
"{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.",
call = call
)
}
check_vector_probability(levels, arg = "quantile_levels", call = call)

if (is.unsorted(levels)) {
cli::cli_abort(
"{.arg quantile_levels} must be sorted in increasing order.",
call = call
)
}
invisible(NULL)
}

#' @export
format.quantile_pred <- function(x, ...) {
quantile_levels <- attr(x, "quantile_levels")
if (length(quantile_levels) == 1L) {
x <- unlist(x)
out <- round(x, 3L)
dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
out[is.na(x)] <- NA_real_
} else {
rng <- sapply(x, range, na.rm = TRUE)
out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]")
out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_
m <- median(x)
out <- paste0("[", round(m, 3L), "]")
}
out
}

#' @importFrom vctrs obj_print_footer
#' @export
vctrs::obj_print_footer

#' @export
obj_print_footer.quantile_pred <- function(x, digits = 3, ...) {
lvls <- attr(x, "quantile_levels")
cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ")
}

check_vector_probability <- function(x, ...,
allow_na = FALSE,
allow_null = FALSE,
arg = caller_arg(x),
call = caller_env()) {
for (d in x) {
check_number_decimal(
d, min = 0, max = 1,
arg = arg, call = call,
allow_na = allow_na,
allow_null = allow_null,
allow_infinite = FALSE
)
}
}

#' @export
median.quantile_pred <- function(x, ...) {
lvls <- attr(x, "quantile_levels")
loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps))
if (any(loc_median)) {
return(map_dbl(x, ~ .x[min(which(loc_median))]))
}
if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) {
return(rep(NA, vctrs::vec_size(x)))
}
map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y)
}

dajmcdon marked this conversation as resolved.
Show resolved Hide resolved
restructure_rq_pred <- function(x, object) {
if (!is.matrix(x)) {
x <- as.matrix(x)
}
rownames(x) <- NULL
n_pred_quantiles <- ncol(x)
quantile_level <- object$spec$quantile_level
res <-
tibble::tibble(
.pred_quantile = as.vector(x),
.quantile_level = rep(quantile_level, each = n),
.row = rep(1:n, num_quantiles))
res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"])
res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val)))
res$.row <- NULL
res

tibble::new_tibble(x = list(.pred_quantile = quantile_pred(x, quantile_level)))
}

#' @export
#' @rdname quantile_pred
extract_quantile_levels <- function(x) {
if ( !inherits(x, "quantile_pred") ) {
cli::cli_abort("{.arg x} should have class {.val quantile_pred}.")
}
attr(x, "quantile_levels")
}

#' @export
#' @rdname quantile_pred
as_tibble.quantile_pred <-
function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) {
lvls <- attr(x, "quantile_levels")
n_samp <- length(x)
n_quant <- length(lvls)
tibble::tibble(
.pred_quantile = unlist(x),
.quantile_levels = rep(lvls, n_samp),
.row = rep(1:n_samp, each = n_quant)
)
topepo marked this conversation as resolved.
Show resolved Hide resolved
}

#' @export
#' @rdname quantile_pred
as.matrix.quantile_pred <- function(x, ...) {
num_samp <- length(x)
matrix(unlist(x), nrow = num_samp)
}
2 changes: 1 addition & 1 deletion R/parsnip-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#' @importFrom stats .checkMFClasses .getXlevels as.formula binomial coef
#' @importFrom stats delete.response model.frame model.matrix model.offset
#' @importFrom stats model.response model.weights na.omit na.pass predict qnorm
#' @importFrom stats qt quantile setNames terms update
#' @importFrom stats qt quantile setNames terms update median
#' @importFrom tibble as_tibble is_tibble tibble
#' @importFrom tidyr gather
#' @importFrom utils capture.output getFromNamespace globalVariables head
Expand Down
15 changes: 10 additions & 5 deletions R/predict_quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
#' @method predict_quantile model_fit
#' @export predict_quantile.model_fit
#' @export
predict_quantile.model_fit <- function(object, new_data, ...) {
predict_quantile.model_fit <- function(object,
new_data,
quantile = (1:9)/10,
interval = "none",
level = 0.95,
...) {

check_spec_pred_type(object, "quantile")

Expand All @@ -18,12 +23,11 @@ predict_quantile.model_fit <- function(object, new_data, ...) {
new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$quantile$pre)) {
if (!is.null(object$spec$method$pred$quantile$pre))
new_data <- object$spec$method$pred$quantile$pre(new_data, object)
}
topepo marked this conversation as resolved.
Show resolved Hide resolved

# Pass some extra arguments to be used in post-processor
object$spec$method$pred$quantile$args$quantile_level <- object$quantile_level
object$spec$method$pred$quantile$args$p <- quantile
pred_call <- make_pred_call(object$spec$method$pred$quantile)

res <- eval_tidy(pred_call)
Expand All @@ -40,5 +44,6 @@ predict_quantile.model_fit <- function(object, new_data, ...) {
# @keywords internal
# @rdname other_predict
# @inheritParams predict.model_fit
predict_quantile <- function (object, ...)
predict_quantile <- function (object, ...) {
UseMethod("predict_quantile")
}
9 changes: 8 additions & 1 deletion man/other_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading