Skip to content

Commit

Permalink
Merge pull request #112 from ModelOriented/speedup-permshap
Browse files Browse the repository at this point in the history
Slight speed-up of permshap()
  • Loading branch information
mayer79 authored Nov 11, 2023
2 parents 64c65d7 + c0bc2e1 commit 086b6e9
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 11 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: kernelshap
Title: Kernel SHAP
Version: 0.4.0
Version: 0.4.1
Authors@R: c(
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre")),
person("David", "Watson", , "[email protected]", role = "aut"),
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# kernelshap 0.4.1

## Other changes

- Slight speed-up of `permshap()` by saving calculations for the two special permutations of all 0 and all 1.
- Consequently, the `m_exact` component in the output is reduced by 2.

# kernelshap 0.4.0

## Major changes
Expand Down
10 changes: 7 additions & 3 deletions R/permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
message(txt)
}

# Baseline and predictions on explanation data (latter not required in algo)
# Baseline and predictions on explanation data
bg_preds <- align_pred(pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...))
v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
Expand All @@ -81,10 +81,10 @@ permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,

# Precalculations that are identical for each row to be explained
Z <- exact_Z(p, feature_names = feature_names, keep_extremes = TRUE)
m_exact <- nrow(Z)
m_exact <- nrow(Z) - 2L # We won't evaluate vz for first and last row
precalc <- list(
Z = Z,
Z_code = rowpaste(Z),
Z_code = rowpaste(Z),
bg_X_rep = bg_X[rep(seq_len(bg_n), times = m_exact), , drop = FALSE]
)

Expand All @@ -97,9 +97,11 @@ permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
parallel_args <- c(list(i = seq_len(n)), parallel_args)
res <- do.call(foreach::foreach, parallel_args) %dopar% permshap_one(
x = X[i, , drop = FALSE],
v1 = v1[i, , drop = FALSE],
object = object,
pred_fun = pred_fun,
bg_w = bg_w,
v0 = v0,
precalc = precalc,
...
)
Expand All @@ -111,9 +113,11 @@ permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
for (i in seq_len(n)) {
res[[i]] <- permshap_one(
x = X[i, , drop = FALSE],
v1 = v1[i, , drop = FALSE],
object = object,
pred_fun = pred_fun,
bg_w = bg_w,
v0 = v0,
precalc = precalc,
...
)
Expand Down
16 changes: 10 additions & 6 deletions R/utils_permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,25 @@ shapley_weights <- function(p, ell) {
#' @keywords internal
#'
#' @inheritParams permshap
#' @param v1 Prediction of `x`.
#' @param v0 Average prediction on background data.
#' @param x A single row to be explained.
#' @param precalc A list with precalculated values that are identical for all rows.
#' @return A (p x K) matrix of SHAP values.
permshap_one <- function(x, object, pred_fun, bg_w, precalc, ...) {
vz <- get_vz( # (m_ex x K)
X = x[rep(1L, times = nrow(precalc[["bg_X_rep"]])), , drop = FALSE], # (m_ex*n_bg x p)
bg = precalc[["bg_X_rep"]], # (m_ex*n_bg x p)
Z = precalc[["Z"]], # (m_ex x p)
permshap_one <- function(x, v1, object, pred_fun, bg_w, v0, precalc, ...) {
Z <- precalc[["Z"]] # ((m_ex+2) x K)
vz <- get_vz( # (m_ex x K)
X = x[rep(1L, times = nrow(precalc[["bg_X_rep"]])), , drop = FALSE], # (m_ex*n_bg x p)
bg = precalc[["bg_X_rep"]], # (m_ex*n_bg x p)
Z = Z[2:(nrow(Z) - 1L), , drop = FALSE], # (m_ex x p)
object = object,
pred_fun = pred_fun,
w = bg_w,
...
)
vz <- rbind(v0, vz, v1) # we add the cheaply calculated v0 and v1
rownames(vz) <- precalc[["Z_code"]]
shapley_formula(precalc[["Z"]], vz)
shapley_formula(Z, vz = vz)
}

#' Shapley's formula
Expand Down
2 changes: 1 addition & 1 deletion packaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ library(usethis)
use_description(
fields = list(
Title = "Kernel SHAP",
Version = "0.4.0",
Version = "0.4.1",
Description = "Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017),
and Covert and Lee (2021) <http://proceedings.mlr.press/v130/covert21a>.
Furthermore, for up to 14 features, exact permutation SHAP values can be calculated.
Expand Down

0 comments on commit 086b6e9

Please sign in to comment.