Skip to content

Commit

Permalink
updated bar plot and added aggregation of InteractionValues object
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Jan 8, 2025
1 parent de60c64 commit d1279fd
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 64 deletions.
67 changes: 67 additions & 0 deletions shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,3 +769,70 @@ def plot_upset(self, show: bool = True, **kwargs) -> Optional[plt.Figure]:
from shapiq.plot.upset import upset_plot

return upset_plot(self, show=show, **kwargs)


def aggregate_interaction_values(
interaction_values: list[InteractionValues],
aggregation: str = "mean",
) -> InteractionValues:
"""Aggregates InteractionValues objects using a specific aggregation method.
Args:
interaction_values: A list of InteractionValues objects to aggregate.
aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
Returns:
The aggregated InteractionValues object.
Note:
The index of the aggregated InteractionValues object is set to the index of the first
InteractionValues object in the list.
Raises:
ValueError: If the aggregation method is not supported.
"""

def _aggregate(vals: list[float], method: str) -> float:
"""Does the actual aggregation of the values."""
if method == "mean":
return np.mean(vals)
elif method == "median":
return np.median(vals)
elif method == "sum":
return np.sum(vals)
elif method == "max":
return np.max(vals)
elif method == "min":
return np.min(vals)
else:
raise ValueError(f"Aggregation method {method} is not supported.")

# get all keys from all InteractionValues objects
all_keys = set()
for iv in interaction_values:
all_keys.update(iv.interaction_lookup.keys())

# aggregate the values
new_values = np.zeros(len(all_keys), dtype=float)
new_lookup = {}
for i, key in enumerate(all_keys):
new_lookup[key] = i
values = [iv[key] for iv in interaction_values]
new_values[i] = _aggregate(values, aggregation)

max_order = max([iv.max_order for iv in interaction_values])
min_order = min([iv.min_order for iv in interaction_values])
n_players = max([iv.n_players for iv in interaction_values])

return InteractionValues(
values=new_values,
index=interaction_values[0].index,
max_order=max_order,
n_players=n_players,
min_order=min_order,
interaction_lookup=new_lookup,
estimated=True,
estimation_budget=None,
baseline_value=_aggregate([iv.baseline_value for iv in interaction_values], aggregation),
)
126 changes: 65 additions & 61 deletions shapiq/plot/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import matplotlib.pyplot as plt
import numpy as np

from ..interaction_values import InteractionValues
from ..interaction_values import InteractionValues, aggregate_interaction_values
from ._config import BLUE, RED
from .utils import abbreviate_feature_names, format_labels, format_value

__all__ = ["bar_plot"]


def _bar(values, feature_names, max_display=10, ax=None, show=True):
def _bar(values, feature_names, max_display=10, ax=None):
"""Create a bar plot of a set of SHAP values.
Parameters
Expand All @@ -29,24 +29,8 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True):
explanation objects.
max_display : int
How many top features to include in the bar plot (default is 10).
order : OpChain or numpy.ndarray
A function that returns a sort ordering given a matrix of SHAP values
and an axis, or a direct sample ordering given as a ``numpy.ndarray``.
By default, take the absolute value.
clustering: np.ndarray or None
A partition tree, as returned by ``shap.utils.hclust``
clustering_cutoff: float
Controls how much of the clustering structure is displayed.
show_data: bool or str
Controls if data values are shown as part of the y tick labels. If
"auto", we show the data only when there are no transforms.
ax: matplotlib Axes
Axes object to draw the plot onto, otherwise uses the current Axes.
show : bool
Whether ``matplotlib.pyplot.show()`` is called before returning.
Setting this to ``False`` allows the plot
to be customized further after it has been created.
Returns
-------
Expand All @@ -58,36 +42,45 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True):
See `bar plot examples <https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/bar.html>`_.
"""
# assert str(type(shap_values)).endswith("Explanation'>"), "The shap_values parameter must be a shap.Explanation object!"
xlabel = "Shapley value"

