Skip to content

Commit

Permalink
Merge pull request #132 from ModelOriented/additive-shap
Browse files Browse the repository at this point in the history
Additive shap
  • Loading branch information
mayer79 authored May 26, 2024
2 parents 559ba5c + d1c6958 commit 9808196
Show file tree
Hide file tree
Showing 23 changed files with 387 additions and 176 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: kernelshap
Title: Kernel SHAP
Version: 0.4.2
Version: 0.5.0
Authors@R: c(
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre")),
person("David", "Watson", , "[email protected]", role = "aut"),
Expand All @@ -19,7 +19,7 @@ Depends:
R (>= 3.2.0)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Imports:
foreach,
stats,
Expand Down
4 changes: 1 addition & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@ S3method(kernelshap,ranger)
S3method(permshap,default)
S3method(permshap,ranger)
S3method(print,kernelshap)
S3method(print,permshap)
S3method(summary,kernelshap)
S3method(summary,permshap)
export(additive_shap)
export(is.kernelshap)
export(is.permshap)
export(kernelshap)
export(permshap)
importFrom(foreach,"%dopar%")
23 changes: 23 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
# kernelshap 0.5.0

## New features

New additive explainer `additive_shap()` that works for models fitted via

- `lm()`,
- `glm()`,
- `mgcv::gam()`,
- `mgcv::bam()`,
- `gam::gam()`,
- `survival::coxph()`,
- `survival::survreg()`.

The explainer uses `predict(..., type = "terms")`, a beautiful trick
used in `fastshap::explain.lm()`. The result will be identical to those returned by `kernelshap()` and `permshap()` but exponentially faster. Thanks David Watson for the great idea discussed in [#130](https://github.com/ModelOriented/kernelshap/issues/130).

## User visible changes

- `permshap()` now returns an object of class "kernelshap" to reduce the number of redundant methods.
- To distinguish which algorithm has generated the "kernelshap" object, the outputs of `kernelshap()`, `permshap()` (and `additive_shap()`) got an element "algorithm".
- `is.permshap()` has been removed.

# kernelshap 0.4.2

## API
Expand Down
95 changes: 95 additions & 0 deletions R/additive_shap.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#' Additive SHAP
#'
#' Exact additive SHAP assuming feature independence. The implementation
#' works for models fitted via
#' - [lm()],
#' - [glm()],
#' - [mgcv::gam()],
#' - [mgcv::bam()],
#' - [gam::gam()],
#' - [survival::coxph()], and
#' - [survival::survreg()].
#'
#' The SHAP values are extracted via `predict(object, newdata = X, type = "terms")`,
#' a logic heavily inspired by `fastshap:::explain.lm(..., exact = TRUE)`.
#' Models with interactions (specified via `:` or `*`), or with terms of
#' multiple features like `log(x1/x2)` are not supported.
#'
#' @inheritParams kernelshap
#' @param X Dataframe with rows to be explained. Will be used like
#' `predict(object, newdata = X, type = "terms")`.
#' @param ... Currently unused.
#' @returns
#' An object of class "kernelshap" with the following components:
#' - `S`: \eqn{(n \times p)} matrix with SHAP values.
#' - `X`: Same as input argument `X`.
#' - `baseline`: The baseline.
#' - `exact`: `TRUE`.
#' - `txt`: Summary text.
#' - `predictions`: Vector with predictions of `X` on the scale of "terms".
#' - `algorithm`: "additive_shap".
#' @export
#' @examples
#' # MODEL ONE: Linear regression
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' s <- additive_shap(fit, head(iris))
#' s
#'
#' # MODEL TWO: More complicated (but not very clever) formula
#' fit <- lm(
#' Sepal.Length ~ poly(Sepal.Width, 2) + log(Petal.Length) + log(Sepal.Width),
#' data = iris
#' )
#' s <- additive_shap(fit, head(iris))
#' s
additive_shap <- function(object, X, verbose = TRUE, ...) {
stopifnot(
inherits(object, c("lm", "glm", "gam", "bam", "Gam", "coxph", "survreg"))
)
if (any(attr(stats::terms(object), "order") > 1)) {
stop("Additive SHAP not appropriate for models with interactions.")
}

txt <- "Exact additive SHAP via predict(..., type = 'terms')"
if (verbose) {
message(txt)
}

S <- stats::predict(object, newdata = X, type = "terms")
rownames(S) <- NULL

# Baseline value
b <- as.vector(attr(S, "constant"))
if (is.null(b)) {
b <- 0
}

# Which columns of X are used in each column of S?
s_names <- colnames(S)
cols_used <- lapply(s_names, function(z) all.vars(stats::reformulate(z)))
if (any(lengths(cols_used) > 1L)) {
stop("The formula contains terms with multiple features (not supported).")
}

# Collapse all columns in S using the same column in X and rename accordingly
mapping <- split(
s_names, factor(unlist(cols_used), levels = colnames(X)), drop = TRUE
)
S <- do.call(
cbind,
lapply(mapping, function(z) rowSums(S[, z, drop = FALSE], na.rm = TRUE))
)

structure(
list(
S = S,
X = X,
baseline = b,
exact = TRUE,
txt = txt,
predictions = b + rowSums(S),
algorithm = "additive_shap"
),
class = "kernelshap"
)
}
4 changes: 3 additions & 1 deletion R/kernelshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
#' - `exact`: Logical flag indicating whether calculations are exact or not.
#' - `txt`: Summary text.
#' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`.
#' - `algorithm`: "kernelshap".
#' @references
#' 1. Scott M. Lundberg and Su-In Lee. A unified approach to interpreting model
#' predictions. Proceedings of the 31st International Conference on Neural
Expand Down Expand Up @@ -318,7 +319,8 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
prop_exact = prop_exact,
exact = exact || trunc(p / 2) == hybrid_degree,
txt = txt,
predictions = v1
predictions = v1,
algorithm = "kernelshap"
)
class(out) <- "kernelshap"
out
Expand Down
52 changes: 4 additions & 48 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,6 @@ print.kernelshap <- function(x, n = 2L, ...) {
invisible(x)
}

#' Prints "permshap" Object
#'
#' @param x An object of class "permshap".
#' @inheritParams print.kernelshap
#' @inherit print.kernelshap return
#' @export
#' @examples
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' s <- permshap(fit, iris[1:3, -1], bg_X = iris[, -1])
#' s
#' @seealso [permshap()]
print.permshap <- function(x, n = 2L, ...) {
print.kernelshap(x, n = n, ...)
}

#' Summarizes "kernelshap" Object
#'
#' @param object An object of class "kernelshap".
Expand Down Expand Up @@ -67,7 +52,10 @@ summary.kernelshap <- function(object, compact = FALSE, n = 2L, ...) {
"\n - m/iter:", getElement(object, "m")
)
}
cat("\n - m_exact:", getElement(object, "m_exact"))
m_exact <- getElement(object, "m_exact")
if (!is.null(m_exact)) {
cat("\n - m_exact:", m_exact)
}
if (!compact) {
cat("\n\nSHAP values of first observations:\n")
print(head_list(S, n = n))
Expand All @@ -79,21 +67,6 @@ summary.kernelshap <- function(object, compact = FALSE, n = 2L, ...) {
invisible(object)
}

#' Summarizes "permshap" Object
#'
#' @param object An object of class "permshap".
#' @inheritParams summary.kernelshap
#' @inherit summary.kernelshap return
#' @export
#' @examples
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' s <- permshap(fit, iris[1:3, -1], bg_X = iris[, -1])
#' summary(s)
#' @seealso [permshap()]
summary.permshap <- function(object, compact = FALSE, n = 2L, ...) {
summary.kernelshap(object, compact = compact, n = n, ...)
}

#' Check for kernelshap
#'
#' Is object of class "kernelshap"?
Expand All @@ -110,20 +83,3 @@ summary.permshap <- function(object, compact = FALSE, n = 2L, ...) {
is.kernelshap <- function(object){
inherits(object, "kernelshap")
}

#' Check for permshap
#'
#' Is object of class "permshap"?
#'
#' @param object An R object.
#' @returns `TRUE` if `object` is of class "permshap", and `FALSE` otherwise.
#' @export
#' @examples
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' s <- permshap(fit, iris[1:2, -1], bg_X = iris[, -1])
#' is.permshap(s)
#' is.permshap("a")
#' @seealso [kernelshap()]
is.permshap <- function(object){
inherits(object, "permshap")
}
8 changes: 5 additions & 3 deletions R/permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#'
#' @inheritParams kernelshap
#' @returns
#' An object of class "permshap" with the following components:
#' An object of class "kernelshap" with the following components:
#' - `S`: \eqn{(n \times p)} matrix with SHAP values or, if the model output has
#' dimension \eqn{K > 1}, a list of \eqn{K} such matrices.
#' - `X`: Same as input argument `X`.
Expand All @@ -16,6 +16,7 @@
#' (currently `TRUE`).
#' - `txt`: Summary text.
#' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`.
#' - `algorithm`: "permshap".
#' @references
#' 1. Erik Strumbelj and Igor Kononenko. Explaining prediction models and individual
#' predictions with feature contributions. Knowledge and Information Systems 41, 2014.
Expand Down Expand Up @@ -141,9 +142,10 @@ permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
m_exact = m_exact,
exact = TRUE,
txt = txt,
predictions = v1
predictions = v1,
algorithm = "permshap"
)
class(out) <- "permshap"
class(out) <- "kernelshap"
out
}

