Skip to content

Commit

Permalink
increased test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Nov 30, 2023
1 parent a88bc43 commit fb1a6a6
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 68 deletions.
22 changes: 1 addition & 21 deletions shapiq/approximator/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def __str__(self) -> str:
def __eq__(self, other: object) -> bool:
"""Checks if two Approximator objects are equal."""
if not isinstance(other, Approximator):
raise NotImplementedError("Cannot compare Approximator with other types.")
raise ValueError("Cannot compare Approximator with other types.")
if (
self.n != other.n
or self.max_order != other.max_order
Expand All @@ -336,26 +336,6 @@ def __hash__(self) -> int:
"""Returns the hash of the Approximator object."""
return hash((self.n, self.max_order, self.index, self.top_order, self._random_state))

def __copy__(self) -> "Approximator":
"""Returns a copy of the Approximator object."""
return self.__class__(
n=self.n,
max_order=self.max_order,
index=self.index,
top_order=self.top_order,
random_state=self._random_state,
)

def __deepcopy__(self, memo) -> "Approximator":
"""Returns a deep copy of the Approximator object."""
return self.__class__(
n=self.n,
max_order=self.max_order,
index=self.index,
top_order=self.top_order,
random_state=self._random_state,
)


class ShapleySamplingMixin:
"""Mixin class for the computation of Shapley weights.
Expand Down
6 changes: 0 additions & 6 deletions shapiq/approximator/permutation/sii.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,3 @@ def approximate(
result = np.divide(result, counts, out=result, where=counts != 0)

return self._finalize_result(result, budget=used_budget, estimated=True)


if __name__ == "__main__":
import doctest

doctest.testmod()
6 changes: 0 additions & 6 deletions shapiq/approximator/permutation/sti.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,3 @@ def _compute_lower_order_sti(
interaction_index = self._interaction_lookup[subset]
result[interaction_index] += update
return result


if __name__ == "__main__":
import doctest

doctest.testmod()
6 changes: 0 additions & 6 deletions shapiq/approximator/regression/fsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,3 @@ def _get_fsi_subset_representation(
):
regression_subsets[:, interaction_index] = all_subsets[:, interaction].all(axis=1)
return regression_subsets, num_players


if __name__ == "__main__":
import doctest

doctest.testmod()
6 changes: 0 additions & 6 deletions shapiq/approximator/shapiq/shapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,3 @@ def _init_discrete_derivative_weights(self) -> dict[int, np.ndarray[float]]:
for k in range(max(0, order + t - self.n), min(order, t) + 1):
weights[order][t, k] = (-1) ** (order - k) * self._weight_kernel(t - k, order)
return weights


if __name__ == "__main__":
import doctest

doctest.testmod()
23 changes: 0 additions & 23 deletions shapiq/games/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,3 @@ def __repr__(self):

def __str__(self):
return f"DummyGame(n={self.n}, interaction={self.interaction})"

def __eq__(self, other):
return self.n == other.n and self.interaction == other.interaction

def __ne__(self, other):
return not self.__eq__(other)

def __hash__(self):
return hash((self.n, self.interaction))

def __copy__(self):
return DummyGame(n=self.n, interaction=self.interaction)

def __deepcopy__(self, memo):
return DummyGame(n=self.n, interaction=self.interaction)

def __getstate__(self):
return {"n": self.n, "interaction": self.interaction}

def __setstate__(self, state):
self.n = state["n"]
self.interaction = state["interaction"]
self.N = set(range(self.n))
14 changes: 14 additions & 0 deletions tests/test_approximator_permutation_sii.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""This test module contains all tests regarding the SII permutation sampling approximator."""
from copy import copy, deepcopy

import numpy as np
import pytest

Expand Down Expand Up @@ -27,6 +29,18 @@ def test_initialization(n, max_order, top_order, expected):
assert approximator.iteration_cost == expected
assert approximator.index == "SII"

approximator_copy = copy(approximator)
approximator_deepcopy = deepcopy(approximator)
approximator_deepcopy.index = "something"
assert approximator_copy == approximator # check that the copy is equal
assert approximator_deepcopy != approximator # check that the deepcopy is not equal
approximator_string = str(approximator)
assert repr(approximator) == approximator_string
assert hash(approximator) == hash(approximator_copy)
assert hash(approximator) != hash(approximator_deepcopy)
with pytest.raises(ValueError):
_ = approximator == 1


@pytest.mark.parametrize(
"n, max_order, top_order, budget, batch_size",
Expand Down
27 changes: 27 additions & 0 deletions tests/test_approximator_permutation_sti.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""This test module contains all tests regarding the STI permutation sampling approximator."""
from copy import copy, deepcopy

import numpy as np
import pytest

Expand Down Expand Up @@ -28,6 +30,18 @@ def test_initialization(n, max_order, iteration_cost):
assert approximator.iteration_cost == iteration_cost
assert approximator.index == "STI"

approximator_copy = copy(approximator)
approximator_deepcopy = deepcopy(approximator)
approximator_deepcopy.index = "something"
assert approximator_copy == approximator # check that the copy is equal
assert approximator_deepcopy != approximator # check that the deepcopy is not equal
approximator_string = str(approximator)
assert repr(approximator) == approximator_string
assert hash(approximator) == hash(approximator_copy)
assert hash(approximator) != hash(approximator_deepcopy)
with pytest.raises(ValueError):
_ = approximator == 1


@pytest.mark.parametrize(
"n, max_order, budget, batch_size",
Expand Down Expand Up @@ -59,3 +73,16 @@ def test_approximate(n, max_order, budget, batch_size):
# check efficiency
efficiency = np.sum(sti_estimates.values)
assert efficiency == pytest.approx(2.0, 0.01)


def test_small_budget_warning():
"""Tests that a warning is raised if the budget is too small."""
n, max_order = 10, 3
interaction = (1, 2)
game = DummyGame(n, interaction)
approximator = PermutationSamplingSTI(n, max_order, random_state=42)
# lower_order_cost is 55
with pytest.warns(UserWarning):
_ = approximator.approximate(1, game) # not even lower_order_cost
with pytest.warns(UserWarning):
_ = approximator.approximate(56, game) # lower_order_cost but no iteration
14 changes: 14 additions & 0 deletions tests/test_approximator_regression_fsi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""This test module contains all tests regarding the FSI regression approximator."""
from copy import deepcopy, copy

import numpy as np
import pytest

Expand Down Expand Up @@ -28,6 +30,18 @@ def test_initialization(n, max_order):
assert approximator.iteration_cost == 1
assert approximator.index == "FSI"

approximator_copy = copy(approximator)
approximator_deepcopy = deepcopy(approximator)
approximator_deepcopy.index = "something"
assert approximator_copy == approximator # check that the copy is equal
assert approximator_deepcopy != approximator # check that the deepcopy is not equal
approximator_string = str(approximator)
assert repr(approximator) == approximator_string
assert hash(approximator) == hash(approximator_copy)
assert hash(approximator) != hash(approximator_deepcopy)
with pytest.raises(ValueError):
_ = approximator == 1


@pytest.mark.parametrize(
"n, max_order, budget, batch_size", [(7, 2, 380, 100), (7, 2, 380, None), (7, 2, 100, None)]
Expand Down
16 changes: 16 additions & 0 deletions tests/test_approximator_shapiq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""This test module contains all tests regarding the shapiq approximator."""
from copy import copy, deepcopy

import numpy as np
import pytest
from approximator._base import InteractionValues
Expand Down Expand Up @@ -32,6 +34,20 @@ def test_initialization(n, max_order, index, top_order):
assert approximator.iteration_cost == 1
assert approximator.index == index

approximator_copy = copy(approximator)
approximator_deepcopy = deepcopy(approximator)
approximator_deepcopy.index = "something"
assert approximator_copy == approximator # check that the copy is equal
assert approximator_deepcopy != approximator # check that the deepcopy is not equal
approximator_string = str(approximator)
assert repr(approximator) == approximator_string
assert hash(approximator) == hash(approximator_copy)
assert hash(approximator) != hash(approximator_deepcopy)
with pytest.raises(ValueError):
_ = approximator == 1
with pytest.raises(ValueError):
approximator_deepcopy._weight_kernel(3, 2)


@pytest.mark.parametrize("n, max_order, budget, batch_size", [(7, 2, 100, None), (7, 2, 100, 10)])
def test_approximate_fsi(n, max_order, budget, batch_size):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_games_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def test_dummy_game(n, interaction, expected):
x_input[list(coalition)] = True
assert game(x_input)[0] == expected[coalition]

string_game = str(game)
assert repr(game) == string_game


def test_dummy_game_access_counts():
"""Test how often the game was called."""
Expand Down

0 comments on commit fb1a6a6

Please sign in to comment.