Skip to content

Commit

Permalink
adds test to joint marginal prediciton and closes #261
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Oct 29, 2024
1 parent 60e8f79 commit 21d781c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
- `get_n_order` now has optional lower/upper limits for the order
- computing metrics now tries to resolve not-matching interaction indices and will throw a warning instead of a ValueError [#179](https://github.com/mmschlk/shapiq/issues/179)
- removed the `sample_replacements` parameter from `MarginalImputer` which is now handled by the `BaselineImputer`. Added a DeprecationWarning for the parameter, which will be removed in the next release.
- adds `BaselineImputer` [#107](https://github.com/mmschlk/shapiq/issues/107)
- adds `joint_marginal_distribution` parameter to `MarginalImputer` [#261](https://github.com/mmschlk/shapiq/issues/261)

### v1.0.1 (2024-06-05)

Expand Down
41 changes: 41 additions & 0 deletions tests/tests_imputer/test_marginal_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,47 @@ def model(x: np.ndarray) -> np.ndarray:
assert len(imputed_values) == 2


def test_joint_marginal_distribution():
"""Test weather the marginal imputer correctly samples replacement values."""

def model(x: np.ndarray) -> np.ndarray:
return np.sum(x, axis=1)

data = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
]
data_as_tuples = [tuple(row) for row in data]
data = np.array(data)
x = np.array([1, 1, 1])

imputer = MarginalImputer(
model=model,
data=data,
x=x,
sample_size=3,
random_state=42,
joint_marginal_distribution=False,
)
replacement_data_independent = imputer._sample_replacement_values(3)
print(replacement_data_independent)

imputer = MarginalImputer(
model=model,
data=data,
x=x,
sample_size=3,
random_state=42,
joint_marginal_distribution=True,
)
replacement_data_joint = imputer._sample_replacement_values(3)
for i in range(3):
assert tuple(replacement_data_joint[i]) in data_as_tuples
# the below only works because of the random seed (might break in future)
assert tuple(replacement_data_independent[i]) not in data_as_tuples


def test_raise_warning():

def model(x: np.ndarray) -> np.ndarray:
Expand Down

0 comments on commit 21d781c

Please sign in to comment.