Skip to content

Commit

Permalink
Adds plot for language models (#284)
Browse files Browse the repository at this point in the history
* adds a plot to visualize word-level attributions

* makes abbreviation in plots optional and closes #281

* Bump pypa/gh-action-pypi-publish from 1.11.0 to 1.12.2 (#283)

Bumps [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) from 1.11.0 to 1.12.2.
- [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases)
- [Commits](pypa/gh-action-pypi-publish@fb13cb3...15c56db)

---
updated-dependencies:
- dependency-name: pypa/gh-action-pypi-publish
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Maximilian <[email protected]>

* adds an optional parameter to manually adjust scale of explanations

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
mmschlk and dependabot[bot] authored Dec 4, 2024
1 parent c0c51f7 commit fadf90d
Show file tree
Hide file tree
Showing 12 changed files with 294 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
## Changelog

### Development
-
- adds the `sentence_plot` function to the `plot` module to visualize the contributions of words to a language model prediction in a sentence-like format
- makes abbreviations in the `plot` module optional [#281](https://github.com/mmschlk/shapiq/issues/281)

### v1.1.1 (2024-11-13)

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/sentence_plot_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 25 additions & 2 deletions shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,7 @@ def plot_network(self, show: bool = True, **kwargs) -> Optional[tuple[plt.Figure

if self.max_order > 1:
return network_plot(
first_order_values=self.get_n_order_values(1),
second_order_values=self.get_n_order_values(2),
interaction_values=self,
show=show,
**kwargs,
)
Expand Down Expand Up @@ -635,6 +634,7 @@ def plot_force(
feature_values: Optional[np.ndarray] = None,
matplotlib=True,
show: bool = True,
abbreviate: bool = True,
**kwargs,
) -> Optional[plt.Figure]:
"""Visualize InteractionValues on a force plot.
Expand All @@ -649,6 +649,7 @@ def plot_force(
feature_values: The feature values used for plotting. Defaults to ``None``.
matplotlib: Whether to return a ``matplotlib`` figure. Defaults to ``True``.
show: Whether to show the plot. Defaults to ``False``.
abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
**kwargs: Keyword arguments passed to ``shap.plots.force()``.
Returns:
Expand All @@ -662,6 +663,7 @@ def plot_force(
feature_names=feature_names,
matplotlib=matplotlib,
show=show,
abbreviate=abbreviate,
**kwargs,
)

Expand All @@ -670,6 +672,7 @@ def plot_waterfall(
feature_names: Optional[np.ndarray] = None,
feature_values: Optional[np.ndarray] = None,
show: bool = True,
abbreviate: bool = True,
max_display: int = 10,
) -> Optional[plt.Axes]:
"""Draws interaction values on a waterfall plot.
Expand All @@ -682,6 +685,7 @@ def plot_waterfall(
feature indices are used instead. Defaults to ``None``.
feature_values: The feature values used for plotting. Defaults to ``None``.
show: Whether to show the plot. Defaults to ``False``.
abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
max_display: The maximum number of interactions to display. Defaults to ``10``.
"""
from shapiq import waterfall_plot
Expand All @@ -691,5 +695,24 @@ def plot_waterfall(
feature_values=feature_values,
feature_names=feature_names,
show=show,
abbreviate=abbreviate,
max_display=max_display,
)

def plot_sentence(
self,
words: list[str],
show: bool = True,
**kwargs,
) -> Optional[tuple[plt.Figure, plt.Axes]]:
"""Plots the first order effects (attributions) of a sentence or paragraph.
For arguments, see shapiq.plots.sentence_plot().
Returns:
If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple
with the figure and the axis of the plot.
"""
from shapiq.plot.sentence import sentence_plot

return sentence_plot(self, words, show=show, **kwargs)
2 changes: 2 additions & 0 deletions shapiq/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .bar import bar_plot
from .force import force_plot
from .network import network_plot
from .sentence import sentence_plot
from .si_graph import si_graph_plot
from .stacked_bar import stacked_bar_plot
from .utils import abbreviate_feature_names, get_interaction_values_and_feature_names
Expand All @@ -15,6 +16,7 @@
"force_plot",
"bar_plot",
"waterfall_plot",
"sentence_plot",
# utils
"abbreviate_feature_names",
"get_interaction_values_and_feature_names",
Expand Down
6 changes: 5 additions & 1 deletion shapiq/plot/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def bar_plot(
list_of_interaction_values: list[InteractionValues],
feature_names: Optional[np.ndarray] = None,
show: bool = False,
abbreviate: bool = True,
**kwargs,
) -> Optional[plt.Axes]:
"""Draws interaction values on a bar plot.
Expand All @@ -28,6 +29,7 @@ def bar_plot(
feature indices are used instead. Defaults to ``None``.
show: Whether ``matplotlib.pyplot.show()`` is called before returning. Default is ``True``.
Setting this to ``False`` allows the plot to be customized further after it has been created.
abbreviate: Whether to abbreviate the feature names. Defaults to ``True``.
**kwargs: Keyword arguments passed to ``shap.plots.beeswarm()``.
"""
check_import_module("shap")
Expand All @@ -41,7 +43,9 @@ def bar_plot(
_first_iv = True
for iv in list_of_interaction_values:

_shap_values, _names = get_interaction_values_and_feature_names(iv, feature_names, None)
_shap_values, _names = get_interaction_values_and_feature_names(
iv, feature_names, None, abbreviate=abbreviate
)
if _first_iv:
_labels = _names
_first_iv = False
Expand Down
4 changes: 3 additions & 1 deletion shapiq/plot/force.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def force_plot(
feature_values: Optional[np.ndarray] = None,
matplotlib: bool = True,
show: bool = False,
abbreviate: bool = True,
**kwargs,
) -> Optional[plt.Figure]:
"""Draws interaction values on a force plot.
Expand All @@ -32,13 +33,14 @@ def force_plot(
feature_values: The feature values used for plotting. Defaults to ``None``.
matplotlib: Whether to return a ``matplotlib`` figure. Defaults to ``True``.
show: Whether to show the plot. Defaults to ``False``.
abbreviate: Whether to abbreviate the feature names. Defaults to ``True``.
**kwargs: Keyword arguments passed to ``shap.plots.force()``.
"""
check_import_module("shap")
import shap

_shap_values, _labels = get_interaction_values_and_feature_names(
interaction_values, feature_names, feature_values
interaction_values, feature_names, feature_values, abbreviate=abbreviate
)

return shap.plots.force(
Expand Down
193 changes: 193 additions & 0 deletions shapiq/plot/sentence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""This module contains the sentence plot."""

from collections.abc import Sequence
from typing import Optional

from matplotlib import pyplot as plt
from matplotlib.font_manager import FontProperties
from matplotlib.patches import FancyBboxPatch, PathPatch
from matplotlib.textpath import TextPath

from ..interaction_values import InteractionValues
from ._config import BLUE, RED


def _get_color_and_alpha(max_value: float, value: float) -> tuple[str, float]:
"""Gets the color and alpha value for an interaction value."""
color = RED.hex if value >= 0 else BLUE.hex
ratio = abs(value / max_value)
ratio = min(ratio, 1.0) # make ratio at most 1
return color, ratio


def sentence_plot(
interaction_values: InteractionValues,
words: Sequence[str],
connected_words: Optional[Sequence[tuple[str, str]]] = None,
chars_per_line: int = 35,
font_family: str = "sans-serif",
show: bool = False,
max_score: Optional[float] = None,
) -> Optional[tuple[plt.Figure, plt.Axes]]:
"""Plots the first order effects (attributions) of a sentence or paragraph.
An example of the plot is shown below.
.. image:: /_static/sentence_plot_example.png
:width: 300
:align: center
Args:
interaction_values: The interaction values as an interaction object.
words: The words of the sentence or a paragraph of text.
connected_words: A list of tuples with connected words. Defaults to ``None``. If two 'words'
are connected, the plot will not add a space between them (e.g., the parts "enjoy" and
"able" would be connected to "enjoyable" with potentially different attributions for
each part).
chars_per_line: The maximum number of characters per line. Defaults to ``35`` after which
the text will be wrapped to the next line. Connected words receive a '-' in front of
them.
font_family: The font family used for the plot. Defaults to ``sans-serif``. For a list of
available font families, see the matplotlib documentation of
``matplotlib.font_manager.FontProperties``. Note the plot is optimized for sans-serif.
max_score: The maximum score for the attributions to scale the colors and alpha values. This
is useful if you want to compare the attributions of different sentences and both plots
should have the same color scale. Defaults to ``None``.
show: Whether to show the plot. Defaults to ``False``.
Returns:
If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple with
the figure and the axis of the plot.
Example:
>>> import numpy as np
>>> from shapiq.plot import sentence_plot
>>> iv = InteractionValues(
... values=np.array([0.45, 0.01, 0.67, -0.2, -0.05, 0.7, 0.1, -0.04, 0.56, 0.7]),
... index="SV",
... n_players=10,
... min_order=1,
... max_order=1,
... estimated=False,
... baseline_value=0.0,
... )
>>> words = ["I", "really", "enjoy", "working", "with", "Shapley", "values", "in", "Python", "!"]
>>> connected_words = [("Shapley", "values")]
>>> fig, ax = sentence_plot(iv, words, connected_words, show=False, chars_per_line=100)
>>> plt.show()
.. image:: /_static/sentence_plot_connected_example.png
:width: 300
:align: center
"""

# set all the size parameters
fontsize = 20
word_spacing = 15
line_spacing = 10
height_padding = 5
width_padding = 5

# clean the input
connected_words = [] if connected_words is None else connected_words
words = [word.strip() for word in words]
attributions = [interaction_values[(i,)] for i in range(len(words))]

# get the maximum score
max_abs_attribution = max_score
if max_score is None:
max_abs_attribution = max([abs(value) for value in attributions])

# create plot
fig, ax = plt.subplots()

max_x_pos = 0
x_pos, y_pos = word_spacing, 0
lines, chars_in_line = 0, 0
for i, (word, attribution) in enumerate(zip(words, attributions)):

# check if the word is connected
is_word_connected_first = False
is_word_connected_second = (words[i - 1], word) in connected_words
try:
is_word_connected_first = (word, words[i + 1]) in connected_words
except IndexError:
pass

# check if the line is too long and needs to be wrapped
chars_in_line += len(word)
if chars_in_line > chars_per_line:
lines += 1
chars_in_line = 0
x_pos = word_spacing
y_pos -= fontsize + line_spacing
if is_word_connected_second:
word = "-" + word

# adjust the x position for connected words
if is_word_connected_second:
x_pos += 2

# set the position of the word in the plot
position = (x_pos, y_pos)

# get the color and alpha value
color, alpha = _get_color_and_alpha(max_abs_attribution, attribution)

# get the text
text_color = "black" if alpha < 2 / 3 else "white"
fp = FontProperties(family=font_family, style="normal", size=fontsize, weight="normal")
text_path = TextPath(position, word, prop=fp)
text_path = PathPatch(text_path, facecolor=text_color, edgecolor="none")
width_of_text = text_path.get_window_extent().width

# get dimensions for the explanation patch
height_patch = fontsize + height_padding
width_patch = width_of_text + 1
y_pos_patch = y_pos - height_padding
x_pos_patch = x_pos + 1
if is_word_connected_first:
x_pos_patch -= width_padding / 2
width_patch += width_padding / 2
elif is_word_connected_second:
width_patch += width_padding / 2
else:
x_pos_patch -= width_padding / 2
width_patch += width_padding

# create the explanation patch
patch = FancyBboxPatch(
xy=(x_pos_patch, y_pos_patch),
width=width_patch,
height=height_patch,
color=color,
alpha=alpha,
zorder=-1,
boxstyle="Round, pad=0, rounding_size=3",
)

# draw elements for the word
ax.add_patch(patch)
ax.add_artist(text_path)

# update the x position
x_pos += width_of_text + word_spacing
max_x_pos = max(max_x_pos, x_pos)
if is_word_connected_first:
x_pos -= word_spacing

# fix up the dimensions of the plot
ax.set_xlim(0, max_x_pos)
ax.set_ylim(y_pos - fontsize / 2, fontsize + fontsize / 2)
width = max_x_pos
height = fontsize + fontsize / 2 + abs(y_pos - fontsize / 2)
fig.set_size_inches(width / 100, height / 100)

# clean up the plot
ax.axis("off")
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

# draw the plot
if not show:
return fig, ax
plt.show()
4 changes: 3 additions & 1 deletion shapiq/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def get_interaction_values_and_feature_names(
interaction_values: InteractionValues,
feature_names: Optional[np.ndarray] = None,
feature_values: Optional[np.ndarray] = None,
abbreviate: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
"""Converts higher-order interaction values to SHAP-like vectors with associated labels.
Expand All @@ -24,12 +25,13 @@ def get_interaction_values_and_feature_names(
feature_names: The feature names used for plotting. If no feature names are provided, the
feature indices are used instead. Defaults to ``None``.
feature_values: The feature values used for plotting. Defaults to ``None``.
abbreviate: Whether to abbreviate the feature names. Defaults to ``True``.
Returns:
A tuple containing the SHAP values and the corresponding labels.
"""
feature_names = copy.deepcopy(feature_names)
if feature_names is not None:
if feature_names is not None and abbreviate:
feature_names = abbreviate_feature_names(feature_names)
_values_dict = {}
for i in range(1, interaction_values.max_order + 1):
Expand Down
4 changes: 3 additions & 1 deletion shapiq/plot/watefall.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def waterfall_plot(
feature_names: Optional[np.ndarray] = None,
feature_values: Optional[np.ndarray] = None,
show: bool = False,
abbreviate: bool = True,
max_display: int = 10,
) -> Optional[plt.Axes]:
"""Draws interaction values on a waterfall plot.
Expand All @@ -31,6 +32,7 @@ def waterfall_plot(
feature indices are used instead. Defaults to ``None``.
feature_values: The feature values used for plotting. Defaults to ``None``.
show: Whether to show the plot. Defaults to ``False``.
abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
max_display: The maximum number of interactions to display. Defaults to ``10``.
"""
check_import_module("shap")
Expand All @@ -45,7 +47,7 @@ def waterfall_plot(
)
else:
_shap_values, _labels = get_interaction_values_and_feature_names(
interaction_values, feature_names, feature_values
interaction_values, feature_names, feature_values, abbreviate=abbreviate
)

shap_explanation = shap.Explanation(
Expand Down
Loading

0 comments on commit fadf90d

Please sign in to comment.