Skip to content

Commit

Permalink
updates aggregation method and finishes work on bar plot
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Jan 10, 2025
1 parent 4370dc5 commit 56e3ef6
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 47 deletions.
57 changes: 55 additions & 2 deletions shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import os
import pickle
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, Union
from warnings import warn
Expand Down Expand Up @@ -630,6 +631,25 @@ def to_dict(self) -> dict:
"baseline_value": self.baseline_value,
}

def aggregate(
self, others: Sequence["InteractionValues"], aggregation: str = "mean"
) -> "InteractionValues":
"""Aggregates InteractionValues objects using a specific aggregation method.
Args:
others: A list of InteractionValues objects to aggregate.
aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
Returns:
The aggregated InteractionValues object.
Note:
For documentation on the aggregation methods, see the ``aggregate_interaction_values()``
function.
"""
return aggregate_interaction_values([self, *others], aggregation)

def plot_network(self, show: bool = True, **kwargs) -> Optional[tuple[plt.Figure, plt.Axes]]:
"""Visualize InteractionValues on a graph.
Expand Down Expand Up @@ -772,7 +792,7 @@ def plot_upset(self, show: bool = True, **kwargs) -> Optional[plt.Figure]:


def aggregate_interaction_values(
interaction_values: list[InteractionValues],
interaction_values: Sequence[InteractionValues],
aggregation: str = "mean",
) -> InteractionValues:
"""Aggregates InteractionValues objects using a specific aggregation method.
Expand All @@ -785,6 +805,37 @@ def aggregate_interaction_values(
Returns:
The aggregated InteractionValues object.
Example:
>>> iv1 = InteractionValues(
... values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=0.0,
... )
>>> iv2 = InteractionValues(
... values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]), # this iv is missing the (1, 2) value
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4}, # no (1, 2)
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=1.0,
... )
>>> aggregate_interaction_values([iv1, iv2], "mean")
InteractionValues(
index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None,
n_players=3, baseline_value=0.5,
Top 10 interactions:
(1, 2): 0.60
(0, 2): 0.35
(0, 1): 0.25
(0,): 0.15
(1,): 0.25
(2,): 0.35
)
Note:
The index of the aggregated InteractionValues object is set to the index of the first
InteractionValues object in the list.
Expand Down Expand Up @@ -812,6 +863,7 @@ def _aggregate(vals: list[float], method: str) -> float:
all_keys = set()
for iv in interaction_values:
all_keys.update(iv.interaction_lookup.keys())
all_keys = sorted(all_keys)

# aggregate the values
new_values = np.zeros(len(all_keys), dtype=float)
Expand All @@ -824,6 +876,7 @@ def _aggregate(vals: list[float], method: str) -> float:
max_order = max([iv.max_order for iv in interaction_values])
min_order = min([iv.min_order for iv in interaction_values])
n_players = max([iv.n_players for iv in interaction_values])
baseline_value = _aggregate([iv.baseline_value for iv in interaction_values], aggregation)

return InteractionValues(
values=new_values,
Expand All @@ -834,5 +887,5 @@ def _aggregate(vals: list[float], method: str) -> float:
interaction_lookup=new_lookup,
estimated=True,
estimation_budget=None,
baseline_value=_aggregate([iv.baseline_value for iv in interaction_values], aggregation),
baseline_value=baseline_value,
)
55 changes: 15 additions & 40 deletions shapiq/plot/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,16 @@
__all__ = ["bar_plot"]


