From d1279fdbf4e9ec0c0f57d00e9b9b94ce5de6f811 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Wed, 8 Jan 2025 16:30:16 +0100 Subject: [PATCH] updated bar plot and added aggregation of InteractionValues object --- shapiq/interaction_values.py | 67 ++++++++++++++ shapiq/plot/bar.py | 126 +++++++++++++------------- shapiq/plot/utils.py | 2 +- tests/test_base_interaction_values.py | 55 ++++++++++- tests/tests_plots/test_bar.py | 2 +- 5 files changed, 188 insertions(+), 64 deletions(-) diff --git a/shapiq/interaction_values.py b/shapiq/interaction_values.py index 83b5d782..a83bcd34 100644 --- a/shapiq/interaction_values.py +++ b/shapiq/interaction_values.py @@ -769,3 +769,70 @@ def plot_upset(self, show: bool = True, **kwargs) -> Optional[plt.Figure]: from shapiq.plot.upset import upset_plot return upset_plot(self, show=show, **kwargs) + + +def aggregate_interaction_values( + interaction_values: list[InteractionValues], + aggregation: str = "mean", +) -> InteractionValues: + """Aggregates InteractionValues objects using a specific aggregation method. + + Args: + interaction_values: A list of InteractionValues objects to aggregate. + aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are + ``"median"``, ``"sum"``, ``"max"``, and ``"min"``. + + Returns: + The aggregated InteractionValues object. + + Note: + The index of the aggregated InteractionValues object is set to the index of the first + InteractionValues object in the list. + + Raises: + ValueError: If the aggregation method is not supported. + """ + + def _aggregate(vals: list[float], method: str) -> float: + """Does the actual aggregation of the values.""" + if method == "mean": + return np.mean(vals) + elif method == "median": + return np.median(vals) + elif method == "sum": + return np.sum(vals) + elif method == "max": + return np.max(vals) + elif method == "min": + return np.min(vals) + else: + raise ValueError(f"Aggregation method {method} is not supported.") + + # get all keys from all InteractionValues objects + all_keys = set() + for iv in interaction_values: + all_keys.update(iv.interaction_lookup.keys()) + + # aggregate the values + new_values = np.zeros(len(all_keys), dtype=float) + new_lookup = {} + for i, key in enumerate(all_keys): + new_lookup[key] = i + values = [iv[key] for iv in interaction_values] + new_values[i] = _aggregate(values, aggregation) + + max_order = max([iv.max_order for iv in interaction_values]) + min_order = min([iv.min_order for iv in interaction_values]) + n_players = max([iv.n_players for iv in interaction_values]) + + return InteractionValues( + values=new_values, + index=interaction_values[0].index, + max_order=max_order, + n_players=n_players, + min_order=min_order, + interaction_lookup=new_lookup, + estimated=True, + estimation_budget=None, + baseline_value=_aggregate([iv.baseline_value for iv in interaction_values], aggregation), + ) diff --git a/shapiq/plot/bar.py b/shapiq/plot/bar.py index bf9283d1..e41c4a27 100644 --- a/shapiq/plot/bar.py +++ b/shapiq/plot/bar.py @@ -5,14 +5,14 @@ import matplotlib.pyplot as plt import numpy as np -from ..interaction_values import InteractionValues +from ..interaction_values import InteractionValues, aggregate_interaction_values from ._config import BLUE, RED from .utils import abbreviate_feature_names, format_labels, format_value __all__ = ["bar_plot"] -def _bar(values, feature_names, max_display=10, ax=None, show=True): +def _bar(values, feature_names, max_display=10, ax=None): """Create a bar plot of a set of SHAP values. Parameters @@ -29,24 +29,8 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): explanation objects. max_display : int How many top features to include in the bar plot (default is 10). - order : OpChain or numpy.ndarray - A function that returns a sort ordering given a matrix of SHAP values - and an axis, or a direct sample ordering given as a ``numpy.ndarray``. - - By default, take the absolute value. - clustering: np.ndarray or None - A partition tree, as returned by ``shap.utils.hclust`` - clustering_cutoff: float - Controls how much of the clustering structure is displayed. - show_data: bool or str - Controls if data values are shown as part of the y tick labels. If - "auto", we show the data only when there are no transforms. ax: matplotlib Axes Axes object to draw the plot onto, otherwise uses the current Axes. - show : bool - Whether ``matplotlib.pyplot.show()`` is called before returning. - Setting this to ``False`` allows the plot - to be customized further after it has been created. Returns ------- @@ -58,36 +42,45 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): See `bar plot examples `_. """ - # assert str(type(shap_values)).endswith("Explanation'>"), "The shap_values parameter must be a shap.Explanation object!" - xlabel = "Shapley value" - # determine how many top features we will plot + num_features = len(values[0]) if max_display is None: - max_display = len(feature_names) - num_features = min(max_display, len(values[0])) + max_display = num_features max_display = min(max_display, num_features) + num_cut = max(num_features - max_display, 0) # number of features that are not displayed - # Make it descending order + # get order of features in descending order feature_order = np.argsort(np.mean(values, axis=0))[::-1] - y_pos = np.arange(len(feature_order), 0, -1) - - # build our y-tick labels - yticklabels = [feature_names[i] for i in feature_order] - + # if there are more features than we are displaying then we aggregate the features not shown + if num_cut > 0: + cut_feature_values = values[:, feature_order[max_display:]] + sum_of_remaining = np.sum(cut_feature_values, axis=None) + index_of_last = feature_order[max_display] + values = np.insert(values, index_of_last, sum_of_remaining, axis=1) + max_display += 1 # include the sum of the remaining in the display + + # get the top features and their names + feature_inds = feature_order[:max_display] + y_pos = np.arange(len(feature_inds), 0, -1) + yticklabels = [feature_names[i] for i in feature_inds] + if num_cut > 0: + yticklabels[-1] = f"Sum of {int(num_cut)} other features" + + # create a figure if one was not provided if ax is None: ax = plt.gca() - # Only modify the figure size if ax was not passed in + # only modify the figure size if ax was not passed in # compute our figure size based on how many features we are showing fig = plt.gcf() row_height = 0.5 fig.set_size_inches( - 8 + 0.3 * max([len(f) for f in feature_names]), - num_features * row_height * np.sqrt(len(values)) + 1.5, + 8 + 0.3 * max([len(feature_name) for feature_name in feature_names]), + max_display * row_height * np.sqrt(len(values)) + 1.5, ) - # if negative values are present then we draw a vertical line to mark 0, otherwise the axis does this for us... - negative_values_present = np.sum(values[:, feature_order[:num_features]] < 0) > 0 + # if negative values are present, we draw a vertical line to mark 0 + negative_values_present = np.sum(values[:, feature_order[:max_display]] < 0) > 0 if negative_values_present: ax.axvline(0, 0, 1, color="#000000", linestyle="-", linewidth=1, zorder=1) @@ -99,15 +92,15 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2) ax.barh( y_pos + ypos_offset, - values[i, feature_order], + values[i, feature_inds], bar_width, align="center", color=[ - BLUE.hex if values[i, feature_order[j]] <= 0 else RED.hex for j in range(len(y_pos)) + BLUE.hex if values[i, feature_inds[j]] <= 0 else RED.hex for j in range(len(y_pos)) ], hatch=patterns[i], edgecolor=(1, 1, 1, 0.8), - label="Model " + str(i), + label="Group " + str(i + 1), ) # draw the yticks (the 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks) @@ -118,19 +111,19 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): ) xlen = ax.get_xlim()[1] - ax.get_xlim()[0] - # xticks = ax.get_xticks() bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted()) width = bbox.width bbox_to_xscale = xlen / width + # draw the bar labels as text next to the bars for i in range(len(values)): ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2) for j in range(len(y_pos)): - ind = feature_order[j] + ind = feature_inds[j] if values[i, ind] < 0: ax.text( values[i, ind] - (5 / 72) * bbox_to_xscale, - y_pos[j] + ypos_offset, + float(y_pos[j] + ypos_offset), format_value(values[i, ind], "%+0.02f"), horizontalalignment="right", verticalalignment="center", @@ -140,7 +133,7 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): else: ax.text( values[i, ind] + (5 / 72) * bbox_to_xscale, - y_pos[j] + ypos_offset, + float(y_pos[j] + ypos_offset), format_value(values[i, ind], "%+0.02f"), horizontalalignment="left", verticalalignment="center", @@ -149,9 +142,10 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): ) # put horizontal lines for each feature row - for i in range(num_features): + for i in range(max_display): ax.axhline(i + 1, color="#888888", lw=0.5, dashes=(1, 5), zorder=-1) + # remove plot frame and y-axis ticks ax.xaxis.set_ticks_position("bottom") ax.yaxis.set_ticks_position("none") ax.spines["right"].set_visible(False) @@ -160,19 +154,15 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): ax.spines["left"].set_visible(False) ax.tick_params("x", labelsize=11) + # set the x-axis limits to cover the data xmin, xmax = ax.get_xlim() - ymin, ymax = ax.get_ylim() x_buffer = (xmax - xmin) * 0.05 - if negative_values_present: ax.set_xlim(xmin - x_buffer, xmax + x_buffer) else: ax.set_xlim(xmin, xmax + x_buffer) - # if features is None: - # pl.xlabel(labels["GLOBAL_VALUE"], fontsize=13) - # else: - ax.set_xlabel(xlabel, fontsize=13) + ax.set_xlabel("Attribution", fontsize=13) if len(values) > 1: ax.legend(fontsize=12, loc="lower right") @@ -180,13 +170,10 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): # color the y tick labels that have the feature values as gray # (these fall behind the black ones with just the feature name) tick_labels = ax.yaxis.get_majorticklabels() - for i in range(num_features): + for i in range(max_display): tick_labels[i].set_color("#999999") - if show: - plt.show() - else: - return ax + return ax def bar_plot( @@ -195,6 +182,7 @@ def bar_plot( show: bool = False, abbreviate: bool = True, max_display: Optional[int] = 10, + global_plot: bool = True, ) -> Optional[plt.Axes]: """Draws interaction values on a bar plot. @@ -207,9 +195,12 @@ def bar_plot( show: Whether ``matplotlib.pyplot.show()`` is called before returning. Default is ``True``. Setting this to ``False`` allows the plot to be customized further after it has been created. abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. - **kwargs: Keyword arguments passed to ``shap.plots.beeswarm()``. + max_display: The maximum number of features to display. Defaults to ``10``. If set to + ``None``, all features are displayed. + global_plot: Weather to aggregate the values of the different InteractionValues objects + into a global explanation (``True``) or to plot them as separate bars (``False``). + Defaults to ``True``. """ - n_players = list_of_interaction_values[0].n_players if feature_names is not None: @@ -219,21 +210,34 @@ def bar_plot( else: feature_mapping = {i: str(i) for i in range(n_players)} - assert len(np.unique([iv.max_order for iv in list_of_interaction_values])) == 1 - - values = np.stack([iv.values for iv in list_of_interaction_values]) + if global_plot: + global_values = aggregate_interaction_values(list_of_interaction_values) + values = np.expand_dims(global_values.values, axis=0) + interaction_list = global_values.interaction_lookup.keys() + else: + all_interactions = set() + for iv in list_of_interaction_values: + all_interactions.update(iv.interaction_lookup.keys()) + all_interactions = sorted(all_interactions) + interaction_list = [] + values = np.zeros((len(list_of_interaction_values), len(all_interactions))) + for j, interaction in enumerate(all_interactions): + interaction_list.append(interaction) + for i, iv in enumerate(list_of_interaction_values): + values[i, j] = iv[interaction] + + # TODO: update this to be correct with the order of labels labels = np.array( list( map( lambda x: format_labels(feature_mapping, x), - list_of_interaction_values[0].dict_values.keys(), + interaction_list, ) ) ) - ax = _bar(values=values, feature_names=labels, show=False, max_display=max_display) - ax.set_xlabel("Shapley value") + ax = _bar(values=values, feature_names=labels, max_display=max_display) if not show: return ax plt.show() diff --git a/shapiq/plot/utils.py b/shapiq/plot/utils.py index 3758bbfb..2c45d65f 100644 --- a/shapiq/plot/utils.py +++ b/shapiq/plot/utils.py @@ -3,7 +3,7 @@ import re from collections.abc import Iterable -__all__ = ["abbreviate_feature_names", "format_value"] +__all__ = ["abbreviate_feature_names", "format_value", "format_labels"] def format_value(s, format_str): diff --git a/tests/test_base_interaction_values.py b/tests/test_base_interaction_values.py index 4e4e81bc..62e4bd9e 100644 --- a/tests/test_base_interaction_values.py +++ b/tests/test_base_interaction_values.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from shapiq.interaction_values import InteractionValues +from shapiq.interaction_values import InteractionValues, aggregate_interaction_values from shapiq.utils import powerset @@ -626,3 +626,56 @@ def test_subset(): assert subset_interaction_values.estimated == interaction_values.estimated assert subset_interaction_values.estimation_budget == interaction_values.estimation_budget assert subset_interaction_values.index == interaction_values.index + + +@pytest.mark.parametrize("aggregation", ["sum", "mean", "median", "max", "min"]) +def test_aggregation(aggregation): + + n_objects = 3 + + n, min_order, max_order = 5, 1, 3 + interaction_values_list = [] + for _ in range(n_objects): + values = np.random.rand(2**n - 1) + interaction_lookup = { + interaction: i for i, interaction in enumerate(powerset(range(n), min_order, max_order)) + } + interaction_values = InteractionValues( + values=values, + index="SII", + max_order=max_order, + n_players=n, + min_order=min_order, + interaction_lookup=interaction_lookup, + estimated=False, + estimation_budget=0, + baseline_value=0.0, + ) + interaction_values_list.append(interaction_values) + + aggregated_interaction_values = aggregate_interaction_values( + interaction_values_list, aggregation=aggregation + ) + + assert isinstance(aggregated_interaction_values, InteractionValues) + assert aggregated_interaction_values.index == "SII" + assert aggregated_interaction_values.n_players == n + assert aggregated_interaction_values.min_order == min_order + assert aggregated_interaction_values.max_order == max_order + + # check that all interactions are equal to the expected value + for interaction in powerset(range(n), 1, n): + aggregated_value = np.array( + [interaction_values[interaction] for interaction_values in interaction_values_list] + ) + if aggregation == "sum": + expected_value = np.sum(aggregated_value) + elif aggregation == "mean": + expected_value = np.mean(aggregated_value) + elif aggregation == "median": + expected_value = np.median(aggregated_value) + elif aggregation == "max": + expected_value = np.max(aggregated_value) + elif aggregation == "min": + expected_value = np.min(aggregated_value) + assert aggregated_interaction_values[interaction] == expected_value diff --git a/tests/tests_plots/test_bar.py b/tests/tests_plots/test_bar.py index a03d3945..befff43c 100644 --- a/tests/tests_plots/test_bar.py +++ b/tests/tests_plots/test_bar.py @@ -8,7 +8,7 @@ from shapiq.plot import bar_plot -def test_bar_concret(): +def test_bar_concrete(): class CookingGame(shapiq.Game): def __init__(self): self.characteristic_function = {