-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #132 from ModelOriented/additive-shap
Additive shap
- Loading branch information
Showing
23 changed files
with
387 additions
and
176 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Package: kernelshap | ||
Title: Kernel SHAP | ||
Version: 0.4.2 | ||
Version: 0.5.0 | ||
Authors@R: c( | ||
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre")), | ||
person("David", "Watson", , "[email protected]", role = "aut"), | ||
|
@@ -19,7 +19,7 @@ Depends: | |
R (>= 3.2.0) | ||
Encoding: UTF-8 | ||
Roxygen: list(markdown = TRUE) | ||
RoxygenNote: 7.2.3 | ||
RoxygenNote: 7.3.1 | ||
Imports: | ||
foreach, | ||
stats, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#' Additive SHAP | ||
#' | ||
#' Exact additive SHAP assuming feature independence. The implementation | ||
#' works for models fitted via | ||
#' - [lm()], | ||
#' - [glm()], | ||
#' - [mgcv::gam()], | ||
#' - [mgcv::bam()], | ||
#' - [gam::gam()], | ||
#' - [survival::coxph()], and | ||
#' - [survival::survreg()]. | ||
#' | ||
#' The SHAP values are extracted via `predict(object, newdata = X, type = "terms")`, | ||
#' a logic heavily inspired by `fastshap:::explain.lm(..., exact = TRUE)`. | ||
#' Models with interactions (specified via `:` or `*`), or with terms of | ||
#' multiple features like `log(x1/x2)` are not supported. | ||
#' | ||
#' @inheritParams kernelshap | ||
#' @param X Dataframe with rows to be explained. Will be used like | ||
#' `predict(object, newdata = X, type = "terms")`. | ||
#' @param ... Currently unused. | ||
#' @returns | ||
#' An object of class "kernelshap" with the following components: | ||
#' - `S`: \eqn{(n \times p)} matrix with SHAP values. | ||
#' - `X`: Same as input argument `X`. | ||
#' - `baseline`: The baseline. | ||
#' - `exact`: `TRUE`. | ||
#' - `txt`: Summary text. | ||
#' - `predictions`: Vector with predictions of `X` on the scale of "terms". | ||
#' - `algorithm`: "additive_shap". | ||
#' @export | ||
#' @examples | ||
#' # MODEL ONE: Linear regression | ||
#' fit <- lm(Sepal.Length ~ ., data = iris) | ||
#' s <- additive_shap(fit, head(iris)) | ||
#' s | ||
#' | ||
#' # MODEL TWO: More complicated (but not very clever) formula | ||
#' fit <- lm( | ||
#' Sepal.Length ~ poly(Sepal.Width, 2) + log(Petal.Length) + log(Sepal.Width), | ||
#' data = iris | ||
#' ) | ||
#' s <- additive_shap(fit, head(iris)) | ||
#' s | ||
additive_shap <- function(object, X, verbose = TRUE, ...) { | ||
stopifnot( | ||
inherits(object, c("lm", "glm", "gam", "bam", "Gam", "coxph", "survreg")) | ||
) | ||
if (any(attr(stats::terms(object), "order") > 1)) { | ||
stop("Additive SHAP not appropriate for models with interactions.") | ||
} | ||
|
||
txt <- "Exact additive SHAP via predict(..., type = 'terms')" | ||
if (verbose) { | ||
message(txt) | ||
} | ||
|
||
S <- stats::predict(object, newdata = X, type = "terms") | ||
rownames(S) <- NULL | ||
|
||
# Baseline value | ||
b <- as.vector(attr(S, "constant")) | ||
if (is.null(b)) { | ||
b <- 0 | ||
} | ||
|
||
# Which columns of X are used in each column of S? | ||
s_names <- colnames(S) | ||
cols_used <- lapply(s_names, function(z) all.vars(stats::reformulate(z))) | ||
if (any(lengths(cols_used) > 1L)) { | ||
stop("The formula contains terms with multiple features (not supported).") | ||
} | ||
|
||
# Collapse all columns in S using the same column in X and rename accordingly | ||
mapping <- split( | ||
s_names, factor(unlist(cols_used), levels = colnames(X)), drop = TRUE | ||
) | ||
S <- do.call( | ||
cbind, | ||
lapply(mapping, function(z) rowSums(S[, z, drop = FALSE], na.rm = TRUE)) | ||
) | ||
|
||
structure( | ||
list( | ||
S = S, | ||
X = X, | ||
baseline = b, | ||
exact = TRUE, | ||
txt = txt, | ||
predictions = b + rowSums(S), | ||
algorithm = "additive_shap" | ||
), | ||
class = "kernelshap" | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Some tests that need contributed packages | ||
|
||
library(mgcv) | ||
library(gam) | ||
library(survival) | ||
library(splines) | ||
library(testthat) | ||
|
||
formulas_ok <- list( | ||
Sepal.Length ~ Sepal.Width + Petal.Width + Species, | ||
Sepal.Length ~ log(Sepal.Width) + poly(Petal.Width, 2) + ns(Petal.Length, 2), | ||
Sepal.Length ~ log(Sepal.Width) + poly(Sepal.Width, 2) | ||
) | ||
|
||
formulas_bad <- list( | ||
Sepal.Length ~ Species * Petal.Length, | ||
Sepal.Length ~ Species + Petal.Length + Species:Petal.Length, | ||
Sepal.Length ~ log(Petal.Length / Petal.Width) | ||
) | ||
|
||
models <- list(mgcv::gam, mgcv::bam, gam::gam) | ||
|
||
X <- head(iris) | ||
for (formula in formulas_ok) { | ||
for (model in models) { | ||
fit <- model(formula, data = iris) | ||
s <- additive_shap(fit, X = X, verbose = FALSE) | ||
expect_equal(s$predictions, as.vector(predict(fit, newdata = X))) | ||
} | ||
} | ||
|
||
for (formula in formulas_bad) { | ||
for (model in models) { | ||
fit <- model(formula, data = iris) | ||
expect_error(s <- additive_shap(fit, X = X, verbose = FALSE)) | ||
} | ||
} | ||
|
||
# Survival | ||
iris$s <- rep(1, nrow(iris)) | ||
formulas_ok <- list( | ||
Surv(Sepal.Length, s) ~ Sepal.Width + Petal.Width + Species, | ||
Surv(Sepal.Length, s) ~ log(Sepal.Width) + poly(Petal.Width, 2) + ns(Petal.Length, 2), | ||
Surv(Sepal.Length, s) ~ log(Sepal.Width) + poly(Sepal.Width, 2) | ||
) | ||
|
||
formulas_bad <- list( | ||
Surv(Sepal.Length, s) ~ Species * Petal.Length, | ||
Surv(Sepal.Length, s) ~ Species + Petal.Length + Species:Petal.Length, | ||
Surv(Sepal.Length, s) ~ log(Petal.Length / Petal.Width) | ||
) | ||
|
||
models <- list(survival::coxph, survival::survreg) | ||
|
||
for (formula in formulas_ok) { | ||
for (model in models) { | ||
fit <- model(formula, data = iris) | ||
s <- additive_shap(fit, X = X, verbose = FALSE) | ||
} | ||
} | ||
|
||
for (formula in formulas_bad) { | ||
for (model in models) { | ||
fit <- model(formula, data = iris) | ||
expect_error(s <- additive_shap(fit, X = X, verbose = FALSE)) | ||
} | ||
} |
Oops, something went wrong.