def _bar(values, feature_names, max_display=10, ax=None):
def _bar(
values: np.ndarray,
feature_names: np.ndarray,
max_display: Optional[int] = 10,
ax: Optional[plt.Axes] = None,
) -> plt.Axes:
"""Create a bar plot of a set of SHAP values.
Parameters
----------
shap_values : shap.Explanation or shap.Cohorts or dictionary of shap.Explanation objects
Passing a multi-row :class:`.Explanation` object creates a global
feature importance plot.
Passing a single row of an explanation (i.e. ``shap_values[0]``) creates
a local feature importance plot.
Passing a dictionary of Explanation objects will create a multiple-bar
plot with one bar type for each of the cohorts represented by the
explanation objects.
max_display : int
How many top features to include in the bar plot (default is 10).
ax: matplotlib Axes
Axes object to draw the plot onto, otherwise uses the current Axes.
Returns
-------
ax: matplotlib Axes
Returns the Axes object with the plot drawn onto it. Only returned if ``show=False``.
Examples
--------
See `bar plot examples <https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/bar.html>`_.
This is a modified version of the bar plot from the SHAP package. The original code can be found
at https://github.com/shap/shap.
"""
# determine how many top features we will plot
num_features = len(values[0])
Expand Down Expand Up @@ -199,7 +180,8 @@ def bar_plot(
``None``, all features are displayed.
global_plot: Weather to aggregate the values of the different InteractionValues objects
into a global explanation (``True``) or to plot them as separate bars (``False``).
Defaults to ``True``.
Defaults to ``True``. If only one InteractionValues object is provided, this parameter
is ignored.
"""
n_players = list_of_interaction_values[0].n_players

Expand All @@ -208,13 +190,14 @@ def bar_plot(
feature_names = abbreviate_feature_names(feature_names)
feature_mapping = {i: feature_names[i] for i in range(n_players)}
else:
feature_mapping = {i: str(i) for i in range(n_players)}
feature_mapping = {i: "F" + str(i) for i in range(n_players)}

# aggregate the interaction values if global_plot is True
if global_plot:
global_values = aggregate_interaction_values(list_of_interaction_values)
values = np.expand_dims(global_values.values, axis=0)
interaction_list = global_values.interaction_lookup.keys()
else:
else: # plot the interaction values separately (also includes the case of a single object)
all_interactions = set()
for iv in list_of_interaction_values:
all_interactions.update(iv.interaction_lookup.keys())
Expand All @@ -226,16 +209,8 @@ def bar_plot(
for i, iv in enumerate(list_of_interaction_values):
values[i, j] = iv[interaction]

# TODO: update this to be correct with the order of labels

labels = np.array(
list(
map(
lambda x: format_labels(feature_mapping, x),
interaction_list,
)
)
)
# format the labels
labels = [format_labels(feature_mapping, interaction) for interaction in interaction_list]

ax = _bar(values=values, feature_names=labels, max_display=max_display)
if not show:
Expand Down
49 changes: 44 additions & 5 deletions shapiq/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,30 @@

import re
from collections.abc import Iterable
from typing import Union

__all__ = ["abbreviate_feature_names", "format_value", "format_labels"]


def format_value(s, format_str):
"""Strips trailing zeros and uses a unicode minus sign."""
def format_value(
s: Union[float, str],
format_str: str = "%.2f",
) -> str:
"""Strips trailing zeros and uses a unicode minus sign.
Args:
s: The value to be formatted.
format_str: The format string to be used. Defaults to "%.2f".
Returns:
str: The formatted value.
Examples:
>>> format_value(1.0)
"1"
>>> format_value(1.234)
"1.23"
"""
if not issubclass(type(s), str):
s = format_str % s
s = re.sub(r"\.?0+$", "", s)
Expand All @@ -16,13 +34,34 @@ def format_value(s, format_str):
return s


def format_labels(feature_mapping, feature_tuple):
def format_labels(
feature_mapping: dict[int, str],
feature_tuple: tuple[int, ...],
) -> str:
"""Formats the feature labels for the plots.
Args:
feature_mapping: A dictionary mapping feature indices to feature names.
feature_tuple: The feature tuple to be formatted.
Returns:
str: The formatted feature tuple.
Example:
>>> feature_mapping = {0: "A", 1: "B", 2: "C"}
>>> format_labels(feature_mapping, (0, 1))
"A x B"
>>> format_labels(feature_mapping, (0,))
"A"
>>> format_labels(feature_mapping, ())
"Base Value"
"""
if len(feature_tuple) == 0:
return "Baseval."
return "Base Value"
elif len(feature_tuple) == 1:
return str(feature_mapping[feature_tuple[0]])
else:
return " x ".join([feature_mapping[f] for f in feature_tuple])
return " x ".join([str(feature_mapping[f]) for f in feature_tuple])


def abbreviate_feature_names(feature_names: Iterable[str]) -> list[str]:
Expand Down
47 changes: 47 additions & 0 deletions tests/test_base_interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,50 @@ def test_aggregation(aggregation):
elif aggregation == "min":
expected_value = np.min(aggregated_value)
assert aggregated_interaction_values[interaction] == expected_value

# test aggregate from InteractionValues object
aggregated_from_object = interaction_values_list[0].aggregate(
aggregation=aggregation, others=interaction_values_list[1:]
)
assert isinstance(aggregated_from_object, InteractionValues)
assert aggregated_from_object == aggregated_interaction_values # same values
assert aggregated_from_object is not aggregated_interaction_values # but different objects


def test_docs_aggregation_function():
"""Tests the aggregation function in the InteractionValues dataclass like in the docs."""

iv1 = InteractionValues(
values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
index="SII",
n_players=3,
min_order=1,
max_order=2,
interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
baseline_value=0.0,
)

# this does not contain the (1, 2) interaction (i.e. is 0)
iv2 = InteractionValues(
values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]),
index="SII",
n_players=3,
min_order=1,
max_order=2,
interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4},
baseline_value=1.0,
)

# test sum
aggregated_interaction_values = aggregate_interaction_values([iv1, iv2], aggregation="sum")
assert pytest.approx(aggregated_interaction_values[(0,)]) == 0.3
assert pytest.approx(aggregated_interaction_values[(1,)]) == 0.5
assert pytest.approx(aggregated_interaction_values[(1, 2)]) == 0.6
assert pytest.approx(aggregated_interaction_values.baseline_value) == 1.0

# test mean
aggregated_interaction_values = aggregate_interaction_values([iv1, iv2], aggregation="mean")
assert pytest.approx(aggregated_interaction_values[(0,)]) == 0.15
assert pytest.approx(aggregated_interaction_values[(1,)]) == 0.25
assert pytest.approx(aggregated_interaction_values[(1, 2)]) == 0.3
assert pytest.approx(aggregated_interaction_values.baseline_value) == 0.5
35 changes: 35 additions & 0 deletions tests/tests_plots/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""This test module tests all plotting utilities."""

from shapiq.plot.utils import abbreviate_feature_names, format_labels, format_value


def test_format_value():
"""Test the format_value function."""
assert format_value(1.0) == "1"
assert format_value(1.234) == "1.23"
assert format_value(-1.234) == "\u22121.23"
assert format_value("1.234") == "1.234"


def test_format_labels():
"""Test the format_labels function."""
feature_mapping = {0: "A", 1: "B", 2: "C"}
assert format_labels(feature_mapping, (0, 1)) == "A x B"
assert format_labels(feature_mapping, (0,)) == "A"
assert format_labels(feature_mapping, ()) == "Base Value"
assert format_labels(feature_mapping, (0, 1, 2)) == "A x B x C"


def test_abbreviate_feature_names():
"""Tests the abbreviate_feature_names function."""
# check for splitting characters
feature_names = ["feature-0", "feature_1", "feature 2", "feature.3"]
assert abbreviate_feature_names(feature_names) == ["F0", "F1", "F2", "F3"]

# check for long names
feature_names = ["longfeaturenamethatisnotshort", "stilllong"]
assert abbreviate_feature_names(feature_names) == ["lon.", "sti."]

# check for abbreviation with capital letters
feature_names = ["LongFeatureName", "Short"]
assert abbreviate_feature_names(feature_names) == ["LFN", "Sho."]

0 comments on commit 56e3ef6

Please sign in to comment.