Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

183 make games be also callable with a tuplelist of tuples #242

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 108 additions & 5 deletions shapiq/games/base.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love the examples!

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pickle
import warnings
from abc import ABC
from typing import Optional
from typing import Optional, Union

import numpy as np
from tqdm.auto import tqdm
Expand Down Expand Up @@ -94,6 +94,7 @@ def __init__(
normalization_value: Optional[float] = None,
path_to_values: Optional[str] = None,
verbose: bool = False,
player_names: Optional[list[str]] = None,
*args,
**kwargs,
) -> None:
Expand Down Expand Up @@ -137,6 +138,11 @@ def __init__(
self._empty_coalition_value_property = None
self._grand_coalition_value_property = None

# define player_names
self.player_name_lookup: dict[str, int] = (
{name: i for i, name in enumerate(player_names)} if player_names is not None else None
)

self.verbose = verbose

@property
Expand All @@ -159,7 +165,105 @@ def is_normalized(self) -> bool:
"""Checks if the game is normalized/centered."""
return self(self.empty_coalition) == 0

def __call__(self, coalitions: np.ndarray, verbose: bool = False) -> np.ndarray:
def _check_coalitions(
self,
coalitions: Union[np.ndarray, list[Union[tuple[int], tuple[str]]]],
) -> np.ndarray:
"""
Check if the coalitions are in the correct format and convert them to one-hot encoding.
The format may either be a numpy array containg the coalitions in one-hot encoding or a list of tuples with integers or strings.
Args:
coalitions: The coalitions to convert to one-hot encoding.
Returns:
np.ndarray: The coalitions in the correct format
Raises:
TypeError: If the coalitions are not in the correct format.
Examples:
>>> coalitions = np.asarray([[1, 0, 0, 0], [0, 1, 1, 0]])
>>> coalitions = [(0, 1), (1, 2)]
>>> coalitions = [()]
>>> coalitions = [(0, 1), (1, 2), (0, 1, 2)]
if player_name_lookup is not None:
>>> coalitions = [("Alice", "Bob"), ("Bob", "Charlie")]
Wrong format:
>>> coalitions = [1, 0, 0, 0]
>>> coalitions = [(1,"Alice")]
>>> coalitions = np.array([1,-1,2])


"""
error_message = (
"List may only contain tuples of integers or strings."
"The tuples are not allowed to have heterogeneous types."
"Reconcile the docs for correct format of coalitions."
)

if isinstance(coalitions, np.ndarray):

# Check that coalition is contained in array
if len(coalitions) == 0:
raise TypeError("The array of coalitions is empty.")

# Check if single coalition is correctly given
if coalitions.ndim == 1:
if len(coalitions) < self.n_players or len(coalitions) > self.n_players:
raise TypeError(
"The array of coalitions is not correctly formatted."
f"It should have a length of {self.n_players}"
)
coalitions = coalitions.reshape((1, self.n_players))

# Check that all coalitions have the correct number of players
if coalitions.shape[1] != self.n_players:
raise TypeError(
f"The number of players in the coalitions ({coalitions.shape[1]}) does not match "
f"the number of players in the game ({self.n_players})."
)

# Check that values of numpy array are either 0 or 1
if not np.all(np.logical_or(coalitions == 0, coalitions == 1)):
raise TypeError("The values in the array of coalitions are not binary.")

return coalitions

# We now assume to work with list of tuples
if isinstance(coalitions, tuple):
# if by any chance a tuple was given wrap into a list
coalitions = [coalitions]

try:
# convert list of tuples to one-hot encoding
coalitions = transform_coalitions_to_array(coalitions, self.n_players)

return coalitions
except Exception as err:
# It may either be the tuples contain strings or wrong format
if self.player_name_lookup is not None:
# We now assume the tuples to contain strings
try:
coalitions = [
(
tuple(self.player_name_lookup[player] for player in coalition)
if coalition != tuple()
else tuple()
)
for coalition in coalitions
]
coalitions = transform_coalitions_to_array(coalitions, self.n_players)

return coalitions
except Exception as err:
raise TypeError(error_message) from err

raise TypeError(error_message) from err

def __call__(
self,
coalitions: Union[
np.ndarray, list[Union[tuple[int], tuple[str]]], tuple[Union[int, str]], str
],
verbose: bool = False,
) -> np.ndarray:
"""Calls the game's value function with the given coalitions and returns the output of the
value function.

Expand All @@ -170,9 +274,8 @@ def __call__(self, coalitions: np.ndarray, verbose: bool = False) -> np.ndarray:
Returns:
The values of the coalitions.
"""
# check if coalitions are correct dimensions
if coalitions.ndim == 1:
coalitions = coalitions.reshape((1, self.n_players))
# check if coalitions are correct format
coalitions = self._check_coalitions(coalitions)

verbose = verbose or self.verbose

Expand Down
86 changes: 86 additions & 0 deletions tests/tests_games/test_base_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,96 @@
import numpy as np
import pytest

from shapiq.games.base import Game
from shapiq.games.benchmark import DummyGame # used to test the base class
from shapiq.utils.sets import powerset, transform_coalitions_to_array


def test_call():
"""This test tests the call function of the base game class."""

class TestGame(Game):
"""This is a test game class that inherits from the base game class.
Its value function is the amount of players divided by the number of players.
"""

def __init__(self, n, **kwargs):
super().__init__(n_players=n, normalization_value=0, **kwargs)

def value_function(self, coalition):
return np.sum(coalition) / self.n_players

n_players = 6
test_game = TestGame(
n=n_players, player_names=["Alice", "Bob", "Charlie", "David", "Eve", "Frank"]
)

# assert that player names are correctly stored
assert test_game.player_name_lookup == {
"Alice": 0,
"Bob": 1,
"Charlie": 2,
"David": 3,
"Eve": 4,
"Frank": 5,
}

assert test_game([]) == 0.0

# test coalition calls with wrong datatype
with pytest.raises(TypeError):
assert test_game([(0, 1), "Alice", "Charlie"])
with pytest.raises(TypeError):
assert test_game([(0, 1), ("Alice",), ("Bob",)])
with pytest.raises(TypeError):
assert test_game(("Alice", 1))

# test wrong coalition size in call
with pytest.raises(TypeError):
assert test_game(np.array([True, False, True])) == 0.0
with pytest.raises(TypeError):
assert test_game(np.array([])) == 0.0

# test wrong method for numpy array values
with pytest.raises(TypeError):
assert test_game(np.array([1, 2, 3, 4, 5, 6])) == 0.0

# test wrong coalition size in shape[1]
with pytest.raises(TypeError):
assert test_game(np.array([[True, False, True]])) == 0.0

# test with empty coalition all call variants
test_coalition = test_game.empty_coalition
assert test_game(test_coalition) == 0.0
assert test_game(()) == 0.0
assert test_game([()]) == 0.0

# test with grand coalition all call variants
test_coalition = test_game.grand_coalition
assert test_game(test_coalition) == 1.0
assert test_game(tuple(range(0, test_game.n_players))) == 1.0
assert test_game([tuple(range(0, test_game.n_players))]) == 1.0
assert test_game(tuple(test_game.player_name_lookup.values())) == 1.0
assert test_game([tuple(test_game.player_name_lookup.values())]) == 1.0

# test with single player coalition all call variants
test_coalition = np.array([True] + [False for _ in range(test_game.n_players - 1)])
assert test_game(test_coalition) - 1 / 6 < 10e-7
assert test_game((0,)) - 1 / 6 < 10e-7
assert test_game([(0,)]) - 1 / 6 < 10e-7
assert test_game(("Alice",)) - 1 / 6 < 10e-7
assert test_game([("Alice",)]) - 1 / 6 < 10e-7

# test string calls with missing player names
test_game2 = TestGame(n=n_players)
with pytest.raises(TypeError):
assert test_game2("Alice") == 0.0
with pytest.raises(TypeError):
assert test_game2(("Bob",)) == 0.0
with pytest.raises(TypeError):
assert test_game2([("Charlie",)]) == 0.0


def test_precompute():
"""This test tests the precompute function of the base game class"""
n_players = 6
Expand Down