Skip to content

Commit

Permalink
183 make games be also callable with a tuplelist of tuples (#242)
Browse files Browse the repository at this point in the history
* added possibility of tuple/list of tuple in BaseGame __call__

* Added possibility for to Game coalitions with given player_names

* Added possibility to call game based on given player_names

* removed player_names property as it caused conflicts

* made __call__ annotations for python3.9

* refactor game call with strings

* refactor game call of tuple and list

* Reimplement _check_coalitions to fewer lines.

---------

Co-authored-by: Maximilian <[email protected]>
  • Loading branch information
Advueu963 and mmschlk authored Oct 25, 2024
1 parent 96bd7c7 commit 3c3dfc2
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 5 deletions.
113 changes: 108 additions & 5 deletions shapiq/games/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pickle
import warnings
from abc import ABC
from typing import Optional
from typing import Optional, Union

import numpy as np
from tqdm.auto import tqdm
Expand Down Expand Up @@ -94,6 +94,7 @@ def __init__(
normalization_value: Optional[float] = None,
path_to_values: Optional[str] = None,
verbose: bool = False,
player_names: Optional[list[str]] = None,
*args,
**kwargs,
) -> None:
Expand Down Expand Up @@ -137,6 +138,11 @@ def __init__(
self._empty_coalition_value_property = None
self._grand_coalition_value_property = None

# define player_names
self.player_name_lookup: dict[str, int] = (
{name: i for i, name in enumerate(player_names)} if player_names is not None else None
)

self.verbose = verbose

@property
Expand All @@ -159,7 +165,105 @@ def is_normalized(self) -> bool:
"""Checks if the game is normalized/centered."""
return self(self.empty_coalition) == 0

def __call__(self, coalitions: np.ndarray, verbose: bool = False) -> np.ndarray:
def _check_coalitions(
self,
coalitions: Union[np.ndarray, list[Union[tuple[int], tuple[str]]]],
) -> np.ndarray:
"""
Check if the coalitions are in the correct format and convert them to one-hot encoding.
The format may either be a numpy array containg the coalitions in one-hot encoding or a list of tuples with integers or strings.
Args:
coalitions: The coalitions to convert to one-hot encoding.
Returns:
np.ndarray: The coalitions in the correct format
Raises:
TypeError: If the coalitions are not in the correct format.
Examples:
>>> coalitions = np.asarray([[1, 0, 0, 0], [0, 1, 1, 0]])
>>> coalitions = [(0, 1), (1, 2)]
>>> coalitions = [()]
>>> coalitions = [(0, 1), (1, 2), (0, 1, 2)]
if player_name_lookup is not None:
>>> coalitions = [("Alice", "Bob"), ("Bob", "Charlie")]
Wrong format:
>>> coalitions = [1, 0, 0, 0]
>>> coalitions = [(1,"Alice")]
>>> coalitions = np.array([1,-1,2])
"""
error_message = (
"List may only contain tuples of integers or strings."
"The tuples are not allowed to have heterogeneous types."
"Reconcile the docs for correct format of coalitions."
)

if isinstance(coalitions, np.ndarray):

# Check that coalition is contained in array
if len(coalitions) == 0:
raise TypeError("The array of coalitions is empty.")

# Check if single coalition is correctly given
if coalitions.ndim == 1:
if len(coalitions) < self.n_players or len(coalitions) > self.n_players:
raise TypeError(
"The array of coalitions is not correctly formatted."
f"It should have a length of {self.n_players}"
)
coalitions = coalitions.reshape((1, self.n_players))

# Check that all coalitions have the correct number of players
if coalitions.shape[1] != self.n_players:
raise TypeError(
f"The number of players in the coalitions ({coalitions.shape[1]}) does not match "
f"the number of players in the game ({self.n_players})."
)

# Check that values of numpy array are either 0 or 1
if not np.all(np.logical_or(coalitions == 0, coalitions == 1)):
raise TypeError("The values in the array of coalitions are not binary.")

return coalitions

# We now assume to work with list of tuples
if isinstance(coalitions, tuple):
# if by any chance a tuple was given wrap into a list
coalitions = [coalitions]

try:
# convert list of tuples to one-hot encoding
coalitions = transform_coalitions_to_array(coalitions, self.n_players)

return coalitions
except Exception as err:
# It may either be the tuples contain strings or wrong format
if self.player_name_lookup is not None:
# We now assume the tuples to contain strings
try:
coalitions = [
(
tuple(self.player_name_lookup[player] for player in coalition)
if coalition != tuple()
else tuple()
)
for coalition in coalitions
]
coalitions = transform_coalitions_to_array(coalitions, self.n_players)

return coalitions
except Exception as err:
raise TypeError(error_message) from err

raise TypeError(error_message) from err

def __call__(
self,
coalitions: Union[
np.ndarray, list[Union[tuple[int], tuple[str]]], tuple[Union[int, str]], str
],
verbose: bool = False,
) -> np.ndarray:
"""Calls the game's value function with the given coalitions and returns the output of the
value function.
Expand All @@ -170,9 +274,8 @@ def __call__(self, coalitions: np.ndarray, verbose: bool = False) -> np.ndarray:
Returns:
The values of the coalitions.
"""
# check if coalitions are correct dimensions
if coalitions.ndim == 1:
coalitions = coalitions.reshape((1, self.n_players))
# check if coalitions are correct format
coalitions = self._check_coalitions(coalitions)

verbose = verbose or self.verbose

Expand Down
86 changes: 86 additions & 0 deletions tests/tests_games/test_base_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,96 @@
import numpy as np
import pytest

from shapiq.games.base import Game
from shapiq.games.benchmark import DummyGame # used to test the base class
from shapiq.utils.sets import powerset, transform_coalitions_to_array


def test_call():
"""This test tests the call function of the base game class."""

class TestGame(Game):
"""This is a test game class that inherits from the base game class.
Its value function is the amount of players divided by the number of players.
"""

def __init__(self, n, **kwargs):
super().__init__(n_players=n, normalization_value=0, **kwargs)

def value_function(self, coalition):
return np.sum(coalition) / self.n_players

n_players = 6
test_game = TestGame(
n=n_players, player_names=["Alice", "Bob", "Charlie", "David", "Eve", "Frank"]
)

# assert that player names are correctly stored
assert test_game.player_name_lookup == {
"Alice": 0,
"Bob": 1,
"Charlie": 2,
"David": 3,
"Eve": 4,
"Frank": 5,
}

assert test_game([]) == 0.0

# test coalition calls with wrong datatype
with pytest.raises(TypeError):
assert test_game([(0, 1), "Alice", "Charlie"])
with pytest.raises(TypeError):
assert test_game([(0, 1), ("Alice",), ("Bob",)])
with pytest.raises(TypeError):
assert test_game(("Alice", 1))

# test wrong coalition size in call
with pytest.raises(TypeError):
assert test_game(np.array([True, False, True])) == 0.0
with pytest.raises(TypeError):
assert test_game(np.array([])) == 0.0

# test wrong method for numpy array values
with pytest.raises(TypeError):
assert test_game(np.array([1, 2, 3, 4, 5, 6])) == 0.0

# test wrong coalition size in shape[1]
with pytest.raises(TypeError):
assert test_game(np.array([[True, False, True]])) == 0.0

# test with empty coalition all call variants
test_coalition = test_game.empty_coalition
assert test_game(test_coalition) == 0.0
assert test_game(()) == 0.0
assert test_game([()]) == 0.0

# test with grand coalition all call variants
test_coalition = test_game.grand_coalition
assert test_game(test_coalition) == 1.0
assert test_game(tuple(range(0, test_game.n_players))) == 1.0
assert test_game([tuple(range(0, test_game.n_players))]) == 1.0
assert test_game(tuple(test_game.player_name_lookup.values())) == 1.0
assert test_game([tuple(test_game.player_name_lookup.values())]) == 1.0

# test with single player coalition all call variants
test_coalition = np.array([True] + [False for _ in range(test_game.n_players - 1)])
assert test_game(test_coalition) - 1 / 6 < 10e-7
assert test_game((0,)) - 1 / 6 < 10e-7
assert test_game([(0,)]) - 1 / 6 < 10e-7
assert test_game(("Alice",)) - 1 / 6 < 10e-7
assert test_game([("Alice",)]) - 1 / 6 < 10e-7

# test string calls with missing player names
test_game2 = TestGame(n=n_players)
with pytest.raises(TypeError):
assert test_game2("Alice") == 0.0
with pytest.raises(TypeError):
assert test_game2(("Bob",)) == 0.0
with pytest.raises(TypeError):
assert test_game2([("Charlie",)]) == 0.0


def test_precompute():
"""This test tests the precompute function of the base game class"""
n_players = 6
Expand Down

0 comments on commit 3c3dfc2

Please sign in to comment.