Expand Down
3 changes: 2 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ case_p1 <- function(n, feature_names, v0, v1, X, verbose) {
prop_exact = 1,
exact = TRUE,
txt = txt,
predictions = v1
predictions = v1,
algorithm = "kernelshap"
)
class(out) <- "kernelshap"
out
Expand Down
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@

## Overview

The package contains two workhorses to calculate SHAP values for any model:
The package contains two workhorses to calculate SHAP values for *any* model:

- `permshap()`: Exact permutation SHAP algorithm of [1]. Available for up to $p=14$ features.
- `kernelshap()`: Kernel SHAP algorithm of [2] and [3]. By default, exact Kernel SHAP is used for up to $p=8$ features, and an almost exact hybrid algorithm otherwise.

Furthermore, the function `additive_shap()` produces SHAP values for additive models fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`,
`survival::coxph()`, or `survival::survreg()`. It is exponentially faster than `permshap()` and `kernelshap()`, with identical results.

### Kernel SHAP or permutation SHAP?

Kernel SHAP has been introduced in [2] as an approximation of permutation SHAP [1]. For up to ten features, exact calculations are realistic for both algorithms. Since exact Kernel SHAP is still only an approximation of exact permutation SHAP, the latter should be preferred in this case, even if the results are often very similar.
Expand All @@ -38,6 +41,7 @@ If the training data is small, use the full training data. In cases with a natur
- Factor-valued predictions are automatically turned into one-hot-encoded columns.
- Case weights are supported via the argument `bg_w`.
- By changing the defaults in `kernelshap()`, the iterative pure sampling approach in [3] can be enforced.
- The `additive_shap()` explainer is easier to use: Only the model and `X` are required.

## Installation

Expand Down Expand Up @@ -215,6 +219,18 @@ sv_dependence(ps, xvars)

![](man/figures/README-nn-dep.svg)

### Additive SHAP

The additive explainer extracts the additive contribution of each feature from a model of suitable class.

```r
fit <- lm(log(price) ~ log(carat) + color + clarity + cut, data = diamonds)
shap_values <- additive_shap(fit, diamonds) |>
shapviz()
sv_importance(shap_values)
sv_dependence(shap_values, v = "carat", color_var = NULL)
```

### Multi-output models

{kernelshap} supports multivariate predictions like:
Expand Down
67 changes: 67 additions & 0 deletions backlog/test_additive_shap.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Some tests that need contributed packages

library(mgcv)
library(gam)
library(survival)
library(splines)
library(testthat)

formulas_ok <- list(
Sepal.Length ~ Sepal.Width + Petal.Width + Species,
Sepal.Length ~ log(Sepal.Width) + poly(Petal.Width, 2) + ns(Petal.Length, 2),
Sepal.Length ~ log(Sepal.Width) + poly(Sepal.Width, 2)
)

formulas_bad <- list(
Sepal.Length ~ Species * Petal.Length,
Sepal.Length ~ Species + Petal.Length + Species:Petal.Length,
Sepal.Length ~ log(Petal.Length / Petal.Width)
)

models <- list(mgcv::gam, mgcv::bam, gam::gam)

X <- head(iris)
for (formula in formulas_ok) {
for (model in models) {
fit <- model(formula, data = iris)
s <- additive_shap(fit, X = X, verbose = FALSE)
expect_equal(s$predictions, as.vector(predict(fit, newdata = X)))
}
}

for (formula in formulas_bad) {
for (model in models) {
fit <- model(formula, data = iris)
expect_error(s <- additive_shap(fit, X = X, verbose = FALSE))
}
}

# Survival
iris$s <- rep(1, nrow(iris))
formulas_ok <- list(
Surv(Sepal.Length, s) ~ Sepal.Width + Petal.Width + Species,
Surv(Sepal.Length, s) ~ log(Sepal.Width) + poly(Petal.Width, 2) + ns(Petal.Length, 2),
Surv(Sepal.Length, s) ~ log(Sepal.Width) + poly(Sepal.Width, 2)
)

formulas_bad <- list(
Surv(Sepal.Length, s) ~ Species * Petal.Length,
Surv(Sepal.Length, s) ~ Species + Petal.Length + Species:Petal.Length,
Surv(Sepal.Length, s) ~ log(Petal.Length / Petal.Width)
)

models <- list(survival::coxph, survival::survreg)

for (formula in formulas_ok) {
for (model in models) {
fit <- model(formula, data = iris)
s <- additive_shap(fit, X = X, verbose = FALSE)
}
}

for (formula in formulas_bad) {
for (model in models) {
fit <- model(formula, data = iris)
expect_error(s <- additive_shap(fit, X = X, verbose = FALSE))
}
}
Loading

0 comments on commit 9808196

Please sign in to comment.