Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slight optimization #115

Merged
merged 5 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

- Slight speed-up of `permshap()` by saving calculations for the two special permutations of all 0 and all 1.
- Consequently, the `m_exact` component in the output is reduced by 2.
- Slight speed-up of `permshap()` by optimizing internal function `shapley_formula()`.
- Slight speed-up of `kernelshap()` and `permshap()` for single-output predictions.

# kernelshap 0.4.0

Expand Down
7 changes: 6 additions & 1 deletion R/permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
verbose = TRUE, ...) {
basic_checks(X = X, bg_X = bg_X, feature_names = feature_names, pred_fun = pred_fun)
p <- length(feature_names)
stopifnot("Permutation SHAP only supported for up to 14 features" = p <= 14L)
if (p <= 1L) {
stop("Case p = 1 not implemented. Use kernelshap() instead.")
}
if (p > 14L) {
stop("Permutation SHAP only supported for up to 14 features")
}
n <- nrow(X)
bg_n <- nrow(bg_X)
if (!is.null(bg_w)) {
Expand Down
27 changes: 27 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ get_vz <- function(X, bg, Z, object, pred_fun, w, ...) {
preds <- align_pred(pred_fun(object, X, ...))

# Aggregate
if (ncol(preds) == 1L) {
return(wrowmean_vector(preds, ngroups = m, w = w))
}
if (is.null(w)) {
return(rowsum(preds, group = g, reorder = FALSE) / n_bg)
}
Expand Down Expand Up @@ -243,6 +246,30 @@ fdummy <- function(x) {
out
}

#' Grouped Means for Single-Column Matrices (adapted from {hstats})
#'
#' Grouped means for matrix with single column over fixed-length groups.
#'
#' @noRd
#' @keywords internal
#'
#' @param x Matrix with one column.
#' @param ngroups Number of subsequent, equals sized groups.
#' @param w Optional vector of case weights of length `NROW(x) / ngroups`.
#' @returns Matrix with one column.
wrowmean_vector <- function(x, ngroups = 1L, w = NULL) {
if (ncol(x) != 1L) {
stop("x must have a single column")
}
nm <- colnames(x)
dim(x) <- c(length(x) %/% ngroups, ngroups)
out <- as.matrix(if (is.null(w)) colMeans(x) else colSums(x * w) / sum(w))
if (!is.null(nm)) {
colnames(out) <- nm
}
out
}

#' Basic Input Checks
#'
#' @noRd
Expand Down
19 changes: 9 additions & 10 deletions R/utils_permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,17 @@ permshap_one <- function(x, v1, object, pred_fun, bg_w, v0, precalc, ...) {
#' @param vz Named vector of vz values.
#' @returns SHAP values organized as (p x K) matrix.
shapley_formula <- function(Z, vz) {
out <- matrix(
nrow = ncol(Z), ncol = ncol(vz), dimnames = list(colnames(Z), colnames(vz))
)
for (v in colnames(Z)) {
s1 <- Z[, v] == 1L
p <- ncol(Z)
out <- matrix(nrow = p, ncol = ncol(vz), dimnames = list(colnames(Z), colnames(vz)))
for (j in seq_len(p)) {
s1 <- Z[, j] == 1L
vz1 <- vz[s1, , drop = FALSE]
Z0 <- Z[s1, , drop = FALSE]
Z0[, v] <- 0L
s0 <- rowpaste(Z0)
L <- rowSums(Z[s1, -j, drop = FALSE]) # how many players are playing with j?
s0 <- rownames(vz1)
substr(s0, j, j) <- "0"
vz0 <- vz[s0, , drop = FALSE]
w <- shapley_weights(ncol(Z), rowSums(Z0))
out[v, ] <- wcolMeans(vz1 - vz0, w = w)
w <- shapley_weights(p, L)
out[j, ] <- wcolMeans(vz1 - vz0, w = w)
}
out
}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The package contains two workhorses to calculate SHAP values for any model:

### Kernel SHAP or permutation SHAP?

Kernel SHAP was introduced as an approximation of permutation SHAP. For up to $8-10$ features, exact calculations are feasible for both algorithms and take the same amount of time. Since exact Kernel SHAP is still only an approximation of exact permutation SHAP, permutation SHAP should be preferred in this case, even if they agree for most of the models. A situation where the two approaches give different results: The model has interactions of order three or higher *and* correlated features.
Kernel SHAP was introduced as an approximation of permutation SHAP. For up to $8-10$ features, exact calculations are feasible for both algorithms and take the same amount of time. Since exact Kernel SHAP is still only an approximation of exact permutation SHAP, permutation SHAP should be preferred in this case. A situation where the two approaches give different results: The model has interactions of order three or higher *and* correlated features.

### Typical workflow to explain any model

Expand Down
22 changes: 0 additions & 22 deletions tests/testthat/test-permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ test_that("Matrix input gives error with inconsistent feature_names", {
)
})


## Now with case weights
fit <- lm(
Sepal.Length ~ poly(Petal.Width, degree = 2L) * Species, data = iris,
Expand Down Expand Up @@ -138,27 +137,6 @@ test_that("Decomposing a single row works with case weights", {
expect_equal(rowSums(s$S) + s$baseline, preds[1L])
})

fit <- lm(
Sepal.Length ~ poly(Petal.Width, degree = 2L),
data = iris,
weights = Petal.Length
)
x <- "Petal.Width"
preds <- unname(predict(fit, iris))

test_that("Special case p = 1 works with case weights", {
s <- permshap(
fit,
iris[1:5, x, drop = FALSE],
bg_X = iris,
bg_w = iris$Petal.Length,
verbose = FALSE
)

expect_equal(s$baseline, weighted.mean(iris$Sepal.Length, iris$Petal.Length))
expect_equal(rowSums(s$S) + s$baseline, preds[1:5])
})

fit <- lm(
Sepal.Length ~ . , data = iris[c(1L, 3L, 4L)], weights = iris$Sepal.Width
)
Expand Down
22 changes: 22 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,25 @@ test_that("fdummy() respects factor level order", {
expect_equal(d1, d2[, colnames(d1)])
expect_equal(colnames(d1), rev(colnames(d2)))
})

test_that("wrowmean_vector() works for 1D matrices", {
x2 <- cbind(a = 6:1)
out2 <- wrowmean_vector(x2, ngroups = 2L)
expec <- rowsum(x2, group = rep(1:2, each = 3)) / 3
rownames(expec) <- NULL

expect_error(wrowmean_vector(matrix(1:4, ncol = 2L)))
expect_equal(out2, expec)

expect_equal(wrowmean_vector(x2, ngroups = 3L), cbind(a = c(5.5, 3.5, 1.5)))

# Constant weights have no effect
expect_equal(wrowmean_vector(x2, ngroups = 2L, w = c(1, 1, 1)), out2)
expect_equal(wrowmean_vector(x2, ngroups = 2L, w = c(4, 4, 4)), out2)

# Non-constant weights
a <- weighted.mean(6:4, 1:3)
b <- weighted.mean(3:1, 1:3)
out2 <- wrowmean_vector(x2, ngroups = 2L, w = 1:3)
expect_equal(out2, cbind(a = c(a, b)))
})