Skip to content

Commit

Permalink
Merge pull request #80 from mayer79/na_rm
Browse files Browse the repository at this point in the history
Missing values
  • Loading branch information
mayer79 authored Oct 19, 2023
2 parents ebf6e38 + d66e077 commit de0d70d
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 38 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- `average_loss()` is more flexible regarding the group `BY` argument. It can also be a variable *name*. Non-discrete `BY` variables are now automatically binned. Like `partial_dep()`, binning is controlled by the `by_size = 4` argument.
- `average_loss()` also returns a "hstats_matrix" object with `print()` and `plot()` method. The values can be extracted via `$M`.
- The default `v` of `hstats()` and `perm_importance()` is now `NULL`. Internally, it is set to `colnames(X)` (minus the column names of `w` and `y` if passed as name).
- Missing grid values: `partial_dep()` and `ice()` have received a `na.rm = TRUE` argument that controls if missing values are dropped during grid creation. The default is compatible with earlier releases.

# hstats 0.3.0

Expand Down
17 changes: 12 additions & 5 deletions R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ ice <- function(object, ...) {
ice.default <- function(object, v, X, pred_fun = stats::predict,
BY = NULL, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 100L, ...) {
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100L, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun),
Expand All @@ -69,7 +70,7 @@ ice.default <- function(object, v, X, pred_fun = stats::predict,
# Prepare grid
if (is.null(grid)) {
grid <- multivariate_grid(
x = X[, v], grid_size = grid_size, trim = trim, strategy = strategy
x = X[, v], grid_size = grid_size, trim = trim, strategy = strategy, na.rm = na.rm
)
} else {
check_grid(g = grid, v = v, X_is_matrix = is.matrix(X))
Expand Down Expand Up @@ -142,7 +143,8 @@ ice.ranger <- function(object, v, X,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
BY = NULL, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 100, ...) {
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100, ...) {
ice.default(
object = object,
v = v,
Expand All @@ -153,6 +155,7 @@ ice.ranger <- function(object, v, X,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
...
)
Expand All @@ -163,7 +166,8 @@ ice.ranger <- function(object, v, X,
ice.Learner <- function(object, v, X,
pred_fun = NULL,
BY = NULL, grid = NULL, grid_size = 49L, trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 100L, ...) {
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100L, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
Expand All @@ -177,6 +181,7 @@ ice.Learner <- function(object, v, X,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
...
)
Expand All @@ -188,7 +193,8 @@ ice.explainer <- function(object, v = v, X = object[["data"]],
pred_fun = object[["predict_function"]],
BY = NULL, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 100, ...) {
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100, ...) {
ice.default(
object = object[["model"]],
v = v,
Expand All @@ -199,6 +205,7 @@ ice.explainer <- function(object, v = v, X = object[["data"]],
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
...
)
Expand Down
23 changes: 13 additions & 10 deletions R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ partial_dep <- function(object, ...) {
partial_dep.default <- function(object, v, X, pred_fun = stats::predict,
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 1000L,
w = NULL, ...) {
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 1000L, w = NULL, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun),
Expand All @@ -110,7 +110,7 @@ partial_dep.default <- function(object, v, X, pred_fun = stats::predict,
# Care about grid
if (is.null(grid)) {
grid <- multivariate_grid(
x = X[, v], grid_size = grid_size, trim = trim, strategy = strategy
x = X[, v], grid_size = grid_size, trim = trim, strategy = strategy, na.rm = na.rm
)
} else {
check_grid(g = grid, v = v, X_is_matrix = is.matrix(X))
Expand All @@ -130,7 +130,7 @@ partial_dep.default <- function(object, v, X, pred_fun = stats::predict,
out <- partial_dep.default(
object = object,
v = v,
X = X[BY2$BY %in% b, , drop = FALSE],
X = X[BY2$BY %in% b, , drop = FALSE], # works also when by is NA
pred_fun = pred_fun,
grid = grid,
n_max = n_max,
Expand Down Expand Up @@ -185,8 +185,8 @@ partial_dep.ranger <- function(object, v, X,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 1000L,
w = NULL, ...) {
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 1000L, w = NULL, ...) {
partial_dep.default(
object = object,
v = v,
Expand All @@ -198,6 +198,7 @@ partial_dep.ranger <- function(object, v, X,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
w = w,
...
Expand All @@ -210,8 +211,8 @@ partial_dep.Learner <- function(object, v, X,
pred_fun = NULL,
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 1000L,
w = NULL, ...) {
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 1000L, w = NULL, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
Expand All @@ -226,6 +227,7 @@ partial_dep.Learner <- function(object, v, X,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
w = w,
...
Expand All @@ -238,8 +240,8 @@ partial_dep.explainer <- function(object, v, X = object[["data"]],
pred_fun = object[["predict_function"]],
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 1000L,
w = object[["weights"]], ...) {
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 1000L, w = object[["weights"]], ...) {
partial_dep.default(
object = object[["model"]],
v = v,
Expand All @@ -251,6 +253,7 @@ partial_dep.explainer <- function(object, v, X = object[["data"]],
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
w = w,
...
Expand Down
2 changes: 1 addition & 1 deletion R/utils_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ wrowmean <- function(x, ngroups = 1L, w = NULL) {
}
list(
X = X[!X_dup, , drop = FALSE],
w = c(rowsum(w, group = x_not_v, reorder = FALSE))
w = c(rowsum(w, group = x_not_v, reorder = FALSE)) # warning if missing in x_not_v
)
}

Expand Down
35 changes: 24 additions & 11 deletions R/utils_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#' of grid values. Set to `0:1` for no trimming.
#' @param strategy How to find grid values of non-discrete numeric columns?
#' Either "uniform" or "quantile", see description of [univariate_grid()].
#' @param na.rm Should missing values be dropped from grid? Default is `TRUE`.
#' @returns A vector or factor of evaluation points.
#' @seealso [multivariate_grid()]
#' @export
Expand All @@ -35,24 +36,29 @@
#' univariate_grid(x, grid_size = 5) # Quantile binning
#' univariate_grid(x, grid_size = 5, strategy = "uniform") # Uniform
univariate_grid <- function(z, grid_size = 49L, trim = c(0.01, 0.99),
strategy = c("uniform", "quantile")) {
strategy = c("uniform", "quantile"), na.rm = TRUE) {
strategy <- match.arg(strategy)
uni <- unique(z)
if (!is.numeric(z) || length(uni) <= grid_size) {
return(sort(uni))
out <- if (na.rm) sort(uni) else sort(uni, na.last = TRUE)
return(out)
}

# Non-discrete numeric
if (strategy == "quantile") {
p <- seq(trim[1L], trim[2L], length.out = grid_size)
g <- stats::quantile(z, probs = p, names = FALSE, type = 1L, na.rm = TRUE)
return(unique(g))
out <- unique(g)
} else {
# strategy = "uniform" (could use range() if trim = 0:1)
r <- stats::quantile(z, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
# pretty(r, n = grid_size) # Until version 0.2.0
out <- seq(r[1L], r[2L], length.out = grid_size)
}

# strategy = "uniform" (could use range() if trim = 0:1)
r <- stats::quantile(z, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
# pretty(r, n = grid_size) # Until version 0.2.0
seq(r[1L], r[2L], length.out = grid_size)
if (!na.rm && anyNA(z)) {
out <- c(out, NA)
}
return(out)
}

#' Multivariate Grid
Expand All @@ -72,14 +78,17 @@ univariate_grid <- function(z, grid_size = 49L, trim = c(0.01, 0.99),
#' multivariate_grid(iris$Species) # Works also in the univariate case
#' @export
multivariate_grid <- function(x, grid_size = 49L, trim = c(0.01, 0.99),
strategy = c("uniform", "quantile")) {
strategy = c("uniform", "quantile"), na.rm = TRUE) {
strategy <- match.arg(strategy)
p <- NCOL(x)
if (p == 1L) {
if (is.data.frame(x)) {
x <- x[[1L]]
}
return(univariate_grid(x, grid_size = grid_size, trim = trim, strategy = strategy))
out <- univariate_grid(
x, grid_size = grid_size, trim = trim, strategy = strategy, na.rm = na.rm
)
return(out)
}
grid_size <- ceiling(grid_size^(1/p)) # take p's root of grid_size
is_mat <- is.matrix(x)
Expand All @@ -89,7 +98,11 @@ multivariate_grid <- function(x, grid_size = 49L, trim = c(0.01, 0.99),
out <- expand.grid(
lapply(
x,
FUN = univariate_grid, grid_size = grid_size, trim = trim, strategy = strategy
FUN = univariate_grid,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm
)
)
if (is_mat) as.matrix(out) else out
Expand Down
6 changes: 6 additions & 0 deletions man/ice.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/multivariate_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/partial_dep.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/univariate_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit de0d70d

Please sign in to comment.