Skip to content

Commit

Permalink
Merge pull request #115 from ModelOriented/shapley-formula-speedup
Browse files Browse the repository at this point in the history
Slight optimization
  • Loading branch information
mayer79 authored Nov 12, 2023
2 parents def8fe5 + 191b310 commit 8fb3104
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 34 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

- Slight speed-up of `permshap()` by saving calculations for the two special permutations of all 0 and all 1.
- Consequently, the `m_exact` component in the output is reduced by 2.
- Slight speed-up of `permshap()` by optimizing internal function `shapley_formula()`.
- Slight speed-up of `kernelshap()` and `permshap()` for single-output predictions.

# kernelshap 0.4.0

Expand Down
7 changes: 6 additions & 1 deletion R/permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
verbose = TRUE, ...) {
basic_checks(X = X, bg_X = bg_X, feature_names = feature_names, pred_fun = pred_fun)
p <- length(feature_names)
stopifnot("Permutation SHAP only supported for up to 14 features" = p <= 14L)
if (p <= 1L) {
stop("Case p = 1 not implemented. Use kernelshap() instead.")
}
if (p > 14L) {
stop("Permutation SHAP only supported for up to 14 features")
}
n <- nrow(X)
bg_n <- nrow(bg_X)
if (!is.null(bg_w)) {
Expand Down
27 changes: 27 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ get_vz <- function(X, bg, Z, object, pred_fun, w, ...) {
preds <- align_pred(pred_fun(object, X, ...))

# Aggregate
if (ncol(preds) == 1L) {
return(wrowmean_vector(preds, ngroups = m, w = w))
}
if (is.null(w)) {
return(rowsum(preds, group = g, reorder = FALSE) / n_bg)
}
Expand Down Expand Up @@ -243,6 +246,30 @@ fdummy <- function(x) {
out
}

#' Grouped Means for Single-Column Matrices (adapted from {hstats})
#'
#' Grouped means for matrix with single column over fixed-length groups.
#'
#' @noRd
#' @keywords internal
#'
#' @param x Matrix with one column.
#' @param ngroups Number of subsequent, equals sized groups.
#' @param w Optional vector of case weights of length `NROW(x) / ngroups`.
#' @returns Matrix with one column.
wrowmean_vector <- function(x, ngroups = 1L, w = NULL) {
if (ncol(x) != 1L) {
stop("x must have a single column")
}
nm <- colnames(x)
dim(x) <- c(length(x) %/% ngroups, ngroups)
out <- as.matrix(if (is.null(w)) colMeans(x) else colSums(x * w) / sum(w))
if (!is.null(nm)) {
colnames(out) <- nm
}
out
}

#' Basic Input Checks
#'
#' @noRd
Expand Down
19 changes: 9 additions & 10 deletions R/utils_permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,17 @@ permshap_one <- function(x, v1, object, pred_fun, bg_w, v0, precalc, ...) {
#' @param vz Named vector of vz values.
#' @returns SHAP values organized as (p x K) matrix.
shapley_formula <- function(Z, vz) {
out <- matrix(
nrow = ncol(Z), ncol = ncol(vz), dimnames = list(colnames(Z), colnames(vz))
)
for (v in colnames(Z)) {
s1 <- Z[, v] == 1L
p <- ncol(Z)
out <- matrix(nrow = p, ncol = ncol(vz), dimnames = list(colnames(Z), colnames(vz)))
for (j in seq_len(p)) {
s1 <- Z[, j] == 1L
vz1 <- vz[s1, , drop = FALSE]
Z0 <- Z[s1, , drop = FALSE]
Z0[, v] <- 0L
s0 <- rowpaste(Z0)
L <- rowSums(Z[s1, -j, drop = FALSE]) # how many players are playing with j?
s0 <- rownames(vz1)
substr(s0, j, j) <- "0"
vz0 <- vz[s0, , drop = FALSE]
w <- shapley_weights(ncol(Z), rowSums(Z0))
out[v, ] <- wcolMeans(vz1 - vz0, w = w)
w <- shapley_weights(p, L)
out[j, ] <- wcolMeans(vz1 - vz0, w = w)
}
out
}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The package contains two workhorses to calculate SHAP values for any model:

### Kernel SHAP or permutation SHAP?

Kernel SHAP was introduced as an approximation of permutation SHAP. For up to $8-10$ features, exact calculations are feasible for both algorithms and take the same amount of time. Since exact Kernel SHAP is still only an approximation of exact permutation SHAP, permutation SHAP should be preferred in this case, even if they agree for most of the models. A situation where the two approaches give different results: The model has interactions of order three or higher *and* correlated features.
Kernel SHAP was introduced as an approximation of permutation SHAP. For up to $8-10$ features, exact calculations are feasible for both algorithms and take the same amount of time. Since exact Kernel SHAP is still only an approximation of exact permutation SHAP, permutation SHAP should be preferred in this case. A situation where the two approaches give different results: The model has interactions of order three or higher *and* correlated features.

### Typical workflow to explain any model

Expand Down
22 changes: 0 additions & 22 deletions tests/testthat/test-permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ test_that("Matrix input gives error with inconsistent feature_names", {
)
})


## Now with case weights
fit <- lm(
Sepal.Length ~ poly(Petal.Width, degree = 2L) * Species, data = iris,
Expand Down Expand Up @@ -138,27 +137,6 @@ test_that("Decomposing a single row works with case weights", {
expect_equal(rowSums(s$S) + s$baseline, preds[1L])
})

fit <- lm(
Sepal.Length ~ poly(Petal.Width, degree = 2L),
data = iris,
weights = Petal.Length
)
x <- "Petal.Width"
preds <- unname(predict(fit, iris))

test_that("Special case p = 1 works with case weights", {
s <- permshap(
fit,
iris[1:5, x, drop = FALSE],
bg_X = iris,
bg_w = iris$Petal.Length,
verbose = FALSE
)

expect_equal(s$baseline, weighted.mean(iris$Sepal.Length, iris$Petal.Length))
expect_equal(rowSums(s$S) + s$baseline, preds[1:5])
})

fit <- lm(
Sepal.Length ~ . , data = iris[c(1L, 3L, 4L)], weights = iris$Sepal.Width
)
Expand Down
22 changes: 22 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,25 @@ test_that("fdummy() respects factor level order", {
expect_equal(d1, d2[, colnames(d1)])
expect_equal(colnames(d1), rev(colnames(d2)))
})

test_that("wrowmean_vector() works for 1D matrices", {
x2 <- cbind(a = 6:1)
out2 <- wrowmean_vector(x2, ngroups = 2L)
expec <- rowsum(x2, group = rep(1:2, each = 3)) / 3
rownames(expec) <- NULL

expect_error(wrowmean_vector(matrix(1:4, ncol = 2L)))
expect_equal(out2, expec)

expect_equal(wrowmean_vector(x2, ngroups = 3L), cbind(a = c(5.5, 3.5, 1.5)))

# Constant weights have no effect
expect_equal(wrowmean_vector(x2, ngroups = 2L, w = c(1, 1, 1)), out2)
expect_equal(wrowmean_vector(x2, ngroups = 2L, w = c(4, 4, 4)), out2)

# Non-constant weights
a <- weighted.mean(6:4, 1:3)
b <- weighted.mean(3:1, 1:3)
out2 <- wrowmean_vector(x2, ngroups = 2L, w = 1:3)
expect_equal(out2, cbind(a = c(a, b)))
})

0 comments on commit 8fb3104

Please sign in to comment.