From 779ff2a2386e35391167463b36c88cff94a26893 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Thu, 18 Jul 2024 13:52:09 -0400 Subject: [PATCH 1/3] add quantile mode for lm --- DESCRIPTION | 4 +-- R/augment.R | 15 ++++++-- R/linear_reg_data.R | 76 ++++++++++++++++++++++++++++++++++++++++ R/predict.R | 4 +-- R/predict_quantile.R | 2 +- man/augment.Rd | 2 +- man/predict.model_fit.Rd | 4 +-- 7 files changed, 96 insertions(+), 11 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index bfa80fed1..c94f64995 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -77,6 +77,6 @@ Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -Remotes: +Remotes: tidymodels/hardhat -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 diff --git a/R/augment.R b/R/augment.R index 7ecea0d6f..aaf405f03 100644 --- a/R/augment.R +++ b/R/augment.R @@ -78,12 +78,12 @@ #' augment(cls_xy, cls_tst) #' augment(cls_xy, cls_tst[, -3]) #' -augment.model_fit <- function(x, new_data, eval_time = NULL, ...) { +augment.model_fit <- function(x, new_data, eval_time = NULL, quantile = NULL, ...) { new_data <- tibble::new_tibble(new_data) res <- switch( x$spec$mode, - "regression" = augment_regression(x, new_data), + "regression" = augment_regression(x, new_data, quantile = quantile), "classification" = augment_classification(x, new_data), "censored regression" = augment_censored(x, new_data, eval_time = eval_time), rlang::abort(paste("Unknown mode:", x$spec$mode)) @@ -91,9 +91,17 @@ augment.model_fit <- function(x, new_data, eval_time = NULL, ...) { tibble::new_tibble(res) } -augment_regression <- function(x, new_data) { +augment_regression <- function(x, new_data, quantile = NULL) { ret <- new_data check_spec_pred_type(x, "numeric") + + if (spec_has_pred_type(x, "quantile") & !is.null(quantile)) { + ret <- + dplyr::bind_cols( + predict(x, new_data = new_data, type = "quantile", quantile = quantile), + ret) + } + ret <- dplyr::bind_cols(predict(x, new_data = new_data), ret) if (length(x$preproc$y_var) > 0) { y_nm <- x$preproc$y_var @@ -101,6 +109,7 @@ augment_regression <- function(x, new_data) { ret <- dplyr::mutate(ret, .resid = !!rlang::sym(y_nm) - .pred) } } + dplyr::relocate(ret, dplyr::starts_with(".pred"), dplyr::starts_with(".resid")) } diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index bdf6a3753..709817f05 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -73,6 +73,7 @@ set_pred( ) ) ) + set_pred( model = "linear_reg", eng = "lm", @@ -97,6 +98,24 @@ set_pred( ) ) +set_pred( + model = "linear_reg", + eng = "lm", + mode = "regression", + type = "quantile", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "lm_quantile"), + args = + list( + object = expr(object$fit), + new_data = expr(new_data), + quantile = expr(quantile) + ) + ) +) + set_pred( model = "linear_reg", eng = "lm", @@ -582,3 +601,60 @@ set_pred( ) ) +# ------------------------------------------------------------------------------ +# Helper functions + +lm_quantile <- function(object, new_data, quantile = (1:9)/10) { + quantile <- sort(unique(quantile)) + + .row <- 1:nrow(new_data) + + if ( any(quantile == 0.5) ) { + preds <- + tibble::tibble(.quantile = 1/2, + .pred_quantile =predict(object, new_data), + .row = .row) + } else { + preds <- NULL + } + + upper_quantile <- quantile[quantile > .5] + lower_quantile <- quantile[quantile < .5] + + if ( length(upper_quantile) > 0 ) { + # Convert (1 - level) / 2 to actual quantile + # so using level = 0.95 will give you the 0.975 value; to actually get 0.95 + # we need to decrease it a bit + rev_quant = 1 - upper_quantile + upper_adjusted <- 1 + -2 * rev_quant + } + if ( length(lower_quantile) > 0 ) { + upper_adjusted <- 2 * lower_quantile + } + not_center <- c(lower_quantile, upper_quantile) + adjusted <- c(upper_adjusted, upper_adjusted) + + for ( i in seq_along(not_center) ) { + tmp_pred <- predict(object, new_data, interval = "prediction", level = adjusted[i]) + if ( not_center[i] > 0.5) { + tmp_pred <- tmp_pred[, "upr"] + } else { + tmp_pred <- tmp_pred[, "lwr"] + } + tmp_pred <- + tibble::tibble(.quantile = not_center[i], + .pred_quantile = tmp_pred, + .row = .row) + preds <- dplyr::bind_rows(preds, tmp_pred) + } + + preds <- preds[order(preds$.row, preds$.quantile), ] + preds <- + vctrs::vec_split( + x = preds[setdiff(colnames(preds), ".row")], + by = preds$.row + ) + tibble::new_tibble(list(.pred_quantile = preds$val)) +} + + diff --git a/R/predict.R b/R/predict.R index 327bd80ef..82d891e04 100644 --- a/R/predict.R +++ b/R/predict.R @@ -86,9 +86,9 @@ #' produces for class probabilities (or other non-scalar outputs), #' the columns are named `.pred_lower_classlevel` and so on. #' -#' For `type = "quantile"`, the tibble has a `.pred` column, which is +#' For `type = "quantile"`, the tibble has a `.pred_quantile` column, which is #' a list-column. Each list element contains a tibble with columns -#' `.pred` and `.quantile` (and perhaps other columns). +#' `.pred_quantile` and `.quantile`. #' #' For `type = "time"`, the tibble has a `.pred_time` column. #' diff --git a/R/predict_quantile.R b/R/predict_quantile.R index c2817e48b..039e39ab1 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -27,7 +27,7 @@ predict_quantile.model_fit <- function(object, new_data <- object$spec$method$pred$quantile$pre(new_data, object) # Pass some extra arguments to be used in post-processor - object$spec$method$pred$quantile$args$p <- quantile + object$spec$method$pred$quantile$args$quantile <- quantile pred_call <- make_pred_call(object$spec$method$pred$quantile) res <- eval_tidy(pred_call) diff --git a/man/augment.Rd b/man/augment.Rd index 100645abd..8e27c0bf4 100644 --- a/man/augment.Rd +++ b/man/augment.Rd @@ -4,7 +4,7 @@ \alias{augment.model_fit} \title{Augment data with predictions} \usage{ -\method{augment}{model_fit}(x, new_data, eval_time = NULL, ...) +\method{augment}{model_fit}(x, new_data, eval_time = NULL, quantile = NULL, ...) } \arguments{ \item{x}{A \code{model_fit} object produced by \code{\link[=fit.model_spec]{fit.model_spec()}} or diff --git a/man/predict.model_fit.Rd b/man/predict.model_fit.Rd index eaf3364ef..0bb399370 100644 --- a/man/predict.model_fit.Rd +++ b/man/predict.model_fit.Rd @@ -70,9 +70,9 @@ the confidence level. In the case where intervals can be produces for class probabilities (or other non-scalar outputs), the columns are named \code{.pred_lower_classlevel} and so on. -For \code{type = "quantile"}, the tibble has a \code{.pred} column, which is +For \code{type = "quantile"}, the tibble has a \code{.pred_quantile} column, which is a list-column. Each list element contains a tibble with columns -\code{.pred} and \code{.quantile} (and perhaps other columns). +\code{.pred_quantile} and \code{.quantile}. For \code{type = "time"}, the tibble has a \code{.pred_time} column. From c9fb07041d3c0733683c47a66f82d9481a92b470 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Thu, 18 Jul 2024 15:36:11 -0400 Subject: [PATCH 2/3] update news --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index c51afb0e7..a324d1462 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,6 +7,7 @@ * New `extract_fit_time()` method has been added that returns the time it took to train the model (#853). +* Adding `"quantile"` prediction methods for engines using `lm`, `stan`, and `dbarts`. # parsnip 1.2.1 From f0d1c2f52a0d323a77cae28893a4d1645c70b349 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Thu, 18 Jul 2024 15:36:16 -0400 Subject: [PATCH 3/3] small updates --- R/linear_reg_data.R | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 709817f05..8268e56e1 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -612,31 +612,33 @@ lm_quantile <- function(object, new_data, quantile = (1:9)/10) { if ( any(quantile == 0.5) ) { preds <- tibble::tibble(.quantile = 1/2, - .pred_quantile =predict(object, new_data), + .pred_quantile = predict(object, new_data), .row = .row) } else { preds <- NULL } + # Convert (1 - level) / 2 to actual quantile since predict.lm() does two-sided + # intervals. For example, using level = 0.95 will give you the intervals + # based on c(0.25, 0.975). To actually get c(0,05, 0.95), we need to make + # an adjustment. + upper_quantile <- quantile[quantile > .5] lower_quantile <- quantile[quantile < .5] if ( length(upper_quantile) > 0 ) { - # Convert (1 - level) / 2 to actual quantile - # so using level = 0.95 will give you the 0.975 value; to actually get 0.95 - # we need to decrease it a bit rev_quant = 1 - upper_quantile upper_adjusted <- 1 + -2 * rev_quant } if ( length(lower_quantile) > 0 ) { - upper_adjusted <- 2 * lower_quantile + lower_adjusted <- 2 * lower_quantile } not_center <- c(lower_quantile, upper_quantile) - adjusted <- c(upper_adjusted, upper_adjusted) + adjusted <- c(lower_adjusted, upper_adjusted) for ( i in seq_along(not_center) ) { tmp_pred <- predict(object, new_data, interval = "prediction", level = adjusted[i]) - if ( not_center[i] > 0.5) { + if ( not_center[i] > 0.5 ) { tmp_pred <- tmp_pred[, "upr"] } else { tmp_pred <- tmp_pred[, "lwr"] @@ -648,13 +650,18 @@ lm_quantile <- function(object, new_data, quantile = (1:9)/10) { preds <- dplyr::bind_rows(preds, tmp_pred) } - preds <- preds[order(preds$.row, preds$.quantile), ] - preds <- + # Now convert to list columns + quant_to_list(preds) +} + +quant_to_list <- function(x) { + x <- x[order(x$.row, x$.quantile), ] + x <- vctrs::vec_split( - x = preds[setdiff(colnames(preds), ".row")], - by = preds$.row + x = x[setdiff(colnames(x), ".row")], + by = x$.row ) - tibble::new_tibble(list(.pred_quantile = preds$val)) + tibble::new_tibble(list(.pred_quantile = x$val)) }