Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quantile regression mode #1209

Merged
merged 29 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e7ce07b
add a quantile regression mode to test with
topepo Aug 5, 2024
cbf636f
update type checkers
Aug 29, 2024
6bec5cb
avoid confusion with global all_models object
Aug 29, 2024
5f6db8c
add quantile_level argument to set_mode()
Aug 29, 2024
d753196
initial data for quantreg
Aug 29, 2024
5176e86
some initial tests
Aug 29, 2024
abcd97d
fix some issues
Aug 29, 2024
a153df0
enable quantile prediction
Aug 29, 2024
6168556
tests for quantreg
topepo Sep 5, 2024
3bdb471
Quantile predictions output constructor (#1191)
dajmcdon Sep 13, 2024
bef131b
quantile regression updates for new hardhat model (#1207)
topepo Sep 25, 2024
861a64a
Change to `quantile` argument to `quantile levels` (#1208)
topepo Sep 26, 2024
db551ff
Merge branch 'main' into quantile-mode
topepo Sep 26, 2024
44ebdf7
post conflict merge updates
topepo Sep 26, 2024
481dbc3
update news
topepo Sep 26, 2024
69e754e
version bump and fix typo
topepo Sep 26, 2024
9ebcae0
revert GHA branches
topepo Sep 26, 2024
2fcedb9
small bug fix
Sep 27, 2024
58025ca
Apply suggestions from code review
topepo Oct 10, 2024
a870db4
don't export median
topepo Oct 10, 2024
f5442a7
add call arg
topepo Oct 10, 2024
76d4ff6
Merge branch 'quantile-mode' of https://github.com/tidymodels/parsnip…
topepo Oct 10, 2024
4db8ca6
added documentation on model
topepo Oct 10, 2024
603f47a
add mode
topepo Oct 10, 2024
d8cae62
convert error to warning
topepo Oct 10, 2024
07513b6
remove rankdeficient
topepo Oct 10, 2024
b8fe3a1
added skip
topepo Oct 10, 2024
d88543c
add deprecated `quantile` arg back in
topepo Oct 10, 2024
46fa0db
remove numeric prediction
topepo Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.2.1.9002
Version: 1.2.1.9003
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand All @@ -25,7 +25,7 @@ Imports:
ggplot2,
globals,
glue,
hardhat (>= 1.4.0),
hardhat (>= 1.4.0.9002),
lifecycle,
magrittr,
pillar,
Expand All @@ -40,8 +40,8 @@ Imports:
vctrs (>= 0.6.0),
withr
Suggests:
C50,
bench,
C50,
covr,
dials (>= 1.1.0),
earth,
Expand Down Expand Up @@ -69,6 +69,9 @@ Suggests:
xgboost (>= 1.5.0.1)
VignetteBuilder:
knitr
Remotes:
r-lib/sparsevctrs,
tidymodels/hardhat
ByteCompile: true
Config/Needs/website: C50, dbarts, earth, glmnet, keras, kernlab, kknn,
LiblineaR, mgcv, nnet, parsnip, randomForest, ranger, rpart, rstanarm,
Expand All @@ -79,6 +82,4 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
Remotes:
r-lib/sparsevctrs
RoxygenNote: 7.3.2
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ export(make_classes)
export(make_engine_list)
export(make_seealso_list)
export(mars)
export(matrix_to_quantile_pred)
export(max_mtry_formula)
export(maybe_data_frame)
export(maybe_matrix)
Expand Down Expand Up @@ -402,6 +403,7 @@ importFrom(stats,as.formula)
importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,delete.response)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,model.offset)
Expand Down
24 changes: 20 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,41 @@
# parsnip (development version)

## New Features

* A new model mode (`"quantile regression"`) was added. Including:
* A `linear_reg()` engine for `"quantreg"`.
* Predictions are encoded via a custom vector type. See [hardhat::quantile_pred()].
* Predicted quantile levels are designated when the new mode is specified. See `?set_mode`.

* `fit_xy()` can now take dgCMatrix input for `x` argument (#1121).

* `fit_xy()` can now take sparse tibbles as data values (#1165).

* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument, and error informatively when model doesn't support it (#1167).

* Transitioned package errors and warnings to use cli (#1147 and #1148 by
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
#1161, #1081).
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).

## Other Changes

* Transitioned package errors and warnings to use cli (#1147 and #1148 by @shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160, #1161, #1081).

* `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775).

* Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083).

* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
## Bug Fixes

* Ensure that `knit_engine_docs()` has the required packages installed (#1156).

* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).

## Breaking Change

* For quantile prediction, the `predict()` argument has been changed from `quantile` to `quantile_levels` for consistency. This does not affect models with mode `"quantile regression"`.
topepo marked this conversation as resolved.
Show resolved Hide resolved

* The quantile regression prediction type was disabled for the deprecated `surv_reg()` model.


# parsnip 1.2.1

* Added a missing `tidy()` method for survival analysis glmnet models (#1086).
Expand Down
21 changes: 16 additions & 5 deletions R/aaa-import-standalone-types-check.R
topepo marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# Standalone file: do not edit by hand
# Source: <https://github.com/r-lib/rlang/blob/main/R/standalone-types-check.R>
# ----------------------------------------------------------------------
#
# ---
# repo: r-lib/rlang
# file: standalone-types-check.R
Expand All @@ -13,6 +9,9 @@
#
# ## Changelog
#
# 2024-08-15:
# - `check_character()` gains an `allow_na` argument (@martaalcalde, #1724)
#
# 2023-03-13:
# - Improved error messages of number checkers (@teunbrand)
# - Added `allow_infinite` argument to `check_number_whole()` (@mgirlich).
Expand Down Expand Up @@ -461,15 +460,28 @@ check_formula <- function(x,

# Vectors -----------------------------------------------------------------

# TODO: Figure out what to do with logical `NA` and `allow_na = TRUE`

check_character <- function(x,
...,
allow_na = TRUE,
allow_null = FALSE,
arg = caller_arg(x),
call = caller_env()) {

if (!missing(x)) {
if (is_character(x)) {
if (!allow_na && any(is.na(x))) {
abort(
sprintf("`%s` can't contain NA values.", arg),
arg = arg,
call = call
)
}

return(invisible(NULL))
}

if (allow_null && is_null(x)) {
return(invisible(NULL))
}
Expand All @@ -479,7 +491,6 @@ check_character <- function(x,
x,
"a character vector",
...,
allow_na = FALSE,
allow_null = allow_null,
arg = arg,
call = call
Expand Down
8 changes: 4 additions & 4 deletions R/aaa_models.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Initialize model environments

all_modes <- c("classification", "regression", "censored regression")
all_modes <- c("classification", "regression", "censored regression", "quantile regression")

# ------------------------------------------------------------------------------

Expand Down Expand Up @@ -194,9 +194,9 @@ stop_missing_engine <- function(cls, call) {
)
}

check_mode_for_new_engine <- function(cls, eng, mode, call = caller_env()) {
all_modes <- get_from_env(paste0(cls, "_modes"))
if (!(mode %in% all_modes)) {
check_mode_for_new_engine <- function(cls, eng, mode) {
topepo marked this conversation as resolved.
Show resolved Hide resolved
model_modes <- get_from_env(paste0(cls, "_modes"))
if (!(mode %in% model_modes)) {
topepo marked this conversation as resolved.
Show resolved Hide resolved
cli::cli_abort(
"{.val {mode}} is not a known mode for model {.fn {cls}}.",
call = call
Expand Down
17 changes: 17 additions & 0 deletions R/aaa_quantiles.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#' Reformat quantile predictions
#'
#' @param x A matrix of predictions with rows as samples and columns as quantile
#' levels.
#' @param object A parsnip `model_fit` object from a quantile regression model.
#' @keywords internal
#' @export
matrix_to_quantile_pred <- function(x, object) {
topepo marked this conversation as resolved.
Show resolved Hide resolved
if (!is.matrix(x)) {
x <- as.matrix(x)
}
rownames(x) <- NULL
n_pred_quantiles <- ncol(x)
quantile_levels <- object$spec$quantile_levels

tibble::new_tibble(x = list(.pred_quantile = hardhat::quantile_pred(x, quantile_levels)))
}
23 changes: 20 additions & 3 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ check_eng_args <- function(args, obj, core_args) {
#' set_args(mtry = 3, importance = TRUE) %>%
#' set_mode("regression")
#'
#' linear_reg() %>%
#' set_mode("quantile regression", quantile_levels = c(0.2, 0.5, 0.8))
#' @export
set_args <- function(object, ...) {
UseMethod("set_args")
Expand Down Expand Up @@ -89,12 +91,17 @@ set_args.default <- function(object,...) {

#' @rdname set_args
#' @export
set_mode <- function(object, mode) {
set_mode <- function(object, mode, ...) {
UseMethod("set_mode")
}

#' @rdname set_args
#' @param quantile_levels A vector of values between zero and one (only for the
#' `"quantile regression"` mode); otherwise, it is `NULL`. The model uses these
#' values to appropriately train quantile regression models to make predictions
#' for these values (e.g., `quantile_levels = 0.5` is the median).
#' @export
set_mode.model_spec <- function(object, mode) {
set_mode.model_spec <- function(object, mode, quantile_levels = NULL, ...) {
topepo marked this conversation as resolved.
Show resolved Hide resolved
cls <- class(object)[1]
if (rlang::is_missing(mode)) {
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
Expand All @@ -111,11 +118,21 @@ set_mode.model_spec <- function(object, mode) {

object$mode <- mode
object$user_specified_mode <- TRUE
if (mode == "quantile regression") {
hardhat::check_quantile_levels(quantile_levels)
} else {
if (!is.null(quantile_levels)) {
cli::cli_abort("{.arg quantile_levels} is only used when the mode is
topepo marked this conversation as resolved.
Show resolved Hide resolved
{.val quantile regression}.")
}
}

object$quantile_levels <- quantile_levels
object
}

#' @export
set_mode.default <- function(object, mode) {
set_mode.default <- function(object, mode, ...) {
error_set_object(object, func = "set_mode")

invisible(FALSE)
Expand Down
9 changes: 8 additions & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ fit.model_spec <-
eval_env$formula <- formula
eval_env$weights <- wts

if (!is.null(object$quantile_levels)) {
eval_env$quantile_levels <- object$quantile_levels
}

data <- materialize_sparse_tibble(data, object, "data")

fit_interface <-
Expand All @@ -187,7 +191,6 @@ fit.model_spec <-
with a spark data object."
)


# populate `method` with the details for this model type
object <- add_methods(object, engine = object$engine)

Expand Down Expand Up @@ -295,6 +298,10 @@ fit_xy.model_spec <-
eval_env$y_var <- y_var
eval_env$weights <- weights_to_numeric(case_weights, object)

if (!is.null(object$quantile_levels)) {
eval_env$quantile_levels <- object$quantile_levels
}

# TODO case weights: pass in eval_env not individual elements
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)

Expand Down
65 changes: 65 additions & 0 deletions R/linear_reg_data.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set_new_model("linear_reg")

set_model_mode("linear_reg", "regression")
set_model_mode("linear_reg", "quantile regression")

# ------------------------------------------------------------------------------

Expand Down Expand Up @@ -582,3 +583,67 @@ set_pred(
)
)

# ------------------------------------------------------------------------------

set_model_engine("linear_reg", "quantile regression", "quantreg")
set_dependency("linear_reg", "quantreg", "quantreg")
topepo marked this conversation as resolved.
Show resolved Hide resolved

set_fit(
model = "linear_reg",
eng = "quantreg",
mode = "quantile regression",
value = list(
interface = "formula",
protect = c("formula", "data", "weights"),
func = c(pkg = "quantreg", fun = "rq"),
defaults = list(tau = expr(quantile_levels))
)
)

set_encoding(
model = "linear_reg",
eng = "quantreg",
mode = "quantile regression",
options = list(
predictor_indicators = "traditional",
compute_intercept = TRUE,
remove_intercept = TRUE,
allow_sparse_x = FALSE
)
)

set_pred(
model = "linear_reg",
eng = "quantreg",
mode = "quantile regression",
type = "numeric",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data),
type = "response",
rankdeficient = "simple"
topepo marked this conversation as resolved.
Show resolved Hide resolved
)
)
)

set_pred(
model = "linear_reg",
eng = "quantreg",
mode = "quantile regression",
type = "quantile",
value = list(
pre = NULL,
post = matrix_to_quantile_pred,
func = c(fun = "predict"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data)
)
)
)
2 changes: 1 addition & 1 deletion R/parsnip-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#' @importFrom stats .checkMFClasses .getXlevels as.formula binomial coef
#' @importFrom stats delete.response model.frame model.matrix model.offset
#' @importFrom stats model.response model.weights na.omit na.pass predict qnorm
#' @importFrom stats qt quantile setNames terms update
#' @importFrom stats qt quantile setNames terms update median
topepo marked this conversation as resolved.
Show resolved Hide resolved
#' @importFrom tibble as_tibble is_tibble tibble
#' @importFrom tidyr gather
#' @importFrom utils capture.output getFromNamespace globalVariables head
Expand Down
6 changes: 4 additions & 2 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,14 @@ check_pred_type <- function(object, type, ..., call = rlang::caller_env()) {
regression = "numeric",
classification = "class",
"censored regression" = "time",
"quantile regression" = "quantile",
cli::cli_abort(
"{.arg type} should be 'regression', 'censored regression', or 'classification'.",
"{.arg type} should be one of {.val {all_modes}}.",
topepo marked this conversation as resolved.
Show resolved Hide resolved
call = call
)
)
}

if (!(type %in% pred_types))
cli::cli_abort(
"{.arg type} should be one of {.or {.arg {pred_types}}}.",
Expand Down Expand Up @@ -373,7 +375,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())

# ----------------------------------------------------------------------------

other_args <- c("interval", "level", "std_error", "quantile",
other_args <- c("interval", "level", "std_error", "quantile_levels",
"time", "eval_time", "increasing")

eval_time_types <- c("survival", "hazard")
Expand Down
Loading
Loading