# determine how many top features we will plot
num_features = len(values[0])
if max_display is None:
max_display = len(feature_names)
num_features = min(max_display, len(values[0]))
max_display = num_features
max_display = min(max_display, num_features)
num_cut = max(num_features - max_display, 0) # number of features that are not displayed

# Make it descending order
# get order of features in descending order
feature_order = np.argsort(np.mean(values, axis=0))[::-1]

y_pos = np.arange(len(feature_order), 0, -1)

# build our y-tick labels
yticklabels = [feature_names[i] for i in feature_order]

# if there are more features than we are displaying then we aggregate the features not shown
if num_cut > 0:
cut_feature_values = values[:, feature_order[max_display:]]
sum_of_remaining = np.sum(cut_feature_values, axis=None)
index_of_last = feature_order[max_display]
values = np.insert(values, index_of_last, sum_of_remaining, axis=1)
max_display += 1 # include the sum of the remaining in the display

# get the top features and their names
feature_inds = feature_order[:max_display]
y_pos = np.arange(len(feature_inds), 0, -1)
yticklabels = [feature_names[i] for i in feature_inds]
if num_cut > 0:
yticklabels[-1] = f"Sum of {int(num_cut)} other features"

# create a figure if one was not provided
if ax is None:
ax = plt.gca()
# Only modify the figure size if ax was not passed in
# only modify the figure size if ax was not passed in
# compute our figure size based on how many features we are showing
fig = plt.gcf()
row_height = 0.5
fig.set_size_inches(
8 + 0.3 * max([len(f) for f in feature_names]),
num_features * row_height * np.sqrt(len(values)) + 1.5,
8 + 0.3 * max([len(feature_name) for feature_name in feature_names]),
max_display * row_height * np.sqrt(len(values)) + 1.5,
)

# if negative values are present then we draw a vertical line to mark 0, otherwise the axis does this for us...
negative_values_present = np.sum(values[:, feature_order[:num_features]] < 0) > 0
# if negative values are present, we draw a vertical line to mark 0
negative_values_present = np.sum(values[:, feature_order[:max_display]] < 0) > 0
if negative_values_present:
ax.axvline(0, 0, 1, color="#000000", linestyle="-", linewidth=1, zorder=1)

Expand All @@ -99,15 +92,15 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True):
ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2)
ax.barh(
y_pos + ypos_offset,
values[i, feature_order],
values[i, feature_inds],
bar_width,
align="center",
color=[
BLUE.hex if values[i, feature_order[j]] <= 0 else RED.hex for j in range(len(y_pos))
BLUE.hex if values[i, feature_inds[j]] <= 0 else RED.hex for j in range(len(y_pos))
],
hatch=patterns[i],
edgecolor=(1, 1, 1, 0.8),
label="Model " + str(i),
label="Group " + str(i + 1),
)

# draw the yticks (the 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks)
Expand All @@ -118,19 +111,19 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True):
)

xlen = ax.get_xlim()[1] - ax.get_xlim()[0]
# xticks = ax.get_xticks()
bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted())
width = bbox.width
bbox_to_xscale = xlen / width

# draw the bar labels as text next to the bars
for i in range(len(values)):
ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2)
for j in range(len(y_pos)):
ind = feature_order[j]
ind = feature_inds[j]
if values[i, ind] < 0:
ax.text(
values[i, ind] - (5 / 72) * bbox_to_xscale,
y_pos[j] + ypos_offset,
float(y_pos[j] + ypos_offset),
format_value(values[i, ind], "%+0.02f"),
horizontalalignment="right",
verticalalignment="center",
Expand All @@ -140,7 +133,7 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True):
else:
ax.text(
values[i, ind] + (5 / 72) * bbox_to_xscale,
y_pos[j] + ypos_offset,
float(y_pos[j] + ypos_offset),
format_value(values[i, ind], "%+0.02f"),
horizontalalignment="left",
verticalalignment="center",
Expand All @@ -149,9 +142,10 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True):
)

# put horizontal lines for each feature row
for i in range(num_features):
for i in range(max_display):
ax.axhline(i + 1, color="#888888", lw=0.5, dashes=(1, 5), zorder=-1)

# remove plot frame and y-axis ticks
ax.xaxis.set_ticks_position("bottom")
ax.yaxis.set_ticks_position("none")
ax.spines["right"].set_visible(False)
Expand All @@ -160,33 +154,26 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True):
ax.spines["left"].set_visible(False)
ax.tick_params("x", labelsize=11)

# set the x-axis limits to cover the data
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
x_buffer = (xmax - xmin) * 0.05

if negative_values_present:
ax.set_xlim(xmin - x_buffer, xmax + x_buffer)
else:
ax.set_xlim(xmin, xmax + x_buffer)

# if features is None:
# pl.xlabel(labels["GLOBAL_VALUE"], fontsize=13)
# else:
ax.set_xlabel(xlabel, fontsize=13)
ax.set_xlabel("Attribution", fontsize=13)

if len(values) > 1:
ax.legend(fontsize=12, loc="lower right")

# color the y tick labels that have the feature values as gray
# (these fall behind the black ones with just the feature name)
tick_labels = ax.yaxis.get_majorticklabels()
for i in range(num_features):
for i in range(max_display):
tick_labels[i].set_color("#999999")

if show:
plt.show()
else:
return ax
return ax


def bar_plot(
Expand All @@ -195,6 +182,7 @@ def bar_plot(
show: bool = False,
abbreviate: bool = True,
max_display: Optional[int] = 10,
global_plot: bool = True,
) -> Optional[plt.Axes]:
"""Draws interaction values on a bar plot.
Expand All @@ -207,9 +195,12 @@ def bar_plot(
show: Whether ``matplotlib.pyplot.show()`` is called before returning. Default is ``True``.
Setting this to ``False`` allows the plot to be customized further after it has been created.
abbreviate: Whether to abbreviate the feature names. Defaults to ``True``.
**kwargs: Keyword arguments passed to ``shap.plots.beeswarm()``.
max_display: The maximum number of features to display. Defaults to ``10``. If set to
``None``, all features are displayed.
global_plot: Weather to aggregate the values of the different InteractionValues objects
into a global explanation (``True``) or to plot them as separate bars (``False``).
Defaults to ``True``.
"""

n_players = list_of_interaction_values[0].n_players

if feature_names is not None:
Expand All @@ -219,21 +210,34 @@ def bar_plot(
else:
feature_mapping = {i: str(i) for i in range(n_players)}

assert len(np.unique([iv.max_order for iv in list_of_interaction_values])) == 1

values = np.stack([iv.values for iv in list_of_interaction_values])
if global_plot:
global_values = aggregate_interaction_values(list_of_interaction_values)
values = np.expand_dims(global_values.values, axis=0)
interaction_list = global_values.interaction_lookup.keys()
else:
all_interactions = set()
for iv in list_of_interaction_values:
all_interactions.update(iv.interaction_lookup.keys())
all_interactions = sorted(all_interactions)
interaction_list = []
values = np.zeros((len(list_of_interaction_values), len(all_interactions)))
for j, interaction in enumerate(all_interactions):
interaction_list.append(interaction)
for i, iv in enumerate(list_of_interaction_values):
values[i, j] = iv[interaction]

# TODO: update this to be correct with the order of labels

labels = np.array(
list(
map(
lambda x: format_labels(feature_mapping, x),
list_of_interaction_values[0].dict_values.keys(),
interaction_list,
)
)
)

ax = _bar(values=values, feature_names=labels, show=False, max_display=max_display)
ax.set_xlabel("Shapley value")
ax = _bar(values=values, feature_names=labels, max_display=max_display)
if not show:
return ax
plt.show()
2 changes: 1 addition & 1 deletion shapiq/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from collections.abc import Iterable

__all__ = ["abbreviate_feature_names", "format_value"]
__all__ = ["abbreviate_feature_names", "format_value", "format_labels"]


def format_value(s, format_str):
Expand Down
Loading

0 comments on commit d1279fd

Please sign in to comment.