From c2eb87990712a9cee5407e8198dca8142d2613a2 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Thu, 9 Jan 2025 14:24:51 +0100 Subject: [PATCH 01/16] adds tabpfn example --- .../tabular_notebooks/explaining_tabpfn.ipynb | 928 ++++++++++++++++++ 1 file changed, 928 insertions(+) create mode 100644 docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb diff --git a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb new file mode 100644 index 00000000..8b70886f --- /dev/null +++ b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb @@ -0,0 +1,928 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Explaining TabPFN\n", + "\n", + "TabPFN is a foundation model for tabular data, which uses in-context learning to do solve classification and regression tasks.\n", + "TabPFN outperforms traditional models like Random Forest, Gradient Boosting for small datasets and raises the state-of-the-art for tabular data!\n", + "Recently, a major update was released, which includes a new architecture and an updated API.\n", + "\n", + "For more information about TabPFN, check the [github repository](https://github.com/PriorLabs/TabPFN) and the associated papers ([TabPFN](https://openreview.net/forum?id=eu9fVjVasr4), [TabPFNv2](https://www.nature.com/articles/s41586-024-08328-6)).\n", + "\n", + "In this tutorial, we see how we can **use shapiq to explain the predictions of TabPFNv2**. \n", + "We will use the California housing dataset and train a TabPFN model to predict the house prices.\n", + "Many explanation methods show that models tend to learn interactions between the latitude and longitude features, containing information about the exact location of a house.\n", + "We want to see if TabPFN also learns the interactions between latitude and longitude.\n", + "\n", + "First, lets import the libraries (tabpfn and shapiq) and check their versions.\n", + "Note that this tutorial uses the latest version of TabPFN (> 2.0.0) and will not necessarily work with older versions." + ], + "id": "af7fe5c630d43e1d" + }, + { + "cell_type": "code", + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-01-09T13:20:56.150989Z", + "start_time": "2025-01-09T13:20:56.137996Z" + } + }, + "source": [ + "from importlib.metadata import version\n", + "\n", + "import shapiq\n", + "import tabpfn\n", + "\n", + "print(\"shapiq version: \", shapiq.__version__, \"tabpfn version: \", version(\"tabpfn\"))" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shapiq version: 1.1.1 tabpfn version: 2.0.0\n" + ] + } + ], + "execution_count": 9 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Get the California Housing Data\n", + "Now let's load the California housing dataset and inspect the data." + ], + "id": "229e7c0478fc1c96" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T13:20:56.243644Z", + "start_time": "2025-01-09T13:20:56.199647Z" + } + }, + "cell_type": "code", + "source": [ + "x_data, y_data = shapiq.datasets.load_california_housing()\n", + "feature_names = x_data.columns\n", + "\n", + "# copy the data to make sure we don't modify the original data\n", + "dataset = x_data.copy()\n", + "dataset[\"HousePrice\"] = y_data\n", + "display(dataset.head())\n", + "display(dataset[\"HousePrice\"].describe())" + ], + "id": "af75a5d50fbb096e", + "outputs": [ + { + "data": { + "text/plain": [ + " MedInc HouseAge AveRooms AveBedrms Population AveOccup Latitude \\\n", + "0 8.3252 41.0 6.984127 1.023810 322.0 2.555556 37.88 \n", + "1 8.3014 21.0 6.238137 0.971880 2401.0 2.109842 37.86 \n", + "2 7.2574 52.0 8.288136 1.073446 496.0 2.802260 37.85 \n", + "3 5.6431 52.0 5.817352 1.073059 558.0 2.547945 37.85 \n", + "4 3.8462 52.0 6.281853 1.081081 565.0 2.181467 37.85 \n", + "\n", + " Longitude HousePrice \n", + "0 -122.23 4.526 \n", + "1 -122.22 3.585 \n", + "2 -122.24 3.521 \n", + "3 -122.25 3.413 \n", + "4 -122.25 3.422 " + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MedIncHouseAgeAveRoomsAveBedrmsPopulationAveOccupLatitudeLongitudeHousePrice
08.325241.06.9841271.023810322.02.55555637.88-122.234.526
18.301421.06.2381370.9718802401.02.10984237.86-122.223.585
27.257452.08.2881361.073446496.02.80226037.85-122.243.521
35.643152.05.8173521.073059558.02.54794537.85-122.253.413
43.846252.06.2818531.081081565.02.18146737.85-122.253.422
\n", + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "count 20640.000000\n", + "mean 2.068558\n", + "std 1.153956\n", + "min 0.149990\n", + "25% 1.196000\n", + "50% 1.797000\n", + "75% 2.647250\n", + "max 5.000010\n", + "Name: HousePrice, dtype: float64" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 10 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "Now we have loaded the data.\n", + "**HousePrice** is the target variable we want to predict.\n", + "The target ranges from 0.15 to 5.0.\n", + "\n", + "In order to use TabPFN, we need to split the data into a training and testing set.\n", + "Note, that TabPFN works best for **small sized datasets** (less than 10k samples).\n", + "So let's select a train set of 10k samples." + ], + "id": "2d3e6649c1ae8450" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T13:20:56.259634Z", + "start_time": "2025-01-09T13:20:56.245636Z" + } + }, + "cell_type": "code", + "source": [ + "# split the data into training and testing sets\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "x_train, x_test, y_train, y_test = train_test_split(\n", + " x_data.values, y_data.values, train_size=200, random_state=42\n", + ")\n", + "print(\"Train data shape: \", x_train.shape, y_train.shape)\n", + "print(\"Test data shape: \", x_test.shape, y_test.shape)" + ], + "id": "e77933887d0a119f", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train data shape: (200, 8) (200,)\n", + "Test data shape: (20440, 8) (20440,)\n" + ] + } + ], + "execution_count": 11 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Fit TabPFN\n", + "Now that we have the data, we can fit TabPFN to the training data and make it ready for predictions." + ], + "id": "8be176b5890b9eaf" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T13:20:56.401760Z", + "start_time": "2025-01-09T13:20:56.261658Z" + } + }, + "cell_type": "code", + "source": [ + "model = tabpfn.TabPFNRegressor(n_estimators=4, n_jobs=4)\n", + "model.fit(x_train, y_train)" + ], + "id": "a1100c73d7b0867e", + "outputs": [ + { + "data": { + "text/plain": [ + "TabPFNRegressor(n_estimators=4, n_jobs=4)" + ], + "text/html": [ + "
TabPFNRegressor(n_estimators=4, n_jobs=4)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 12 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "When we have the \"trained\" model, we can use it to predict the house prices.\n", + "These predictions are very competitive with the state-of-the-art models." + ], + "id": "25603c1d4540f2c5" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T13:21:12.781778Z", + "start_time": "2025-01-09T13:20:56.403762Z" + } + }, + "cell_type": "code", + "source": [ + "from sklearn.metrics import mean_squared_error\n", + "import numpy as np\n", + "\n", + "predictions = model.predict(x_test[:1000])\n", + "mse = mean_squared_error(y_test[:1000], predictions)\n", + "print(mse)\n", + "\n", + "average_prediction = np.mean(predictions)\n", + "print(\"Average prediction: \", average_prediction)" + ], + "id": "d36110af9fa1b058", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.37466335287753105\n", + "Average prediction: 2.129479\n" + ] + } + ], + "execution_count": 13 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Explain TabPFN with shapiq\n", + "Now that we see how TabPFN performs, we can use shapiq to explain the predictions." + ], + "id": "464ced0bcf3760ea" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T13:21:12.797285Z", + "start_time": "2025-01-09T13:21:12.783772Z" + } + }, + "cell_type": "code", + "source": [ + "# explainer = shapiq.Explainer(model, data=x_test[:1000], index=\"FSII\", max_order=2, imputer=\"baseline\")\n", + "# explainer._imputer.verbose = True\n", + "#\n", + "# x_explain = x_test[0]\n", + "#\n", + "# sv = explainer.explain(x_explain)\n", + "# sv.plot_force(feature_names=feature_names)" + ], + "id": "41314e231db2e986", + "outputs": [], + "execution_count": 14 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Explaining TabPFN with Remove-and-\"Retrain\"\n", + "\n", + "Since TabPFN is a foundation model, it uses in-context learning to solve the classification and regression tasks.\n", + "This means that \"retraining\" the model is quite inexpensive, because we only need to provide the new data points with whatever features we want to remove." + ], + "id": "cdba7867ce6fbbb0" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T13:21:12.813285Z", + "start_time": "2025-01-09T13:21:12.799277Z" + } + }, + "cell_type": "code", + "source": [ + "import numpy as np\n", + "\n", + "\n", + "class TabPFNGame(shapiq.Game):\n", + " \"\"\"The TabPFN Game class implementation a remove-and-\"retrain\" strategy to explain the predictions of TabPFN.\"\"\"\n", + "\n", + " def __init__(self, model, x_train, y_train, x_explain, normalization_value):\n", + " self.model = model\n", + " self.x_train = x_train\n", + " self.y_train = y_train\n", + " self.x_explain = x_explain\n", + "\n", + " print(\"Initializing TabPFN Game\")\n", + " print(\"Train data shape: \", x_train.shape, y_train.shape)\n", + " print(\"Explain data shape: \", x_explain.shape)\n", + "\n", + " super().__init__(n_players=x_train.shape[1], normalization_value=normalization_value)\n", + "\n", + " def value_function(self, coalitions: np.ndarray) -> np.ndarray:\n", + " \"\"\"The value function performs the remove-and-\"retrain\" strategy for TabPFN.\"\"\"\n", + " output = np.zeros(len(coalitions), dtype=float)\n", + " for i, coalition in enumerate(coalitions):\n", + " if sum(coalition) == 0:\n", + " output[i] = 0.0\n", + " continue\n", + " x_train_coal = self.x_train[:, coalition]\n", + " x_explain_coal = self.x_explain[coalition].reshape(1, -1)\n", + " self.model.fit(x_train_coal, self.y_train)\n", + " prediction = float(self.model.predict(x_explain_coal)[0])\n", + " output[i] = prediction\n", + " return output" + ], + "id": "37a977c5f4a88aee", + "outputs": [], + "execution_count": 15 + }, + { + "metadata": { + "jupyter": { + "is_executing": true + }, + "ExecuteTime": { + "start_time": "2025-01-09T13:21:12.814278Z" + } + }, + "cell_type": "code", + "source": [ + "x_explain = x_test[0]\n", + "game = TabPFNGame(model, x_train, y_train, x_explain, normalization_value=average_prediction)\n", + "game.verbose = True\n", + "game.precompute()\n", + "game.save_values(\"tabpfn_values.npz\")" + ], + "id": "7b2606969b5bab0", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initializing TabPFN Game\n", + "Train data shape: (200, 8) (200,)\n", + "Explain data shape: (8,)\n" + ] + }, + { + "data": { + "text/plain": [ + "Evaluating game: 0%| | 0/256 [00:00 Date: Thu, 9 Jan 2025 14:38:42 +0100 Subject: [PATCH 02/16] updated average_prediction --- .../tabular_notebooks/explaining_tabpfn.ipynb | 259 ++++++++++-------- 1 file changed, 146 insertions(+), 113 deletions(-) diff --git a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb index 8b70886f..e53b2fe1 100644 --- a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb +++ b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb @@ -28,8 +28,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2025-01-09T13:20:56.150989Z", - "start_time": "2025-01-09T13:20:56.137996Z" + "end_time": "2025-01-09T13:37:42.902399Z", + "start_time": "2025-01-09T13:37:42.888410Z" } }, "source": [ @@ -49,7 +49,7 @@ ] } ], - "execution_count": 9 + "execution_count": 28 }, { "metadata": {}, @@ -63,8 +63,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:20:56.243644Z", - "start_time": "2025-01-09T13:20:56.199647Z" + "end_time": "2025-01-09T13:37:42.980940Z", + "start_time": "2025-01-09T13:37:42.932418Z" } }, "cell_type": "code", @@ -214,7 +214,7 @@ "output_type": "display_data" } ], - "execution_count": 10 + "execution_count": 29 }, { "metadata": {}, @@ -233,8 +233,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:20:56.259634Z", - "start_time": "2025-01-09T13:20:56.245636Z" + "end_time": "2025-01-09T13:37:42.996941Z", + "start_time": "2025-01-09T13:37:42.982942Z" } }, "cell_type": "code", @@ -243,7 +243,7 @@ "from sklearn.model_selection import train_test_split\n", "\n", "x_train, x_test, y_train, y_test = train_test_split(\n", - " x_data.values, y_data.values, train_size=200, random_state=42\n", + " x_data.values, y_data.values, train_size=300, random_state=1\n", ")\n", "print(\"Train data shape: \", x_train.shape, y_train.shape)\n", "print(\"Test data shape: \", x_test.shape, y_test.shape)" @@ -254,12 +254,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "Train data shape: (200, 8) (200,)\n", - "Test data shape: (20440, 8) (20440,)\n" + "Train data shape: (300, 8) (300,)\n", + "Test data shape: (20340, 8) (20340,)\n" ] } ], - "execution_count": 11 + "execution_count": 30 }, { "metadata": {}, @@ -273,8 +273,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:20:56.401760Z", - "start_time": "2025-01-09T13:20:56.261658Z" + "end_time": "2025-01-09T13:37:43.152927Z", + "start_time": "2025-01-09T13:37:42.997939Z" } }, "cell_type": "code", @@ -290,7 +290,7 @@ "TabPFNRegressor(n_estimators=4, n_jobs=4)" ], "text/html": [ - "
TabPFNRegressor(n_estimators=4, n_jobs=4)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + "
TabPFNRegressor(n_estimators=4, n_jobs=4)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ] }, - "execution_count": 12, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 12 + "execution_count": 31 }, { "metadata": {}, @@ -716,34 +716,47 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:21:12.781778Z", - "start_time": "2025-01-09T13:20:56.403762Z" + "end_time": "2025-01-09T13:38:21.998978Z", + "start_time": "2025-01-09T13:37:43.154928Z" + } + }, + "cell_type": "code", + "source": "predictions = model.predict(x_test[:2000])", + "id": "d36110af9fa1b058", + "outputs": [], + "execution_count": 32 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T13:38:22.014134Z", + "start_time": "2025-01-09T13:38:21.999971Z" } }, "cell_type": "code", "source": [ - "from sklearn.metrics import mean_squared_error\n", + "from sklearn.metrics import mean_squared_error, r2_score\n", "import numpy as np\n", "\n", - "predictions = model.predict(x_test[:1000])\n", - "mse = mean_squared_error(y_test[:1000], predictions)\n", - "print(mse)\n", + "mse = mean_squared_error(y_test[:2000], predictions)\n", + "r2 = r2_score(y_test[:2000], predictions)\n", + "print(\"MSE: \", mse, \"R2: \", r2)\n", "\n", "average_prediction = np.mean(predictions)\n", "print(\"Average prediction: \", average_prediction)" ], - "id": "d36110af9fa1b058", + "id": "fdd1896b91cfbd4a", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0.37466335287753105\n", - "Average prediction: 2.129479\n" + "MSE: 0.32530222052366137 R2: 0.753441412304244\n", + "Average prediction: 2.0711837\n" ] } ], - "execution_count": 13 + "execution_count": 33 }, { "metadata": {}, @@ -757,8 +770,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:21:12.797285Z", - "start_time": "2025-01-09T13:21:12.783772Z" + "end_time": "2025-01-09T13:38:22.029136Z", + "start_time": "2025-01-09T13:38:22.015135Z" } }, "cell_type": "code", @@ -773,7 +786,7 @@ ], "id": "41314e231db2e986", "outputs": [], - "execution_count": 14 + "execution_count": 34 }, { "metadata": {}, @@ -789,8 +802,42 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:21:12.813285Z", - "start_time": "2025-01-09T13:21:12.799277Z" + "end_time": "2025-01-09T13:38:27.522752Z", + "start_time": "2025-01-09T13:38:22.032140Z" + } + }, + "cell_type": "code", + "source": [ + "x_explain = x_test[0]\n", + "y_explain = y_test[0]\n", + "\n", + "prediction = model.predict(x_explain.reshape(1, -1))[0]\n", + "print(\n", + " \"Prediction: \",\n", + " prediction,\n", + " \"True value: \",\n", + " y_explain,\n", + " \"Average prediction: \",\n", + " average_prediction,\n", + ")" + ], + "id": "19b2cf3dd8a8d751", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prediction: 3.4585295 True value: 3.55 Average prediction: 2.0711837\n" + ] + } + ], + "execution_count": 35 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T13:38:27.538825Z", + "start_time": "2025-01-09T13:38:27.523744Z" } }, "cell_type": "code", @@ -801,24 +848,25 @@ "class TabPFNGame(shapiq.Game):\n", " \"\"\"The TabPFN Game class implementation a remove-and-\"retrain\" strategy to explain the predictions of TabPFN.\"\"\"\n", "\n", - " def __init__(self, model, x_train, y_train, x_explain, normalization_value):\n", + " def __init__(self, model, x_train, y_train, x_explain, average_prediction):\n", " self.model = model\n", " self.x_train = x_train\n", " self.y_train = y_train\n", " self.x_explain = x_explain\n", + " self.average_prediction = average_prediction\n", "\n", " print(\"Initializing TabPFN Game\")\n", " print(\"Train data shape: \", x_train.shape, y_train.shape)\n", " print(\"Explain data shape: \", x_explain.shape)\n", "\n", - " super().__init__(n_players=x_train.shape[1], normalization_value=normalization_value)\n", + " super().__init__(n_players=x_train.shape[1], normalization_value=self.average_prediction)\n", "\n", " def value_function(self, coalitions: np.ndarray) -> np.ndarray:\n", " \"\"\"The value function performs the remove-and-\"retrain\" strategy for TabPFN.\"\"\"\n", " output = np.zeros(len(coalitions), dtype=float)\n", " for i, coalition in enumerate(coalitions):\n", " if sum(coalition) == 0:\n", - " output[i] = 0.0\n", + " output[i] = self.average_prediction\n", " continue\n", " x_train_coal = self.x_train[:, coalition]\n", " x_explain_coal = self.x_explain[coalition].reshape(1, -1)\n", @@ -829,21 +877,18 @@ ], "id": "37a977c5f4a88aee", "outputs": [], - "execution_count": 15 + "execution_count": 36 }, { "metadata": { - "jupyter": { - "is_executing": true - }, "ExecuteTime": { - "start_time": "2025-01-09T13:21:12.814278Z" + "end_time": "2025-01-09T13:38:27.570490Z", + "start_time": "2025-01-09T13:38:27.539815Z" } }, "cell_type": "code", "source": [ - "x_explain = x_test[0]\n", - "game = TabPFNGame(model, x_train, y_train, x_explain, normalization_value=average_prediction)\n", + "game = TabPFNGame(model, x_train, y_train, x_explain, average_prediction)\n", "game.verbose = True\n", "game.precompute()\n", "game.save_values(\"tabpfn_values.npz\")" @@ -851,35 +896,24 @@ "id": "7b2606969b5bab0", "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Initializing TabPFN Game\n", - "Train data shape: (200, 8) (200,)\n", - "Explain data shape: (8,)\n" + "ename": "TypeError", + "evalue": "TabPFNGame.__init__() got an unexpected keyword argument 'normalization_value'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[37], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m game \u001b[38;5;241m=\u001b[39m \u001b[43mTabPFNGame\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_explain\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnormalization_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maverage_prediction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2\u001b[0m game\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 3\u001b[0m game\u001b[38;5;241m.\u001b[39mprecompute()\n", + "\u001b[1;31mTypeError\u001b[0m: TabPFNGame.__init__() got an unexpected keyword argument 'normalization_value'" ] - }, - { - "data": { - "text/plain": [ - "Evaluating game: 0%| | 0/256 [00:00 Date: Thu, 9 Jan 2025 15:01:11 +0100 Subject: [PATCH 03/16] updated average_prediction --- .../tabular_notebooks/explaining_tabpfn.ipynb | 115 ++++++++++++------ 1 file changed, 79 insertions(+), 36 deletions(-) diff --git a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb index e53b2fe1..6c526bf2 100644 --- a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb +++ b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb @@ -28,8 +28,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2025-01-09T13:37:42.902399Z", - "start_time": "2025-01-09T13:37:42.888410Z" + "end_time": "2025-01-09T13:54:21.989586Z", + "start_time": "2025-01-09T13:54:21.975580Z" } }, "source": [ @@ -37,19 +37,24 @@ "\n", "import shapiq\n", "import tabpfn\n", + "import torch\n", "\n", - "print(\"shapiq version: \", shapiq.__version__, \"tabpfn version: \", version(\"tabpfn\"))" + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "print(\"shapiq version: \", shapiq.__version__, \"tabpfn version: \", version(\"tabpfn\"))\n", + "print(\"Device: \", device)" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "shapiq version: 1.1.1 tabpfn version: 2.0.0\n" + "shapiq version: 1.1.1 tabpfn version: 2.0.0\n", + "Device: cpu\n" ] } ], - "execution_count": 28 + "execution_count": 39 }, { "metadata": {}, @@ -279,7 +284,7 @@ }, "cell_type": "code", "source": [ - "model = tabpfn.TabPFNRegressor(n_estimators=4, n_jobs=4)\n", + "model = tabpfn.TabPFNRegressor(n_estimators=4, n_jobs=4, device=\"cuda\")\n", "model.fit(x_train, y_train)" ], "id": "a1100c73d7b0867e", @@ -812,14 +817,9 @@ "y_explain = y_test[0]\n", "\n", "prediction = model.predict(x_explain.reshape(1, -1))[0]\n", - "print(\n", - " \"Prediction: \",\n", - " prediction,\n", - " \"True value: \",\n", - " y_explain,\n", - " \"Average prediction: \",\n", - " average_prediction,\n", - ")" + "print(\"Prediction: \", prediction)\n", + "print(\"True value: \", y_explain)\n", + "print(\"Average prediction: \", average_prediction)" ], "id": "19b2cf3dd8a8d751", "outputs": [ @@ -842,9 +842,6 @@ }, "cell_type": "code", "source": [ - "import numpy as np\n", - "\n", - "\n", "class TabPFNGame(shapiq.Game):\n", " \"\"\"The TabPFN Game class implementation a remove-and-\"retrain\" strategy to explain the predictions of TabPFN.\"\"\"\n", "\n", @@ -871,8 +868,8 @@ " x_train_coal = self.x_train[:, coalition]\n", " x_explain_coal = self.x_explain[coalition].reshape(1, -1)\n", " self.model.fit(x_train_coal, self.y_train)\n", - " prediction = float(self.model.predict(x_explain_coal)[0])\n", - " output[i] = prediction\n", + " pred = float(self.model.predict(x_explain_coal)[0])\n", + " output[i] = pred\n", " return output" ], "id": "37a977c5f4a88aee", @@ -882,8 +879,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:38:27.570490Z", - "start_time": "2025-01-09T13:38:27.539815Z" + "end_time": "2025-01-09T13:54:21.973587Z", + "start_time": "2025-01-09T13:38:59.204920Z" } }, "cell_type": "code", @@ -896,24 +893,36 @@ "id": "7b2606969b5bab0", "outputs": [ { - "ename": "TypeError", - "evalue": "TabPFNGame.__init__() got an unexpected keyword argument 'normalization_value'", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[37], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m game \u001b[38;5;241m=\u001b[39m \u001b[43mTabPFNGame\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_explain\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnormalization_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maverage_prediction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2\u001b[0m game\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 3\u001b[0m game\u001b[38;5;241m.\u001b[39mprecompute()\n", - "\u001b[1;31mTypeError\u001b[0m: TabPFNGame.__init__() got an unexpected keyword argument 'normalization_value'" + "name": "stdout", + "output_type": "stream", + "text": [ + "Initializing TabPFN Game\n", + "Train data shape: (300, 8) (300,)\n", + "Explain data shape: (8,)\n" ] + }, + { + "data": { + "text/plain": [ + "Evaluating game: 0%| | 0/256 [00:00" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{(): 0.0,\n", + " (0,): -0.19919682040048847,\n", + " (1,): 0.1938878673232311,\n", + " (2,): 0.03612450051107798,\n", + " (3,): 0.012084188219684957,\n", + " (4,): 0.009044129786639231,\n", + " (5,): 1.1875691359177223,\n", + " (6,): 0.013825137695848005,\n", + " (7,): 0.1340076537387331}" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 41 } ], "metadata": { From 00958999964efd02e51a03c55ce25769d662b025 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Thu, 9 Jan 2025 15:56:54 +0100 Subject: [PATCH 04/16] updated notebook --- .../tabular_notebooks/explaining_tabpfn.ipynb | 328 +++++++++++------- 1 file changed, 198 insertions(+), 130 deletions(-) diff --git a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb index 6c526bf2..6a08a095 100644 --- a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb +++ b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb @@ -28,8 +28,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2025-01-09T13:54:21.989586Z", - "start_time": "2025-01-09T13:54:21.975580Z" + "end_time": "2025-01-09T14:23:43.643915Z", + "start_time": "2025-01-09T14:23:39.475714Z" } }, "source": [ @@ -41,7 +41,8 @@ "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", - "print(\"shapiq version: \", shapiq.__version__, \"tabpfn version: \", version(\"tabpfn\"))\n", + "print(\"shapiq version: \", shapiq.__version__)\n", + "print(\"tabpfn version: \", version(\"tabpfn\"))\n", "print(\"Device: \", device)" ], "outputs": [ @@ -54,13 +55,13 @@ ] } ], - "execution_count": 39 + "execution_count": 3 }, { "metadata": {}, "cell_type": "markdown", "source": [ - "### Get the California Housing Data\n", + "## Get the California Housing Data\n", "Now let's load the California housing dataset and inspect the data." ], "id": "229e7c0478fc1c96" @@ -68,8 +69,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:37:42.980940Z", - "start_time": "2025-01-09T13:37:42.932418Z" + "end_time": "2025-01-09T14:23:43.707421Z", + "start_time": "2025-01-09T14:23:43.645913Z" } }, "cell_type": "code", @@ -219,7 +220,7 @@ "output_type": "display_data" } ], - "execution_count": 29 + "execution_count": 4 }, { "metadata": {}, @@ -238,8 +239,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:37:42.996941Z", - "start_time": "2025-01-09T13:37:42.982942Z" + "end_time": "2025-01-09T14:23:43.723421Z", + "start_time": "2025-01-09T14:23:43.709411Z" } }, "cell_type": "code", @@ -248,7 +249,7 @@ "from sklearn.model_selection import train_test_split\n", "\n", "x_train, x_test, y_train, y_test = train_test_split(\n", - " x_data.values, y_data.values, train_size=300, random_state=1\n", + " x_data.values, y_data.values, train_size=500, random_state=42\n", ")\n", "print(\"Train data shape: \", x_train.shape, y_train.shape)\n", "print(\"Test data shape: \", x_test.shape, y_test.shape)" @@ -259,18 +260,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "Train data shape: (300, 8) (300,)\n", - "Test data shape: (20340, 8) (20340,)\n" + "Train data shape: (500, 8) (500,)\n", + "Test data shape: (20140, 8) (20140,)\n" ] } ], - "execution_count": 30 + "execution_count": 5 }, { "metadata": {}, "cell_type": "markdown", "source": [ - "### Fit TabPFN\n", + "## Fit TabPFN\n", "Now that we have the data, we can fit TabPFN to the training data and make it ready for predictions." ], "id": "8be176b5890b9eaf" @@ -278,13 +279,13 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:37:43.152927Z", - "start_time": "2025-01-09T13:37:42.997939Z" + "end_time": "2025-01-09T14:23:46.264555Z", + "start_time": "2025-01-09T14:23:46.027950Z" } }, "cell_type": "code", "source": [ - "model = tabpfn.TabPFNRegressor(n_estimators=4, n_jobs=4, device=\"cuda\")\n", + "model = tabpfn.TabPFNRegressor(n_jobs=8, device=device)\n", "model.fit(x_train, y_train)" ], "id": "a1100c73d7b0867e", @@ -292,10 +293,10 @@ { "data": { "text/plain": [ - "TabPFNRegressor(n_estimators=4, n_jobs=4)" + "TabPFNRegressor(device=device(type='cpu'), n_jobs=8)" ], "text/html": [ - "
TabPFNRegressor(n_estimators=4, n_jobs=4)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + "
TabPFNRegressor(device=device(type='cpu'), n_jobs=8)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ] }, - "execution_count": 31, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 31 + "execution_count": 6 }, { "metadata": {}, "cell_type": "markdown", "source": [ "When we have the \"trained\" model, we can use it to predict the house prices.\n", - "These predictions are very competitive with the state-of-the-art models." + "These predictions are very competitive with the state-of-the-art models.\n", + "Note that TabPFN at the end of the day is still quite a big transformer model, which needs a GPU to run efficiently." ], "id": "25603c1d4540f2c5" }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Evaluate TabPFN", + "id": "db580ea3627edae2" + }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:38:21.998978Z", - "start_time": "2025-01-09T13:37:43.154928Z" + "end_time": "2025-01-09T14:26:06.514631Z", + "start_time": "2025-01-09T14:23:49.733597Z" } }, "cell_type": "code", - "source": "predictions = model.predict(x_test[:2000])", + "source": [ + "# we downsample the test data for more efficient inference\n", + "x_test, y_test = x_test[:2000], y_test[:2000]\n", + "predictions = model.predict(x_test)" + ], "id": "d36110af9fa1b058", "outputs": [], - "execution_count": 32 + "execution_count": 7 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:38:22.014134Z", - "start_time": "2025-01-09T13:38:21.999971Z" + "end_time": "2025-01-09T14:26:20.308666Z", + "start_time": "2025-01-09T14:26:20.288667Z" } }, "cell_type": "code", @@ -743,8 +755,8 @@ "from sklearn.metrics import mean_squared_error, r2_score\n", "import numpy as np\n", "\n", - "mse = mean_squared_error(y_test[:2000], predictions)\n", - "r2 = r2_score(y_test[:2000], predictions)\n", + "mse = mean_squared_error(y_test, predictions)\n", + "r2 = r2_score(y_test, predictions)\n", "print(\"MSE: \", mse, \"R2: \", r2)\n", "\n", "average_prediction = np.mean(predictions)\n", @@ -756,88 +768,104 @@ "name": "stdout", "output_type": "stream", "text": [ - "MSE: 0.32530222052366137 R2: 0.753441412304244\n", - "Average prediction: 2.0711837\n" + "MSE: 0.313356474514811 R2: 0.7624955351447615\n", + "Average prediction: 2.0460324\n" ] } ], - "execution_count": 33 + "execution_count": 8 }, { "metadata": {}, "cell_type": "markdown", "source": [ - "# Explain TabPFN with shapiq\n", - "Now that we see how TabPFN performs, we can use shapiq to explain the predictions." + "## Explain TabPFN with shapiq\n", + "Now that we see how TabPFN performs, we can use shapiq to explain the predictions.\n", + "First, we will use the KernelSHAP method to explain the predictions." ], - "id": "464ced0bcf3760ea" + "id": "85a7dadbec463d65" }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:38:22.029136Z", - "start_time": "2025-01-09T13:38:22.015135Z" + "end_time": "2025-01-09T14:42:00.521476Z", + "start_time": "2025-01-09T14:41:38.073888Z" } }, "cell_type": "code", "source": [ - "# explainer = shapiq.Explainer(model, data=x_test[:1000], index=\"FSII\", max_order=2, imputer=\"baseline\")\n", - "# explainer._imputer.verbose = True\n", - "#\n", - "# x_explain = x_test[0]\n", - "#\n", - "# sv = explainer.explain(x_explain)\n", - "# sv.plot_force(feature_names=feature_names)" + "x_explain = x_test[10]\n", + "y_explain = y_test[10]\n", + "\n", + "prediction = model.predict(x_explain.reshape(1, -1))[0]\n", + "print(\"Prediction: \", prediction)\n", + "print(\"True value: \", y_explain)\n", + "print(\"Average prediction: \", average_prediction)" ], - "id": "41314e231db2e986", - "outputs": [], - "execution_count": 34 + "id": "15e30787bb74905", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prediction: 2.774252\n", + "True value: 2.938\n", + "Average prediction: 2.0460324\n" + ] + } + ], + "execution_count": 12 }, { "metadata": {}, "cell_type": "markdown", "source": [ - "# Explaining TabPFN with Remove-and-\"Retrain\"\n", + "### Traditional Explanation with Baseline Imputation\n", + "The traditional way to explain any black-box model trained on tabular data is by using imputation strategies for feature removal (excellent [paper by Covert et al.](https://jmlr.csail.mit.edu/papers/volume22/20-1316/20-1316.pdf)).\n", + "During explanations, the model is queried multiple times with different subsets of features removed.\n", + "Removed features are imputed using different strategies, such as the baseline imputation.\n", + "Baseline imputation replaces the removed features with the mean/mode of the training data.\n", "\n", - "Since TabPFN is a foundation model, it uses in-context learning to solve the classification and regression tasks.\n", - "This means that \"retraining\" the model is quite inexpensive, because we only need to provide the new data points with whatever features we want to remove." + "We can natively use the shapiq.Explainer (specifically shapiq.TabularExplainer) to explain the TabPFN model:" ], - "id": "cdba7867ce6fbbb0" + "id": "b225c897c1181eee" }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:38:27.522752Z", - "start_time": "2025-01-09T13:38:22.032140Z" + "end_time": "2025-01-09T13:38:22.029136Z", + "start_time": "2025-01-09T13:38:22.015135Z" } }, "cell_type": "code", "source": [ - "x_explain = x_test[0]\n", - "y_explain = y_test[0]\n", + "explainer = shapiq.Explainer(model, data=x_test, index=\"SV\", max_order=1, imputer=\"baseline\")\n", + "explainer._imputer.verbose = True # see the explanation progress\n", "\n", - "prediction = model.predict(x_explain.reshape(1, -1))[0]\n", - "print(\"Prediction: \", prediction)\n", - "print(\"True value: \", y_explain)\n", - "print(\"Average prediction: \", average_prediction)" + "sv = explainer.explain(x_explain)\n", + "sv.plot_force(feature_names=feature_names)" ], - "id": "19b2cf3dd8a8d751", - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Prediction: 3.4585295 True value: 3.55 Average prediction: 2.0711837\n" - ] - } + "id": "41314e231db2e986", + "outputs": [], + "execution_count": 34 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Explaining TabPFN with Remove-and-\"Retrain\"\n", + "\n", + "Since TabPFN is a foundation model, it uses in-context learning to solve the classification and regression tasks.\n", + "This means that \"retraining\" the model is quite inexpensive, because we only need to provide the new data points with whatever features we want to remove.\n", + "A nice paper by [Rundel et al.](https://arxiv.org/pdf/2403.10923) shows that this strategy is very effective for explaining models like TabPFN." ], - "execution_count": 35 + "id": "cdba7867ce6fbbb0" }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:38:27.538825Z", - "start_time": "2025-01-09T13:38:27.523744Z" + "end_time": "2025-01-09T14:42:33.870941Z", + "start_time": "2025-01-09T14:42:33.859953Z" } }, "cell_type": "code", @@ -874,13 +902,15 @@ ], "id": "37a977c5f4a88aee", "outputs": [], - "execution_count": 36 + "execution_count": 13 }, { "metadata": { + "jupyter": { + "is_executing": true + }, "ExecuteTime": { - "end_time": "2025-01-09T13:54:21.973587Z", - "start_time": "2025-01-09T13:38:59.204920Z" + "start_time": "2025-01-09T14:42:43.034242Z" } }, "cell_type": "code", @@ -897,7 +927,7 @@ "output_type": "stream", "text": [ "Initializing TabPFN Game\n", - "Train data shape: (300, 8) (300,)\n", + "Train data shape: (500, 8) (500,)\n", "Explain data shape: (8,)\n" ] }, @@ -909,36 +939,74 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "a456a025086c4209b11f144a6badd7b0" + "model_id": "56abb4bfbed1432d9b1fec1ed8229322" } }, "metadata": {}, "output_type": "display_data" } ], - "execution_count": 38 + "execution_count": null + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T14:15:53.944934Z", + "start_time": "2025-01-09T14:15:53.934923Z" + } + }, + "cell_type": "code", + "source": "game.load_values(\"tabpfn_values.npz\")", + "id": "a96e3795ea1df8a0", + "outputs": [], + "execution_count": 55 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:54:33.282191Z", - "start_time": "2025-01-09T13:54:33.246177Z" + "end_time": "2025-01-09T14:15:55.311868Z", + "start_time": "2025-01-09T14:15:55.286621Z" } }, "cell_type": "code", "source": [ "approximator = shapiq.KernelSHAP(n=game.n_players, random_state=42)\n", - "sv = approximator.approximate(budget=2**game.n_players, game=game)" + "sv = approximator.approximate(budget=2**game.n_players, game=game)\n", + "sv.baseline_value = average_prediction" ], "id": "7203ae35139cc10a", "outputs": [], - "execution_count": 40 + "execution_count": 56 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T14:15:56.587429Z", + "start_time": "2025-01-09T14:15:56.575416Z" + } + }, + "cell_type": "code", + "source": "sv.baseline_value", + "id": "5258964a22c66031", + "outputs": [ + { + "data": { + "text/plain": [ + "2.0711837" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 57 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T13:54:35.424675Z", - "start_time": "2025-01-09T13:54:35.150372Z" + "end_time": "2025-01-09T14:15:58.680889Z", + "start_time": "2025-01-09T14:15:58.169856Z" } }, "cell_type": "code", @@ -953,7 +1021,7 @@ "text/plain": [ "
" ], - "image/png": "" + "image/png": "" }, "metadata": {}, "output_type": "display_data" @@ -972,12 +1040,12 @@ " (7,): 0.1340076537387331}" ] }, - "execution_count": 41, + "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 41 + "execution_count": 58 } ], "metadata": { From 9eea826e5981796d73e5c4e425cd58bdcf94c9a0 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Thu, 9 Jan 2025 16:15:48 +0100 Subject: [PATCH 05/16] updated notebook --- .../tabular_notebooks/explaining_tabpfn.ipynb | 79 ++++++++++++++++--- 1 file changed, 70 insertions(+), 9 deletions(-) diff --git a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb index 6a08a095..a3f2d820 100644 --- a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb +++ b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb @@ -857,7 +857,11 @@ "\n", "Since TabPFN is a foundation model, it uses in-context learning to solve the classification and regression tasks.\n", "This means that \"retraining\" the model is quite inexpensive, because we only need to provide the new data points with whatever features we want to remove.\n", - "A nice paper by [Rundel et al.](https://arxiv.org/pdf/2403.10923) shows that this strategy is very effective for explaining models like TabPFN." + "A nice paper by [Rundel et al.](https://arxiv.org/pdf/2403.10923) shows that this strategy is very effective for explaining models like TabPFN.\n", + "\n", + "Because of ``shapiq``'s notion of cooperative games, we can easily implement the remove-and-\"retrain\" strategy for TabPFN as game.\n", + "The game takes the model, the training data, the explanation data, and the average prediction as input.\n", + "The value function of the game performs the remove-and-\"retrain\" strategy for TabPFN and returns the predictions for the coalitions." ], "id": "cdba7867ce6fbbb0" }, @@ -871,7 +875,15 @@ "cell_type": "code", "source": [ "class TabPFNGame(shapiq.Game):\n", - " \"\"\"The TabPFN Game class implementation a remove-and-\"retrain\" strategy to explain the predictions of TabPFN.\"\"\"\n", + " \"\"\"The TabPFN Game class implementation a remove-and-\"retrain\" strategy to explain the predictions of TabPFN.\n", + "\n", + " Args:\n", + " model: The TabPFN model.\n", + " x_train: The training data.\n", + " y_train: The training labels.\n", + " x_explain: The data point to explain.\n", + " average_prediction: The average prediction of the model.\n", + " \"\"\"\n", "\n", " def __init__(self, model, x_train, y_train, x_explain, average_prediction):\n", " self.model = model\n", @@ -904,6 +916,12 @@ "outputs": [], "execution_count": 13 }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "With this game implementation we can now use helper functions from ``shapiq.Game`` like ``precompute`` to precompute the values of the game to speed up the explanation process.", + "id": "c8b473a6c67a54a2" + }, { "metadata": { "jupyter": { @@ -915,10 +933,10 @@ }, "cell_type": "code", "source": [ - "game = TabPFNGame(model, x_train, y_train, x_explain, average_prediction)\n", - "game.verbose = True\n", - "game.precompute()\n", - "game.save_values(\"tabpfn_values.npz\")" + "tabpfn_game = TabPFNGame(model, x_train, y_train, x_explain, average_prediction)\n", + "tabpfn_game.verbose = True # see the pre-computation progress\n", + "tabpfn_game.precompute()\n", + "tabpfn_game.save_values(\"tabpfn_values.npz\") # save values for later" ], "id": "7b2606969b5bab0", "outputs": [ @@ -956,11 +974,54 @@ } }, "cell_type": "code", - "source": "game.load_values(\"tabpfn_values.npz\")", + "source": [ + "# re-load the game\n", + "tabpfn_game = shapiq.Game(path_to_values=\"tabpfn_values.npz\", normalize=False)" + ], "id": "a96e3795ea1df8a0", "outputs": [], "execution_count": 55 }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Now that we have evaluated all $2^d$ coalitions, we can use ``shapiq.ExactComputer`` to compute any kind of game-theoretic explanation.", + "id": "1d7f391c3f67721e" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "exact_computer = shapiq.ExactComputer(n_players=tabpfn_game.n_players, game_fun=tabpfn_game)\n", + "sv = exact_computer(index=\"SV\", order=1) # compute the Shapley values\n", + "fsii = exact_computer(index=\"FSII\", order=2) # compute Faithful Shapley Interaction values" + ], + "id": "1887b05e6bd7cda8" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "sv.plot_force(feature_names=feature_names)", + "id": "7bfdd3a9e1ff6b1d" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "fsii.plot_force(feature_names=feature_names)", + "id": "7df6eae3201659ab" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "We can also approximate the game using KernelSHAP:", + "id": "baf13f27f8b50652" + }, { "metadata": { "ExecuteTime": { @@ -970,8 +1031,8 @@ }, "cell_type": "code", "source": [ - "approximator = shapiq.KernelSHAP(n=game.n_players, random_state=42)\n", - "sv = approximator.approximate(budget=2**game.n_players, game=game)\n", + "approximator = shapiq.KernelSHAP(n=tabpfn_game.n_players, random_state=42)\n", + "sv = approximator.approximate(budget=100, game=tabpfn_game)\n", "sv.baseline_value = average_prediction" ], "id": "7203ae35139cc10a", From 0c639db490552d0ea854a4a2c4f07d9c56d88dab Mon Sep 17 00:00:00 2001 From: Maximilian Date: Thu, 9 Jan 2025 16:56:57 +0100 Subject: [PATCH 06/16] updated notebook --- .../tabular_notebooks/explaining_tabpfn.ipynb | 115 ++++++++++++------ 1 file changed, 75 insertions(+), 40 deletions(-) diff --git a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb index a3f2d820..f69f02c5 100644 --- a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb +++ b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb @@ -924,10 +924,8 @@ }, { "metadata": { - "jupyter": { - "is_executing": true - }, "ExecuteTime": { + "end_time": "2025-01-09T15:43:32.323302Z", "start_time": "2025-01-09T14:42:43.034242Z" } }, @@ -964,13 +962,13 @@ "output_type": "display_data" } ], - "execution_count": null + "execution_count": 14 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T14:15:53.944934Z", - "start_time": "2025-01-09T14:15:53.934923Z" + "end_time": "2025-01-09T15:49:05.221527Z", + "start_time": "2025-01-09T15:49:05.210527Z" } }, "cell_type": "code", @@ -980,7 +978,7 @@ ], "id": "a96e3795ea1df8a0", "outputs": [], - "execution_count": 55 + "execution_count": 15 }, { "metadata": {}, @@ -989,32 +987,69 @@ "id": "1d7f391c3f67721e" }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T15:49:07.000073Z", + "start_time": "2025-01-09T15:49:06.965071Z" + } + }, "cell_type": "code", - "outputs": [], - "execution_count": null, "source": [ "exact_computer = shapiq.ExactComputer(n_players=tabpfn_game.n_players, game_fun=tabpfn_game)\n", "sv = exact_computer(index=\"SV\", order=1) # compute the Shapley values\n", "fsii = exact_computer(index=\"FSII\", order=2) # compute Faithful Shapley Interaction values" ], - "id": "1887b05e6bd7cda8" + "id": "1887b05e6bd7cda8", + "outputs": [], + "execution_count": 16 }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T15:49:10.264683Z", + "start_time": "2025-01-09T15:49:09.377875Z" + } + }, "cell_type": "code", - "outputs": [], - "execution_count": null, "source": "sv.plot_force(feature_names=feature_names)", - "id": "7bfdd3a9e1ff6b1d" + "id": "7bfdd3a9e1ff6b1d", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 17 }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T15:49:10.671694Z", + "start_time": "2025-01-09T15:49:10.265674Z" + } + }, "cell_type": "code", - "outputs": [], - "execution_count": null, "source": "fsii.plot_force(feature_names=feature_names)", - "id": "7df6eae3201659ab" + "id": "7df6eae3201659ab", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 18 }, { "metadata": {}, @@ -1025,8 +1060,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T14:15:55.311868Z", - "start_time": "2025-01-09T14:15:55.286621Z" + "end_time": "2025-01-09T15:49:12.460836Z", + "start_time": "2025-01-09T15:49:12.441828Z" } }, "cell_type": "code", @@ -1037,13 +1072,13 @@ ], "id": "7203ae35139cc10a", "outputs": [], - "execution_count": 56 + "execution_count": 19 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T14:15:56.587429Z", - "start_time": "2025-01-09T14:15:56.575416Z" + "end_time": "2025-01-09T15:49:14.514229Z", + "start_time": "2025-01-09T15:49:14.499231Z" } }, "cell_type": "code", @@ -1053,21 +1088,21 @@ { "data": { "text/plain": [ - "2.0711837" + "2.0460324" ] }, - "execution_count": 57, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 57 + "execution_count": 20 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T14:15:58.680889Z", - "start_time": "2025-01-09T14:15:58.169856Z" + "end_time": "2025-01-09T15:49:17.275190Z", + "start_time": "2025-01-09T15:49:16.992815Z" } }, "cell_type": "code", @@ -1082,7 +1117,7 @@ "text/plain": [ "
" ], - "image/png": "" + "image/png": "" }, "metadata": {}, "output_type": "display_data" @@ -1091,22 +1126,22 @@ "data": { "text/plain": [ "{(): 0.0,\n", - " (0,): -0.19919682040048847,\n", - " (1,): 0.1938878673232311,\n", - " (2,): 0.03612450051107798,\n", - " (3,): 0.012084188219684957,\n", - " (4,): 0.009044129786639231,\n", - " (5,): 1.1875691359177223,\n", - " (6,): 0.013825137695848005,\n", - " (7,): 0.1340076537387331}" + " (0,): -0.11543297276783156,\n", + " (1,): -0.01829630308819613,\n", + " (2,): -0.02245987510734033,\n", + " (3,): -0.019211207803047345,\n", + " (4,): -0.16020457476047298,\n", + " (5,): 0.6133794761867309,\n", + " (6,): -0.06019559477215868,\n", + " (7,): 0.5106405646027524}" ] }, - "execution_count": 58, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 58 + "execution_count": 22 } ], "metadata": { From 67f461a7ca418f60de613adb642c5970bc5ed1d4 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Thu, 9 Jan 2025 16:59:30 +0100 Subject: [PATCH 07/16] added data --- .../tabular_notebooks/tabpfn_values.npz | Bin 0 -> 2202 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/source/notebooks/tabular_notebooks/tabpfn_values.npz diff --git a/docs/source/notebooks/tabular_notebooks/tabpfn_values.npz b/docs/source/notebooks/tabular_notebooks/tabpfn_values.npz new file mode 100644 index 0000000000000000000000000000000000000000..a14d6354cfdaa6c3cba195bb6a79f0b44b9133f9 GIT binary patch literal 2202 zcma)7e>_urA0HdChNWF2CE}_ivip)BiIS#Nv+F7Dqqem)@@w-WQZ^R}3B!0sevIfT zs_o{I3|nsUBR}THl!x+bE~T{hZhLoYy(${XXY&KA-pbywN9s zAQ%7u*q~TO0K<&mqw+uiK&;{v41fV5@j(%Ua0_C{RV4s)?WzI>SVgQx;TgI6nTV|` zSY~~nMLH;pf_J1RrYYOMYv`V+-iUws&?jH_LU2NE&#Cm-lwDi= zA07K(o<=B%l`6YMl)=k$(PwTv#5-3-8CiPnoTtnvTJR z2qSKZ8i(_}k|*DWcesES+S`|AAtI&Up&GnRlZ`*lsp{li8(qdiWv#rHQr2aG^{d%> zsg0N@G%~Ow1T=e9dX`XE@UW0pwcb#ss(B*ow1Fztr?u-T!N21`jyt(iC^$!rY1$^m zlopBREn7bn61tBEzuz&#?n|b-I>((@`X@E?3aT-s9z8zLkZk*kt1_aS`)SGnRft5q zcdzzfjA7pyYio-z>eyJ}%LnSMUVBGw|5(G0Uk1uAwP;y!yML2fr^hE%Sp_P`ywTv( z3gepWGLyphhHaWE&^|KkD!EhCgHfmOX*du1jlx#?R=U;pP>%)V^01Td{7;Q`)>3U! zluVQ-W5(Vg$wF>M^o&xsCo0W@|rtZozv#F zT1h)2(5iMLb@bZTjlXv+nI}@rQN!)ljA91IPLDG;X%Zo>$@uw|iYRBK`=^oK-=vVH z$JpCBv-N|ToB;O?OJ|wIOpmI#MK_a~Q-8O6sBK&4?zoL#GMwYUGqZ)qROjBB>-QYL z{o$7VOR@QI%@aEL{8A?$AGlT0ej2zutv+*oq1|jkFK2r=Pm6=%sL{ssKkXlvSRBP% z=wKTuc<=NUsrHDLb+W_sflwz%J8WseW%dn4 zMO?KYY_r^*N1-#Xr^}i)&2%0ylYWdTu&w;AS-y`tUr|)TeXfq1b&BU)3x`U`@9)*O ze5OY+dIv=qE{A5n;Mh@{4G}dQt1uGI#LNUf+&<1v4HZ8?+^^ZWxtnL5lAEP1`_YPfInxvpi3FLg1Vw-+kosMTSXJF=wA z%jrE_op>2iRM%ly+tN@Gt+M#y(IIjDC4r+0A)L{&mF%SB3!QqREVhki$S1WPcU5zr z;V_6snRJCbWdJJXlm$&XnjpgrTJ}IKB~1K?%L*%*zw@7_#NDfQuFmdR>c|{rn$lt~&_7bWvi6sgeP5qGYnhNLJI4S~t){VF&n; zeN{g2?}in%?07Tr1%O^R)Y@EU_xLIOXQ*i^hMErvALM`8pG5q>!M2K8jk_KPP)+-w z^MQ6p*|FKGCeU%fWy39CGuVJ3xakfk_}37yMF%pfNRS87yyex?EXb3DbC@I+aR#~{ zf0o4_d(5(|kO{_~5-Vi>Z@H3BALMLYgpa(EqQ}NiqKG5xw=1HFxrgOl)Z5d)4u)?_ zcq-z-51D0bK9=NQUKvQ13P#7{E|dI^1v=0) z?x~zZBL;6c-%`flk?5))4I|YxG6y078c+cmxLXHnXsDCmZP;f80>@2YMl@lGiQ7S2LKNFoYW$BjN_TXwTOvyr1RaNyHt#o6=$$umPUrmA(7OYUq zHTha3LIbbSAEI;HL>H5KV7DZ>;m{&PoO`|h<}?3z^n=g;5r){FjYUcbV< y*Z%_Fa@Tj_Yc5$8u)Z(fQrFk-YpRl(g8JvxM4wPmJQn}}DYiMqF+%|b0R9b7Rk Date: Fri, 10 Jan 2025 11:20:32 +0100 Subject: [PATCH 08/16] removes warning when class_index is not provided --- CHANGELOG.md | 1 + shapiq/explainer/utils.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ab32c69..1f1bdf3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## Changelog ### Development +- removes warning when class_index is not provided in explainers [#298](https://github.com/mmschlk/shapiq/issues/298) - adds the `sentence_plot` function to the `plot` module to visualize the contributions of words to a language model prediction in a sentence-like format - makes abbreviations in the `plot` module optional [#281](https://github.com/mmschlk/shapiq/issues/281) - adds the `upset_plot` function to the `plot` module to visualize the interactions of higher-order [#290](https://github.com/mmschlk/shapiq/issues/290) diff --git a/shapiq/explainer/utils.py b/shapiq/explainer/utils.py index 65b1ea07..576a4518 100644 --- a/shapiq/explainer/utils.py +++ b/shapiq/explainer/utils.py @@ -1,7 +1,6 @@ """This module contains utility functions for the explainer module.""" import re -import warnings from typing import Any, Callable, Optional, TypeVar import numpy as np @@ -118,7 +117,6 @@ def get_predict_function_and_model_type( ) if class_index is None: - warnings.warn(WARNING_NO_CLASS_INDEX) class_index = 1 def _predict_function_with_class_index(model: ModelType, data: np.ndarray) -> np.ndarray: From 0cc6cd8b11c31524fa2ca615656c389e1c440931 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 11:23:10 +0100 Subject: [PATCH 09/16] adds TODO for xgboost tests --- shapiq/explainer/tree/validation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/shapiq/explainer/tree/validation.py b/shapiq/explainer/tree/validation.py index 79e7f170..db6455b4 100644 --- a/shapiq/explainer/tree/validation.py +++ b/shapiq/explainer/tree/validation.py @@ -29,6 +29,7 @@ "lightgbm.sklearn.LGBMRegressor", "lightgbm.sklearn.LGBMClassifier", "lightgbm.basic.Booster", + # TODO: add xgboost to the list of supported models and check if all tests pass # xboost? } From f2b7ed079c37d450c063c423b4b45d9397cdfc3d Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 11:24:25 +0100 Subject: [PATCH 10/16] allows games to be initialized from values and be not normalized --- shapiq/game_theory/exact.py | 4 +--- shapiq/games/base.py | 3 +++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/shapiq/game_theory/exact.py b/shapiq/game_theory/exact.py index 999cb0a8..d376c2ab 100644 --- a/shapiq/game_theory/exact.py +++ b/shapiq/game_theory/exact.py @@ -125,6 +125,7 @@ def __call__(self, index: str, order: int = None) -> InteractionValues: elif index in self.available_indices: computation_function = self._index_mapping[index] computed_index: InteractionValues = computation_function(index=index, order=order) + computed_index.baseline_value = self.baseline_value self._computed[(index, order)] = computed_index return copy.deepcopy(computed_index) else: @@ -158,9 +159,6 @@ def _evaluate_game(self): def compute_game_values(self) -> tuple[float, np.ndarray[float], dict[tuple[int], int]]: """Evaluates the game on the powerset of all coalitions. - Args: - game_fun: A callable game - Returns: baseline value (empty prediction), all game values, and the lookup dictionary """ diff --git a/shapiq/games/base.py b/shapiq/games/base.py index d274b64c..f5bda0cf 100644 --- a/shapiq/games/base.py +++ b/shapiq/games/base.py @@ -131,6 +131,9 @@ def __init__( if path_to_values is not None: self.load_values(path_to_values, precomputed=True) self.game_id = path_to_values.split(os.path.sep)[-1].split(".")[0] + # if game should not be normalized, reset normalization value to 0 + if not normalize and self.normalization_value != 0: + self.normalization_value = 0.0 # define some handy coalition variables self.empty_coalition = np.zeros(self.n_players, dtype=bool) From 13adbc62efa4bba5b882775e63aa7b5611a0de35 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 12:59:39 +0100 Subject: [PATCH 11/16] updated tabpfn notebook --- CHANGELOG.md | 1 + .../tabular_notebooks/explaining_tabpfn.ipynb | 464 ++++++++---------- ...bpfn_values.npz => tabpfn_values_copy.npz} | Bin shapiq/__init__.py | 2 +- 4 files changed, 218 insertions(+), 249 deletions(-) rename docs/source/notebooks/tabular_notebooks/{tabpfn_values.npz => tabpfn_values_copy.npz} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f1bdf3e..820f4f62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## Changelog ### Development +- adds a TabPFN example notebook to the documentation - removes warning when class_index is not provided in explainers [#298](https://github.com/mmschlk/shapiq/issues/298) - adds the `sentence_plot` function to the `plot` module to visualize the contributions of words to a language model prediction in a sentence-like format - makes abbreviations in the `plot` module optional [#281](https://github.com/mmschlk/shapiq/issues/281) diff --git a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb index f69f02c5..a43907c7 100644 --- a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb +++ b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb @@ -28,8 +28,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2025-01-09T14:23:43.643915Z", - "start_time": "2025-01-09T14:23:39.475714Z" + "end_time": "2025-01-10T11:47:28.951329Z", + "start_time": "2025-01-10T11:47:24.953799Z" } }, "source": [ @@ -50,12 +50,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "shapiq version: 1.1.1 tabpfn version: 2.0.0\n", + "shapiq version: 1.1.1.dev\n", + "tabpfn version: 2.0.1\n", "Device: cpu\n" ] } ], - "execution_count": 3 + "execution_count": 1 }, { "metadata": {}, @@ -69,8 +70,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T14:23:43.707421Z", - "start_time": "2025-01-09T14:23:43.645913Z" + "end_time": "2025-01-10T11:47:29.014925Z", + "start_time": "2025-01-10T11:47:28.953368Z" } }, "cell_type": "code", @@ -220,7 +221,7 @@ "output_type": "display_data" } ], - "execution_count": 4 + "execution_count": 2 }, { "metadata": {}, @@ -232,15 +233,15 @@ "\n", "In order to use TabPFN, we need to split the data into a training and testing set.\n", "Note, that TabPFN works best for **small sized datasets** (less than 10k samples).\n", - "So let's select a train set of 10k samples." + "On CPU, we can only use a very small number of training data points to fit the model. If you have a GPU, feel free to increase the number of samples." ], "id": "2d3e6649c1ae8450" }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T14:23:43.723421Z", - "start_time": "2025-01-09T14:23:43.709411Z" + "end_time": "2025-01-10T11:47:29.030865Z", + "start_time": "2025-01-10T11:47:29.016917Z" } }, "cell_type": "code", @@ -265,27 +266,30 @@ ] } ], - "execution_count": 5 + "execution_count": 3 }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Fit TabPFN\n", - "Now that we have the data, we can fit TabPFN to the training data and make it ready for predictions." + "Now that we have the data, we can fit TabPFN to the training data and make it ready for predictions. \n", + "\n", + "**Note** that TabPFN at the end of the day is still quite a big transformer model, which needs a GPU to run efficiently.\n", + "If you are on GPU, feel free to increase the number of samples in the training or testing sets in the following:" ], "id": "8be176b5890b9eaf" }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T14:23:46.264555Z", - "start_time": "2025-01-09T14:23:46.027950Z" + "end_time": "2025-01-10T11:54:31.725938Z", + "start_time": "2025-01-10T11:54:31.533466Z" } }, "cell_type": "code", "source": [ - "model = tabpfn.TabPFNRegressor(n_jobs=8, device=device)\n", + "model = tabpfn.TabPFNRegressor(n_jobs=7, device=device)\n", "model.fit(x_train, y_train)" ], "id": "a1100c73d7b0867e", @@ -293,10 +297,10 @@ { "data": { "text/plain": [ - "TabPFNRegressor(device=device(type='cpu'), n_jobs=8)" + "TabPFNRegressor(device=device(type='cpu'), n_jobs=7)" ], "text/html": [ - "
TabPFNRegressor(device=device(type='cpu'), n_jobs=8)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + "
TabPFNRegressor(device=device(type='cpu'), n_jobs=7)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ] }, - "execution_count": 6, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 6 + "execution_count": 10 }, { "metadata": {}, "cell_type": "markdown", "source": [ "When we have the \"trained\" model, we can use it to predict the house prices.\n", - "These predictions are very competitive with the state-of-the-art models.\n", - "Note that TabPFN at the end of the day is still quite a big transformer model, which needs a GPU to run efficiently." + "These predictions are very competitive with the state-of-the-art models." ], "id": "25603c1d4540f2c5" }, @@ -729,8 +732,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T14:26:06.514631Z", - "start_time": "2025-01-09T14:23:49.733597Z" + "end_time": "2025-01-10T11:49:48.449670Z", + "start_time": "2025-01-10T11:47:29.286685Z" } }, "cell_type": "code", @@ -741,13 +744,13 @@ ], "id": "d36110af9fa1b058", "outputs": [], - "execution_count": 7 + "execution_count": 5 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T14:26:20.308666Z", - "start_time": "2025-01-09T14:26:20.288667Z" + "end_time": "2025-01-10T11:49:48.464675Z", + "start_time": "2025-01-10T11:49:48.451681Z" } }, "cell_type": "code", @@ -768,12 +771,38 @@ "name": "stdout", "output_type": "stream", "text": [ - "MSE: 0.313356474514811 R2: 0.7624955351447615\n", - "Average prediction: 2.0460324\n" + "MSE: 0.27140348437031175 R2: 0.7964621203301282\n", + "Average prediction: 2.0861094\n" + ] + } + ], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-10T11:54:38.179564Z", + "start_time": "2025-01-10T11:54:38.008602Z" + } + }, + "cell_type": "code", + "source": [ + "# we will reset the model to less training data because we are on CPU\n", + "if device == torch.device(\"cpu\"):\n", + " print(\"Resetting the model to less training data\")\n", + " model.fit(x_train[:200], y_train[:200])" + ], + "id": "7f6253cf223e9136", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Resetting the model to less training data\n" ] } ], - "execution_count": 8 + "execution_count": 11 }, { "metadata": {}, @@ -788,14 +817,14 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-09T14:42:00.521476Z", - "start_time": "2025-01-09T14:41:38.073888Z" + "end_time": "2025-01-10T11:54:46.548358Z", + "start_time": "2025-01-10T11:54:40.883018Z" } }, "cell_type": "code", "source": [ - "x_explain = x_test[10]\n", - "y_explain = y_test[10]\n", + "x_explain = x_data.values[0]\n", + "y_explain = y_data.values[0]\n", "\n", "prediction = model.predict(x_explain.reshape(1, -1))[0]\n", "print(\"Prediction: \", prediction)\n", @@ -808,9 +837,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Prediction: 2.774252\n", - "True value: 2.938\n", - "Average prediction: 2.0460324\n" + "Prediction: 4.2348824\n", + "True value: 4.526\n", + "Average prediction: 2.0861094\n" ] } ], @@ -826,28 +855,45 @@ "Removed features are imputed using different strategies, such as the baseline imputation.\n", "Baseline imputation replaces the removed features with the mean/mode of the training data.\n", "\n", - "We can natively use the shapiq.Explainer (specifically shapiq.TabularExplainer) to explain the TabPFN model:" + "We can natively use the ``shapiq.Explainer`` (specifically ``shapiq.TabularExplainer``) to explain the TabPFN model:" ], "id": "b225c897c1181eee" }, { "metadata": { + "jupyter": { + "is_executing": true + }, "ExecuteTime": { - "end_time": "2025-01-09T13:38:22.029136Z", - "start_time": "2025-01-09T13:38:22.015135Z" + "start_time": "2025-01-10T11:54:51.476773Z" } }, "cell_type": "code", "source": [ - "explainer = shapiq.Explainer(model, data=x_test, index=\"SV\", max_order=1, imputer=\"baseline\")\n", + "explainer = shapiq.Explainer(model, data=x_test[:50], index=\"SV\", max_order=1, imputer=\"baseline\")\n", "explainer._imputer.verbose = True # see the explanation progress\n", "\n", - "sv = explainer.explain(x_explain)\n", - "sv.plot_force(feature_names=feature_names)" + "shapley_values = explainer.explain(x_explain)\n", + "shapley_values.plot_force(feature_names=feature_names)" ], "id": "41314e231db2e986", - "outputs": [], - "execution_count": 34 + "outputs": [ + { + "data": { + "text/plain": [ + "Evaluating game: 0%| | 0/256 [00:00" - ], - "image/png": "" - }, - "metadata": {}, - "output_type": "display_data" - } + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "From the Shapley values, we can see that longitude has a very high positive impact on the prediction and increases the house price.\n", + "When we compute second order Shapley interactions (``index=FSII``, ``order=2``) we can see that the interaction between latitude and longitude together has a positive impact.\n", + "This suggests that the model learns the interactions between latitude and longitude features.\n", + "\n", + "Interestingly, longitude also has a couple of negative interactions with other features such as the median income, which decreases the house price.\n" ], - "execution_count": 17 + "id": "fceea72f0e13feb1" }, { - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-09T15:49:10.671694Z", - "start_time": "2025-01-09T15:49:10.265674Z" - } - }, + "metadata": {}, "cell_type": "code", "source": "fsii.plot_force(feature_names=feature_names)", "id": "7df6eae3201659ab", - "outputs": [ - { - "data": { - "text/plain": [ - "
" - ], - "image/png": "" - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "execution_count": 18 + "outputs": [], + "execution_count": null }, { "metadata": {}, "cell_type": "markdown", - "source": "We can also approximate the game using KernelSHAP:", - "id": "baf13f27f8b50652" + "source": [ + "### Explain TabPFN with Approximation Methods for Shapley Values and Interactions\n", + "When we have a large number of features, the exact computation of Shapley values and interactions can be computationally expensive.\n", + "For this reason, we can use approximation methods like KernelSHAP, KernelSHAP-IQ or Faithful Regression to approximate the Shapley values and interactions with a computational budget.\n", + "\n", + "To illustrate the approximation methods, we use the same TabPFN game (which has only 8 features) but reduce the computational budget to 50 model evaluations.\n", + "First, we approximate the Shapley values with KernelSHAP:" + ], + "id": "e9b187c6a678a8a8" }, { - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-09T15:49:12.460836Z", - "start_time": "2025-01-09T15:49:12.441828Z" - } - }, + "metadata": {}, "cell_type": "code", "source": [ "approximator = shapiq.KernelSHAP(n=tabpfn_game.n_players, random_state=42)\n", - "sv = approximator.approximate(budget=100, game=tabpfn_game)\n", - "sv.baseline_value = average_prediction" + "sv = approximator.approximate(budget=50, game=tabpfn_game)\n", + "sv.plot_force(feature_names=feature_names)" ], "id": "7203ae35139cc10a", "outputs": [], - "execution_count": 19 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-09T15:49:14.514229Z", - "start_time": "2025-01-09T15:49:14.499231Z" - } - }, - "cell_type": "code", - "source": "sv.baseline_value", - "id": "5258964a22c66031", - "outputs": [ - { - "data": { - "text/plain": [ - "2.0460324" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 20 + "execution_count": null }, { - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-09T15:49:17.275190Z", - "start_time": "2025-01-09T15:49:16.992815Z" - } - }, + "metadata": {}, "cell_type": "code", "source": [ - "sv.plot_force(feature_names=feature_names)\n", - "sv.dict_values" + "approximator = shapiq.RegressionFSII(n=tabpfn_game.n_players, random_state=42, max_order=2)\n", + "fsii = approximator.approximate(budget=50, game=tabpfn_game)\n", + "fsii.plot_force(feature_names=feature_names)" ], "id": "c0baa11868d4769e", - "outputs": [ - { - "data": { - "text/plain": [ - "
" - ], - "image/png": "" - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "{(): 0.0,\n", - " (0,): -0.11543297276783156,\n", - " (1,): -0.01829630308819613,\n", - " (2,): -0.02245987510734033,\n", - " (3,): -0.019211207803047345,\n", - " (4,): -0.16020457476047298,\n", - " (5,): 0.6133794761867309,\n", - " (6,): -0.06019559477215868,\n", - " (7,): 0.5106405646027524}" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 22 + "outputs": [], + "execution_count": null } ], "metadata": { diff --git a/docs/source/notebooks/tabular_notebooks/tabpfn_values.npz b/docs/source/notebooks/tabular_notebooks/tabpfn_values_copy.npz similarity index 100% rename from docs/source/notebooks/tabular_notebooks/tabpfn_values.npz rename to docs/source/notebooks/tabular_notebooks/tabpfn_values_copy.npz diff --git a/shapiq/__init__.py b/shapiq/__init__.py index 649cac94..a7087143 100644 --- a/shapiq/__init__.py +++ b/shapiq/__init__.py @@ -2,7 +2,7 @@ the well established Shapley value and its generalization to interaction. """ -__version__ = "1.1.1" +__version__ = "1.1.1.dev" # approximator classes from .approximator import ( From f0f249a82f8b3a559e50de65ede5b5030cc5e00a Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 13:09:48 +0100 Subject: [PATCH 12/16] ran tree notebooks --- .../treeshapiq_custom_tree.ipynb | 22 +++---- .../tree_notebooks/treeshapiq_lightgbm.ipynb | 58 +++++++++---------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/docs/source/notebooks/tree_notebooks/treeshapiq_custom_tree.ipynb b/docs/source/notebooks/tree_notebooks/treeshapiq_custom_tree.ipynb index 594ec78c..bd6a7519 100644 --- a/docs/source/notebooks/tree_notebooks/treeshapiq_custom_tree.ipynb +++ b/docs/source/notebooks/tree_notebooks/treeshapiq_custom_tree.ipynb @@ -21,8 +21,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:14.147010Z", - "start_time": "2024-11-07T15:17:11.982766Z" + "end_time": "2025-01-10T12:07:42.287579Z", + "start_time": "2025-01-10T12:07:40.761213Z" } }, "source": [ @@ -36,7 +36,7 @@ { "data": { "text/plain": [ - "{'shapiq': '1.1.0'}" + "{'shapiq': '1.1.1.dev'}" ] }, "execution_count": 1, @@ -89,8 +89,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:14.162008Z", - "start_time": "2024-11-07T15:17:14.149012Z" + "end_time": "2025-01-10T12:07:42.302792Z", + "start_time": "2025-01-10T12:07:42.289571Z" } }, "cell_type": "code", @@ -175,8 +175,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:14.177007Z", - "start_time": "2024-11-07T15:17:14.166010Z" + "end_time": "2025-01-10T12:07:42.317798Z", + "start_time": "2025-01-10T12:07:42.304789Z" } }, "cell_type": "code", @@ -202,8 +202,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:14.192520Z", - "start_time": "2024-11-07T15:17:14.179012Z" + "end_time": "2025-01-10T12:07:42.333789Z", + "start_time": "2025-01-10T12:07:42.319796Z" } }, "cell_type": "code", @@ -235,8 +235,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:15.220588Z", - "start_time": "2024-11-07T15:17:14.194528Z" + "end_time": "2025-01-10T12:07:43.040086Z", + "start_time": "2025-01-10T12:07:42.336792Z" } }, "cell_type": "code", diff --git a/docs/source/notebooks/tree_notebooks/treeshapiq_lightgbm.ipynb b/docs/source/notebooks/tree_notebooks/treeshapiq_lightgbm.ipynb index 9728d507..1b9d434c 100644 --- a/docs/source/notebooks/tree_notebooks/treeshapiq_lightgbm.ipynb +++ b/docs/source/notebooks/tree_notebooks/treeshapiq_lightgbm.ipynb @@ -34,8 +34,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:33.551377Z", - "start_time": "2024-11-07T15:17:31.201891Z" + "end_time": "2025-01-10T12:08:12.112659Z", + "start_time": "2025-01-10T12:08:10.580868Z" } }, "source": [ @@ -51,7 +51,7 @@ { "data": { "text/plain": [ - "{'shapiq': '1.1.0', 'lightgbm': '4.5.0'}" + "{'shapiq': '1.1.1.dev', 'lightgbm': '4.5.0'}" ] }, "execution_count": 1, @@ -75,8 +75,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:33.660896Z", - "start_time": "2024-11-07T15:17:33.554378Z" + "end_time": "2025-01-10T12:08:12.191968Z", + "start_time": "2025-01-10T12:08:12.115647Z" } }, "source": [ @@ -114,8 +114,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:34.006158Z", - "start_time": "2024-11-07T15:17:33.662903Z" + "end_time": "2025-01-10T12:08:12.429777Z", + "start_time": "2025-01-10T12:08:12.192969Z" } }, "source": [ @@ -159,8 +159,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:37.346354Z", - "start_time": "2024-11-07T15:17:34.010162Z" + "end_time": "2025-01-10T12:08:14.785417Z", + "start_time": "2025-01-10T12:08:12.431779Z" } }, "source": [ @@ -180,8 +180,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:37.361347Z", - "start_time": "2024-11-07T15:17:37.349345Z" + "end_time": "2025-01-10T12:08:14.801430Z", + "start_time": "2025-01-10T12:08:14.789416Z" } }, "source": [ @@ -201,8 +201,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:40.017159Z", - "start_time": "2024-11-07T15:17:37.366486Z" + "end_time": "2025-01-10T12:08:16.270090Z", + "start_time": "2025-01-10T12:08:14.802414Z" } }, "source": [ @@ -247,8 +247,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:40.033165Z", - "start_time": "2024-11-07T15:17:40.019159Z" + "end_time": "2025-01-10T12:08:16.285632Z", + "start_time": "2025-01-10T12:08:16.272082Z" } }, "source": [ @@ -283,8 +283,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:40.049161Z", - "start_time": "2024-11-07T15:17:40.036162Z" + "end_time": "2025-01-10T12:08:16.301624Z", + "start_time": "2025-01-10T12:08:16.287627Z" } }, "source": [ @@ -333,8 +333,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:40.628294Z", - "start_time": "2024-11-07T15:17:40.051157Z" + "end_time": "2025-01-10T12:08:16.666231Z", + "start_time": "2025-01-10T12:08:16.303624Z" } }, "source": [ @@ -381,8 +381,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:40.835344Z", - "start_time": "2024-11-07T15:17:40.630292Z" + "end_time": "2025-01-10T12:08:16.825286Z", + "start_time": "2025-01-10T12:08:16.668230Z" } }, "source": [ @@ -409,8 +409,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:41.105318Z", - "start_time": "2024-11-07T15:17:40.838344Z" + "end_time": "2025-01-10T12:08:17.014318Z", + "start_time": "2025-01-10T12:08:16.828274Z" } }, "cell_type": "code", @@ -446,8 +446,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:17:43.154967Z", - "start_time": "2024-11-07T15:17:41.107319Z" + "end_time": "2025-01-10T12:08:18.289099Z", + "start_time": "2025-01-10T12:08:17.016318Z" } }, "source": [ @@ -482,8 +482,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:19:26.329690Z", - "start_time": "2024-11-07T15:17:43.156968Z" + "end_time": "2025-01-10T12:09:23.407436Z", + "start_time": "2025-01-10T12:08:18.292099Z" } }, "source": [ @@ -508,8 +508,8 @@ "cell_type": "code", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:19:26.817303Z", - "start_time": "2024-11-07T15:19:26.331688Z" + "end_time": "2025-01-10T12:09:23.782400Z", + "start_time": "2025-01-10T12:09:23.410942Z" } }, "source": "shapiq.plot.bar_plot(list_of_interaction_values, feature_names=X.columns, max_display=20)", From 525d13bbd8f475257c98fd732d5fcdcae62dc382 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 13:10:14 +0100 Subject: [PATCH 13/16] adds lgbm to tabular notebooks --- docs/source/notebooks/tabular.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/notebooks/tabular.rst b/docs/source/notebooks/tabular.rst index 362b7527..38ef02bc 100644 --- a/docs/source/notebooks/tabular.rst +++ b/docs/source/notebooks/tabular.rst @@ -8,3 +8,4 @@ The following notebooks provide basic examples of how to use the ``shapiq`` pack :maxdepth: 1 tabular_notebooks/* + tree_notebooks/treeshapiq_lightgbm.ipynb From cf98eeab447e160e2a4c4dc8db6abb7c5221dd09 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 14:13:42 +0100 Subject: [PATCH 14/16] renames game_fun to game in ExactComputer and closes #297 --- CHANGELOG.md | 1 + shapiq/game_theory/exact.py | 6 ++--- shapiq/games/base.py | 2 +- .../tests_game_theory/test_exact_computer.py | 24 +++++++++---------- tests/tests_games/test_treeshapiq_xai.py | 8 +++---- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 820f4f62..607d30c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## Changelog ### Development +- renames ``game_fun`` parameter in ``shapiq.ExactComputer`` to ``game`` [#297](https://github.com/mmschlk/shapiq/issues/297) - adds a TabPFN example notebook to the documentation - removes warning when class_index is not provided in explainers [#298](https://github.com/mmschlk/shapiq/issues/298) - adds the `sentence_plot` function to the `plot` module to visualize the contributions of words to a language model prediction in a sentence-like format diff --git a/shapiq/game_theory/exact.py b/shapiq/game_theory/exact.py index d376c2ab..85dca61d 100644 --- a/shapiq/game_theory/exact.py +++ b/shapiq/game_theory/exact.py @@ -26,7 +26,7 @@ class ExactComputer: Args: n_players: The number of players in the game. - game_fun: A callable game that takes a binary matrix of shape ``(n_coalitions, n_players)`` + game: A callable game that takes a binary matrix of shape ``(n_coalitions, n_players)`` and returns a numpy array of shape ``(n_coalitions,)`` containing the game values. evaluate_game: whether to compute the values at init (if True) or first call (False) @@ -41,12 +41,12 @@ class ExactComputer: def __init__( self, n_players: int, - game_fun: Callable[[np.ndarray], np.ndarray[float]], + game: Callable[[np.ndarray], np.ndarray[float]], evaluate_game: bool = False, ) -> None: # set parameter attributes self.n: int = n_players - self.game_fun = game_fun + self.game_fun = game # set object attributes self._grand_coalition_tuple: tuple[int] = tuple(range(self.n)) diff --git a/shapiq/games/base.py b/shapiq/games/base.py index f5bda0cf..f0178243 100644 --- a/shapiq/games/base.py +++ b/shapiq/games/base.py @@ -490,7 +490,7 @@ def exact_values(self, index: str, order: int) -> InteractionValues: "Computing the exact interaction values via brute force may take a long time." ) - exact_computer = ExactComputer(self.n_players, game_fun=self) + exact_computer = ExactComputer(self.n_players, game=self) return exact_computer(index=index, order=order) @property diff --git a/tests/tests_game_theory/test_exact_computer.py b/tests/tests_game_theory/test_exact_computer.py index ee79760b..04c04b30 100644 --- a/tests/tests_game_theory/test_exact_computer.py +++ b/tests/tests_game_theory/test_exact_computer.py @@ -19,7 +19,7 @@ def test_exact_computer_on_soum(): predicted_value = soum(np.ones(n))[0] # Compute via exactComputer - exact_computer = ExactComputer(n_players=n, game_fun=soum) + exact_computer = ExactComputer(n_players=n, game=soum) # Compute via sparse Möbius representation moebius_converter = MoebiusConverter(soum.moebius_coefficients) @@ -89,7 +89,7 @@ def test_exact_elc_computer_call(index, order): """Tests the call function for the ExactComputer.""" n = 5 soum = SOUM(n, n_basis_games=10, normalize=True) - exact_computer = ExactComputer(n_players=n, game_fun=soum) + exact_computer = ExactComputer(n_players=n, game=soum) interaction_values = exact_computer(index=index, order=order) if order is None: order = n @@ -130,7 +130,7 @@ def test_exact_computer_call(index, order): """Tests the call function for the ExactComputer.""" n = 5 soum = SOUM(n, n_basis_games=10) - exact_computer = ExactComputer(n_players=n, game_fun=soum) + exact_computer = ExactComputer(n_players=n, game=soum) interaction_values = exact_computer(index=index, order=order) if order is None: order = n @@ -145,7 +145,7 @@ def test_basic_functions(): """Tests the basic functions of the ExactComputer.""" n = 5 soum = SOUM(n, n_basis_games=10) - exact_computer = ExactComputer(n_players=n, game_fun=soum) + exact_computer = ExactComputer(n_players=n, game=soum) isinstance(repr(exact_computer), str) isinstance(str(exact_computer), str) @@ -154,7 +154,7 @@ def test_lazy_computation(): """Tests if the lazy computation (calling without params) works.""" n = 5 soum = SOUM(n, n_basis_games=10) - exact_computer = ExactComputer(n_players=n, game_fun=soum) + exact_computer = ExactComputer(n_players=n, game=soum) isinstance(repr(exact_computer), str) isinstance(str(exact_computer), str) sv = exact_computer("SV", 1) @@ -227,10 +227,10 @@ def test_permutation_symmetry(index, order, original_game): def permutation_game(X: np.ndarray): return original_game(X[:, permutation]) - exact_computer = ExactComputer(n_players=n, game_fun=original_game) + exact_computer = ExactComputer(n_players=n, game=original_game) interaction_values = exact_computer(index=index, order=order) - perm_exact_computer = ExactComputer(n_players=n, game_fun=permutation_game) + perm_exact_computer = ExactComputer(n_players=n, game=permutation_game) perm_interaction_values = perm_exact_computer(index=index, order=order) # permutation does not matter @@ -243,7 +243,7 @@ def test_warning_cii(): """Checks weather a warning is raised for the CHII index and min_order = 0.""" n = 5 soum = SOUM(n, n_basis_games=10) - exact_computer = ExactComputer(n_players=n, game_fun=soum) + exact_computer = ExactComputer(n_players=n, game=soum) with pytest.warns(UserWarning): exact_computer("CHII", 0) @@ -304,7 +304,7 @@ def _interaction(arr: np.ndarray): # dtype bool interaction_addition = np.apply_along_axis(_interaction, axis=1, arr=X) return value + interaction_addition - exact_computer = ExactComputer(n_players=n, game_fun=_game_fun) + exact_computer = ExactComputer(n_players=n, game=_game_fun) interaction_values = exact_computer(index=index, order=order) # symmetry of players with same attribution @@ -365,7 +365,7 @@ def _interaction(arr: np.ndarray): # dtype bool interaction_addition = np.apply_along_axis(_interaction, axis=1, arr=X) return value + interaction_addition - exact_computer = ExactComputer(n_players=n, game_fun=_game_fun) + exact_computer = ExactComputer(n_players=n, game=_game_fun) interaction_values = exact_computer(index=index, order=order) # no attribution for coalitions which include the null players. @@ -407,7 +407,7 @@ def _game_fun(X: np.ndarray): fist_order_coefficients = [0, 0.2, -0.1, -0.9, 0] return np.sum(fist_order_coefficients * x_as_float, axis=1) - exact_computer = ExactComputer(n_players=n, game_fun=_game_fun) + exact_computer = ExactComputer(n_players=n, game=_game_fun) interaction_values = exact_computer(index=index, order=order) for coalition, value in interaction_values.dict_values.items(): @@ -458,7 +458,7 @@ def _interaction(arr: np.ndarray): # dtype bool interaction_addition = np.apply_along_axis(_interaction, axis=1, arr=X) return value + interaction_addition - exact_computer = ExactComputer(n_players=n, game_fun=_game_fun) + exact_computer = ExactComputer(n_players=n, game=_game_fun) interaction_values = exact_computer(index=index, order=order) # no attribution for coalitions consisting of the null players. diff --git a/tests/tests_games/test_treeshapiq_xai.py b/tests/tests_games/test_treeshapiq_xai.py index fd1986e0..f9da03cc 100644 --- a/tests/tests_games/test_treeshapiq_xai.py +++ b/tests/tests_games/test_treeshapiq_xai.py @@ -75,7 +75,7 @@ def test_random_forest_selection( assert estimates.index == index # test against the exact computation - exact = ExactComputer(n_players=n_players, game_fun=game) + exact = ExactComputer(n_players=n_players, game=game) exact_values = exact(index=index, order=max_order) for interaction in powerset(range(n_players), min_size=min_order, max_size=max_order): @@ -100,7 +100,7 @@ def test_adult(): assert game.game_name == "AdultCensus_TreeSHAPIQXAI_Game" # test against the exact computation - exact = ExactComputer(n_players=game.n_players, game_fun=game) + exact = ExactComputer(n_players=game.n_players, game=game) exact_values = exact(index=index, order=max_order) for interaction in powerset(range(game.n_players), min_size=min_order, max_size=max_order): @@ -126,7 +126,7 @@ def test_california(index_order): assert game.game_name == "CaliforniaHousing_TreeSHAPIQXAI_Game" # test against the exact computation - exact = ExactComputer(n_players=game.n_players, game_fun=game) + exact = ExactComputer(n_players=game.n_players, game=game) exact_values = exact(index=index, order=max_order) for interaction in powerset(range(game.n_players), min_size=min_order, max_size=max_order): @@ -150,7 +150,7 @@ def test_bike(): assert game.game_name == "BikeSharing_TreeSHAPIQXAI_Game" # test against the exact computation - exact = ExactComputer(n_players=game.n_players, game_fun=game) + exact = ExactComputer(n_players=game.n_players, game=game) exact_values = exact(index=index, order=max_order) for interaction in powerset(range(game.n_players), min_size=min_order, max_size=max_order): From 721812c2bcdfdaea9a218d8266b2b274f59d693b Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 14:19:23 +0100 Subject: [PATCH 15/16] ran and updated notebooks --- .../basics_notebooks/custom_games.ipynb | 105 ++++-- .../basics_notebooks/data_valuation.ipynb | 352 +++++++++--------- .../basics_notebooks/sv_calculation.ipynb | 64 ++-- .../game_theory_notebooks/core.ipynb | 6 +- .../language_model_game.ipynb | 4 +- .../shapiq_scikit_learn.ipynb | 275 +++++++------- .../vision_notebooks/vision_transformer.ipynb | 6 +- 7 files changed, 428 insertions(+), 384 deletions(-) diff --git a/docs/source/notebooks/basics_notebooks/custom_games.ipynb b/docs/source/notebooks/basics_notebooks/custom_games.ipynb index a0ce982b..14f66d49 100644 --- a/docs/source/notebooks/basics_notebooks/custom_games.ipynb +++ b/docs/source/notebooks/basics_notebooks/custom_games.ipynb @@ -12,14 +12,15 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-12-17T14:23:19.696179Z", - "start_time": "2024-12-17T14:23:18.268301Z" + "end_time": "2025-01-10T12:14:05.982266Z", + "start_time": "2025-01-10T12:14:04.426262Z" } }, "cell_type": "code", "source": [ "import shapiq\n", "import numpy as np\n", + "import os\n", "\n", "shapiq.__version__" ], @@ -27,7 +28,7 @@ { "data": { "text/plain": [ - "'1.1.1'" + "'1.1.1.dev'" ] }, "execution_count": 1, @@ -86,8 +87,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-12-17T14:23:19.711170Z", - "start_time": "2024-12-17T14:23:19.698170Z" + "end_time": "2025-01-10T12:14:05.997215Z", + "start_time": "2025-01-10T12:14:05.985240Z" } }, "cell_type": "code", @@ -147,8 +148,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-12-17T14:23:19.727173Z", - "start_time": "2024-12-17T14:23:19.713181Z" + "end_time": "2025-01-10T12:14:06.013212Z", + "start_time": "2025-01-10T12:14:06.000205Z" } }, "cell_type": "code", @@ -174,8 +175,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-12-17T14:23:19.742179Z", - "start_time": "2024-12-17T14:23:19.730173Z" + "end_time": "2025-01-10T12:14:06.029218Z", + "start_time": "2025-01-10T12:14:06.014204Z" } }, "cell_type": "code", @@ -207,8 +208,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-12-17T14:23:19.758170Z", - "start_time": "2024-12-17T14:23:19.745174Z" + "end_time": "2025-01-10T12:14:06.045214Z", + "start_time": "2025-01-10T12:14:06.033206Z" } }, "cell_type": "code", @@ -245,8 +246,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-12-17T14:23:19.789577Z", - "start_time": "2024-12-17T14:23:19.760172Z" + "end_time": "2025-01-10T12:14:06.061209Z", + "start_time": "2025-01-10T12:14:06.046217Z" } }, "cell_type": "code", @@ -280,7 +281,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "e6e0bc19180b4969bae2cbcabef70fdf" + "model_id": "218e4aac6918408d8a38f1c9646509fb" } }, "metadata": {}, @@ -308,19 +309,20 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-12-17T14:23:20.357939Z", - "start_time": "2024-12-17T14:23:19.792499Z" + "end_time": "2025-01-10T12:14:06.076763Z", + "start_time": "2025-01-10T12:14:06.063214Z" } }, "cell_type": "code", "source": [ "# save the precomputed values to a file\n", - "cooking_game.save_values(\"data/cooking_game_values.npz\")\n", + "save_path = os.path.join(\"..\", \"data\", \"cooking_game_values.npz\")\n", + "cooking_game.save_values(save_path)\n", "\n", "# load the precomputed values from the file\n", "empty_cooking_game = CookingGame()\n", "print(\"Values stored before loading: \", empty_cooking_game.value_storage)\n", - "empty_cooking_game.load_values(\"cooking_game_values.npz\")\n", + "empty_cooking_game.load_values(save_path)\n", "print(\"Values stored after loading: \", empty_cooking_game.value_storage)" ], "outputs": [ @@ -328,20 +330,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Values stored before loading: []\n" - ] - }, - { - "ename": "FileNotFoundError", - "evalue": "[Errno 2] No such file or directory: 'cooking_game_values.npz'", - "output_type": "error", - "traceback": [ - "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[1;31mFileNotFoundError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[1;32mIn[7], line 7\u001B[0m\n\u001B[0;32m 5\u001B[0m empty_cooking_game \u001B[38;5;241m=\u001B[39m CookingGame()\n\u001B[0;32m 6\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mValues stored before loading: \u001B[39m\u001B[38;5;124m\"\u001B[39m, empty_cooking_game\u001B[38;5;241m.\u001B[39mvalue_storage)\n\u001B[1;32m----> 7\u001B[0m \u001B[43mempty_cooking_game\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_values\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mcooking_game_values.npz\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[0;32m 8\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mValues stored after loading: \u001B[39m\u001B[38;5;124m\"\u001B[39m, empty_cooking_game\u001B[38;5;241m.\u001B[39mvalue_storage)\n", - "File \u001B[1;32mC:\\1_Workspaces\\1_Phd_Projects\\shapiq\\shapiq\\games\\base.py:426\u001B[0m, in \u001B[0;36mGame.load_values\u001B[1;34m(self, path, precomputed)\u001B[0m\n\u001B[0;32m 423\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m path\u001B[38;5;241m.\u001B[39mendswith(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.npz\u001B[39m\u001B[38;5;124m\"\u001B[39m):\n\u001B[0;32m 424\u001B[0m path \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.npz\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m--> 426\u001B[0m data \u001B[38;5;241m=\u001B[39m \u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 427\u001B[0m n_players \u001B[38;5;241m=\u001B[39m data[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mn_players\u001B[39m\u001B[38;5;124m\"\u001B[39m]\n\u001B[0;32m 428\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mn_players \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m n_players \u001B[38;5;241m!=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mn_players:\n", - "File \u001B[1;32mC:\\1_Workspaces\\1_Phd_Projects\\shapiq\\venv\\lib\\site-packages\\numpy\\lib\\npyio.py:427\u001B[0m, in \u001B[0;36mload\u001B[1;34m(file, mmap_mode, allow_pickle, fix_imports, encoding, max_header_size)\u001B[0m\n\u001B[0;32m 425\u001B[0m own_fid \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[0;32m 426\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m--> 427\u001B[0m fid \u001B[38;5;241m=\u001B[39m stack\u001B[38;5;241m.\u001B[39menter_context(\u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mos_fspath\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfile\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mrb\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m)\n\u001B[0;32m 428\u001B[0m own_fid \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[0;32m 430\u001B[0m \u001B[38;5;66;03m# Code to distinguish from NumPy binary files and pickles.\u001B[39;00m\n", - "\u001B[1;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'cooking_game_values.npz'" + "Values stored before loading: []\n", + "Values stored after loading: [ 0. 4. 3. 2. 9. 8. 7. 15.]\n" ] } ], @@ -356,19 +346,42 @@ ] }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-10T12:14:06.092763Z", + "start_time": "2025-01-10T12:14:06.077767Z" + } + }, "cell_type": "code", "source": [ "# initialize a game object directly from precomputed values\n", - "game = shapiq.Game(path_to_values=\"data/cooking_game_values.npz\")\n", + "game = shapiq.Game(path_to_values=save_path)\n", "print(game)\n", "\n", "# query the value function of the game for the same coalitions as before\n", "coals = np.array([[0, 0, 0], [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1]])\n", "game(coals)" ], - "outputs": [], - "execution_count": null + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Game(3 players, normalize=False, normalization_value=0.0, precomputed=True)\n" + ] + }, + { + "data": { + "text/plain": [ + "array([ 0., 9., 8., 7., 15.])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 8 }, { "metadata": {}, @@ -379,7 +392,12 @@ ] }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-10T12:14:06.108755Z", + "start_time": "2025-01-10T12:14:06.095753Z" + } + }, "cell_type": "code", "source": [ "print(cooking_game.characteristic_function)\n", @@ -388,8 +406,17 @@ "except AttributeError as e:\n", " print(\"AttributeError:\", e)" ], - "outputs": [], - "execution_count": null + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{(): 0, (0,): 4, (1,): 3, (2,): 2, (0, 1): 9, (0, 2): 8, (1, 2): 7, (0, 1, 2): 15}\n", + "AttributeError: 'Game' object has no attribute 'characteristic_function'\n" + ] + } + ], + "execution_count": 9 } ], "metadata": { diff --git a/docs/source/notebooks/basics_notebooks/data_valuation.ipynb b/docs/source/notebooks/basics_notebooks/data_valuation.ipynb index 6f4fa23a..cc9b75c1 100644 --- a/docs/source/notebooks/basics_notebooks/data_valuation.ipynb +++ b/docs/source/notebooks/basics_notebooks/data_valuation.ipynb @@ -32,8 +32,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-11-07T15:12:05.353388Z", - "start_time": "2024-11-07T15:12:03.635224Z" + "end_time": "2025-01-10T12:10:49.879006Z", + "start_time": "2025-01-10T12:10:48.286843Z" } }, "outputs": [ @@ -41,7 +41,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Shapiq version: 1.1.0\n" + "Shapiq version: 1.1.1.dev\n" ] } ], @@ -62,24 +62,13 @@ }, { "cell_type": "code", - "execution_count": 2, "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2024-10-22T16:04:37.540707Z", - "start_time": "2024-10-22T16:04:37.438398Z" + "end_time": "2025-01-10T12:10:50.054578Z", + "start_time": "2025-01-10T12:10:49.882009Z" } }, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:37.516569\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ "def plot_synthetic_data(ax, X_train, y_train, X_test, y_test, title):\n", " ax.set_title(title)\n", @@ -155,7 +144,20 @@ "fig, ax = plt.subplots()\n", "\n", "plot_synthetic_data(ax, X_train, y_train, X_test, y_test, \"Synthetic Classification Data\")" - ] + ], + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-10T13:10:50.012571\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 2 }, { "cell_type": "markdown", @@ -169,8 +171,6 @@ }, { "cell_type": "code", - "execution_count": 3, - "outputs": [], "source": [ "class SyntheticDataValuation(shapiq.Game):\n", " \"\"\"The synthetic data valuation tasked modeled as a cooperative game.\n", @@ -224,10 +224,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:37.551682Z", - "start_time": "2024-10-22T16:04:37.549258Z" + "end_time": "2025-01-10T12:10:50.070091Z", + "start_time": "2025-01-10T12:10:50.055577Z" } - } + }, + "outputs": [], + "execution_count": 3 }, { "cell_type": "markdown", @@ -241,17 +243,6 @@ }, { "cell_type": "code", - "execution_count": 4, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Full coalition value: 1.0\n", - "Empty coalition value: 0.0\n" - ] - } - ], "source": [ "from sklearn.svm import LinearSVC\n", "\n", @@ -274,10 +265,21 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:37.577065Z", - "start_time": "2024-10-22T16:04:37.554955Z" + "end_time": "2025-01-10T12:10:50.085090Z", + "start_time": "2025-01-10T12:10:50.072088Z" } - } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Full coalition value: 1.0\n", + "Empty coalition value: 0.0\n" + ] + } + ], + "execution_count": 4 }, { "cell_type": "markdown", @@ -292,17 +294,6 @@ }, { "cell_type": "code", - "execution_count": 5, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:37.637652\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ "fig, ax = plt.subplots()\n", "\n", @@ -327,10 +318,23 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:37.657117Z", - "start_time": "2024-10-22T16:04:37.567430Z" + "end_time": "2025-01-10T12:10:50.241642Z", + "start_time": "2025-01-10T12:10:50.087089Z" } - } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-10T13:10:50.186638\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 5 }, { "cell_type": "markdown", @@ -345,20 +349,9 @@ }, { "cell_type": "code", - "execution_count": 6, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:38.986904\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ "# Compute Shapley values with the ShapIQ approximator for the game function\n", - "exactComputer = shapiq.ExactComputer(n_players=n_players, game_fun=data_valuation_game)\n", + "exactComputer = shapiq.ExactComputer(n_players=n_players, game=data_valuation_game)\n", "sv_values = exactComputer(\"SV\")\n", "sv_values.plot_stacked_bar(\n", " title=\"Shapley Values for Synthetic (Training) Data\",\n", @@ -370,10 +363,23 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:39.005886Z", - "start_time": "2024-10-22T16:04:37.657523Z" + "end_time": "2025-01-10T12:10:53.578653Z", + "start_time": "2025-01-10T12:10:50.242637Z" } - } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-10T13:10:53.525128\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 6 }, { "cell_type": "markdown", @@ -395,17 +401,6 @@ }, { "cell_type": "code", - "execution_count": 7, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:39.045802\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ "fig, ax = plt.subplots()\n", "\n", @@ -424,10 +419,23 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:39.066692Z", - "start_time": "2024-10-22T16:04:39.013827Z" + "end_time": "2025-01-10T12:10:53.829762Z", + "start_time": "2025-01-10T12:10:53.580649Z" } - } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-10T13:10:53.779762\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 7 }, { "cell_type": "markdown", @@ -443,17 +451,6 @@ }, { "cell_type": "code", - "execution_count": 8, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:39.183419\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ "from matplotlib import patches\n", "\n", @@ -478,10 +475,23 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:39.216073Z", - "start_time": "2024-10-22T16:04:39.071637Z" + "end_time": "2025-01-10T12:10:54.146354Z", + "start_time": "2025-01-10T12:10:53.831753Z" } - } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-10T13:10:54.046830\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 8 }, { "cell_type": "markdown", @@ -495,17 +505,6 @@ }, { "cell_type": "code", - "execution_count": 9, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:40.617300\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ "data_valuation_game = SyntheticDataValuation(\n", " classifier=classifier,\n", @@ -517,7 +516,7 @@ ")\n", "\n", "# Compute Shapley values with the shapiq ExactComputer for the game function\n", - "exactComputer = shapiq.ExactComputer(n_players=n_players, game_fun=data_valuation_game)\n", + "exactComputer = shapiq.ExactComputer(n_players=n_players, game=data_valuation_game)\n", "sv_values = exactComputer(\"SV\")\n", "sv_values.plot_stacked_bar()\n", "plt.show()" @@ -525,10 +524,23 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:40.635663Z", - "start_time": "2024-10-22T16:04:39.216847Z" + "end_time": "2025-01-10T12:10:57.601216Z", + "start_time": "2025-01-10T12:10:54.148388Z" } - } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-10T13:10:57.549612\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 9 }, { "cell_type": "markdown", @@ -541,17 +553,6 @@ }, { "cell_type": "code", - "execution_count": 10, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Accuracy on test data before removing corrupted samples: 0.5\n", - "Accuracy on test data after removing corrupted samples: 1.0\n" - ] - } - ], "source": [ "classifier.fit(corrupted_X_train, corruped_y_train)\n", "print(\"Accuracy on test data before removing corrupted samples: \", classifier.score(X_test, y_test))\n", @@ -564,10 +565,21 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:40.640022Z", - "start_time": "2024-10-22T16:04:40.636571Z" + "end_time": "2025-01-10T12:10:57.617287Z", + "start_time": "2025-01-10T12:10:57.604217Z" } - } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy on test data before removing corrupted samples: 0.5\n", + "Accuracy on test data after removing corrupted samples: 1.0\n" + ] + } + ], + "execution_count": 10 }, { "cell_type": "markdown", @@ -580,17 +592,6 @@ }, { "cell_type": "code", - "execution_count": 11, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:40.745038\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ "def plot_decision_boundary(ax, classifier, X_train, y_train, X_test, y_test):\n", " classifier.fit(X_train, y_train)\n", @@ -621,10 +622,23 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:40.881221Z", - "start_time": "2024-10-22T16:04:40.641575Z" + "end_time": "2025-01-10T12:10:57.851328Z", + "start_time": "2025-01-10T12:10:57.621215Z" } - } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-10T13:10:57.765286\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 11 }, { "cell_type": "markdown", @@ -640,16 +654,6 @@ }, { "cell_type": "code", - "execution_count": 12, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Players: 160\n" - ] - } - ], "source": [ "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.model_selection import train_test_split\n", @@ -665,15 +669,23 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:41.055176Z", - "start_time": "2024-10-22T16:04:40.879435Z" + "end_time": "2025-01-10T12:10:58.247552Z", + "start_time": "2025-01-10T12:10:57.854324Z" } - } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Players: 160\n" + ] + } + ], + "execution_count": 12 }, { "cell_type": "code", - "execution_count": 13, - "outputs": [], "source": [ "data_valuation_game = SyntheticDataValuation(\n", " classifier=classifier,\n", @@ -687,10 +699,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:04:41.056814Z", - "start_time": "2024-10-22T16:04:41.055761Z" + "end_time": "2025-01-10T12:10:58.263501Z", + "start_time": "2025-01-10T12:10:58.249487Z" } - } + }, + "outputs": [], + "execution_count": 13 }, { "cell_type": "markdown", @@ -703,17 +717,6 @@ }, { "cell_type": "code", - "execution_count": 14, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:05:15.036777\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ "budgets = [10, 100, 1000, 5000]\n", "erg = {}\n", @@ -761,10 +764,23 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-10-22T16:05:15.069797Z", - "start_time": "2024-10-22T16:04:41.062860Z" + "end_time": "2025-01-10T12:12:43.099030Z", + "start_time": "2025-01-10T12:10:58.266031Z" } - } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-10T13:12:43.060031\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 14 }, { "cell_type": "markdown", diff --git a/docs/source/notebooks/basics_notebooks/sv_calculation.ipynb b/docs/source/notebooks/basics_notebooks/sv_calculation.ipynb index 1b692e39..8c011492 100644 --- a/docs/source/notebooks/basics_notebooks/sv_calculation.ipynb +++ b/docs/source/notebooks/basics_notebooks/sv_calculation.ipynb @@ -19,8 +19,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:46.096233Z", - "start_time": "2024-11-07T15:16:43.959504Z" + "end_time": "2025-01-10T12:14:20.955825Z", + "start_time": "2025-01-10T12:14:19.361009Z" } }, "cell_type": "code", @@ -33,7 +33,7 @@ { "data": { "text/plain": [ - "'1.1.0'" + "'1.1.1.dev'" ] }, "execution_count": 1, @@ -92,8 +92,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:46.126230Z", - "start_time": "2024-11-07T15:16:46.099232Z" + "end_time": "2025-01-10T12:14:20.971829Z", + "start_time": "2025-01-10T12:14:20.956833Z" } }, "cell_type": "code", @@ -152,8 +152,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:46.142231Z", - "start_time": "2024-11-07T15:16:46.128230Z" + "end_time": "2025-01-10T12:14:20.987376Z", + "start_time": "2025-01-10T12:14:20.972822Z" } }, "cell_type": "code", @@ -179,8 +179,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:46.158233Z", - "start_time": "2024-11-07T15:16:46.147238Z" + "end_time": "2025-01-10T12:14:21.003371Z", + "start_time": "2025-01-10T12:14:20.988363Z" } }, "cell_type": "code", @@ -212,8 +212,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:46.174244Z", - "start_time": "2024-11-07T15:16:46.159234Z" + "end_time": "2025-01-10T12:14:21.018371Z", + "start_time": "2025-01-10T12:14:21.005377Z" } }, "cell_type": "code", @@ -281,7 +281,7 @@ "from shapiq import ExactComputer\n", "\n", "# create an ExactComputer object for the cooking game\n", - "exact_computer = ExactComputer(n_players=cooking_game.n_players, game_fun=cooking_game)\n", + "exact_computer = ExactComputer(n_players=cooking_game.n_players, game=cooking_game)\n", "\n", "# compute the Shapley Values for the game\n", "sv_exact = exact_computer(index=\"SV\")\n", @@ -290,8 +290,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-11-07T15:16:46.189782Z", - "start_time": "2024-11-07T15:16:46.180241Z" + "end_time": "2025-01-10T12:14:21.033370Z", + "start_time": "2025-01-10T12:14:21.019366Z" } }, "outputs": [ @@ -328,8 +328,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:46.365289Z", - "start_time": "2024-11-07T15:16:46.191764Z" + "end_time": "2025-01-10T12:14:21.142986Z", + "start_time": "2025-01-10T12:14:21.034373Z" } }, "cell_type": "code", @@ -382,8 +382,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:46.381306Z", - "start_time": "2024-11-07T15:16:46.367293Z" + "end_time": "2025-01-10T12:14:21.157908Z", + "start_time": "2025-01-10T12:14:21.144932Z" } }, "cell_type": "code", @@ -442,8 +442,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:46.600919Z", - "start_time": "2024-11-07T15:16:46.384829Z" + "end_time": "2025-01-10T12:14:21.314210Z", + "start_time": "2025-01-10T12:14:21.158901Z" } }, "cell_type": "code", @@ -547,8 +547,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:51.516931Z", - "start_time": "2024-11-07T15:16:46.605924Z" + "end_time": "2025-01-10T12:14:25.120874Z", + "start_time": "2025-01-10T12:14:21.317206Z" } }, "cell_type": "code", @@ -598,8 +598,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:51.595448Z", - "start_time": "2024-11-07T15:16:51.518933Z" + "end_time": "2025-01-10T12:14:25.167909Z", + "start_time": "2025-01-10T12:14:25.122873Z" } }, "cell_type": "code", @@ -638,8 +638,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:52.080746Z", - "start_time": "2024-11-07T15:16:51.597451Z" + "end_time": "2025-01-10T12:14:25.563052Z", + "start_time": "2025-01-10T12:14:25.170918Z" } }, "cell_type": "code", @@ -700,8 +700,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:53.051974Z", - "start_time": "2024-11-07T15:16:52.082742Z" + "end_time": "2025-01-10T12:14:26.360311Z", + "start_time": "2025-01-10T12:14:25.565980Z" } }, "cell_type": "code", @@ -739,8 +739,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:18:25.286249Z", - "start_time": "2024-11-07T15:16:53.053975Z" + "end_time": "2025-01-10T12:15:58.308779Z", + "start_time": "2025-01-10T12:14:26.364241Z" } }, "cell_type": "code", @@ -782,8 +782,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:18:25.771060Z", - "start_time": "2024-11-07T15:18:25.288247Z" + "end_time": "2025-01-10T12:15:58.668948Z", + "start_time": "2025-01-10T12:15:58.311776Z" } }, "cell_type": "code", diff --git a/docs/source/notebooks/game_theory_notebooks/core.ipynb b/docs/source/notebooks/game_theory_notebooks/core.ipynb index 53b0b6c5..f2ab83d5 100644 --- a/docs/source/notebooks/game_theory_notebooks/core.ipynb +++ b/docs/source/notebooks/game_theory_notebooks/core.ipynb @@ -155,8 +155,8 @@ "outputs": [], "source": [ "import numpy as np\n", - "from shapiq.exact import ExactComputer\n", - "from shapiq.games.base import Game\n", + "from shapiq import ExactComputer\n", + "from shapiq import Game\n", "\n", "\n", "# Define the PaperGame as described above\n", @@ -185,7 +185,7 @@ "paper_game = PaperGame()\n", "\n", "# Initialize the ExactComputer with the PaperGame\n", - "exact_computer = ExactComputer(n_players=3, game_fun=paper_game)\n", + "exact_computer = ExactComputer(n_players=3, game=paper_game)\n", "# Compute the egalitarian least core abbreviated to \"ELC\"\n", "egalitarian_least_core = exact_computer(\"ELC\")" ], diff --git a/docs/source/notebooks/language_notebooks/language_model_game.ipynb b/docs/source/notebooks/language_notebooks/language_model_game.ipynb index 77b46c0e..9349fbcf 100644 --- a/docs/source/notebooks/language_notebooks/language_model_game.ipynb +++ b/docs/source/notebooks/language_notebooks/language_model_game.ipynb @@ -642,7 +642,7 @@ }, "source": [ "# Compute Shapley interactions with the ShapIQ approximator for the game function\n", - "approximator = shapiq.SHAPIQ(n=n_players, max_order=2, index=\"k-SII\")\n", + "approximator = shapiq.KernelSHAPIQ(n=n_players, max_order=2, index=\"k-SII\")\n", "sii_values = approximator.approximate(budget=2**n_players, game=game_fun)\n", "sii_values.dict_values" ], @@ -687,7 +687,7 @@ }, "source": [ "# Compute Shapley interactions with the ShapIQ approximator for the game object\n", - "approximator = shapiq.SHAPIQ(n=game_class.n_players, max_order=2, index=\"k-SII\")\n", + "approximator = shapiq.KernelSHAPIQ(n=game_class.n_players, max_order=2, index=\"k-SII\")\n", "sii_values = approximator.approximate(budget=2**game_class.n_players, game=game_class)\n", "sii_values.dict_values" ], diff --git a/docs/source/notebooks/tabular_notebooks/shapiq_scikit_learn.ipynb b/docs/source/notebooks/tabular_notebooks/shapiq_scikit_learn.ipynb index a81c5d9f..f994c5ee 100644 --- a/docs/source/notebooks/tabular_notebooks/shapiq_scikit_learn.ipynb +++ b/docs/source/notebooks/tabular_notebooks/shapiq_scikit_learn.ipynb @@ -16,9 +16,7 @@ "cell_type": "markdown", "id": "080de90c", "metadata": {}, - "source": [ - "### import packages" - ] + "source": "### Import Packages" }, { "cell_type": "code", @@ -26,8 +24,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2024-11-07T15:15:48.519797Z", - "start_time": "2024-11-07T15:15:48.503808Z" + "end_time": "2025-01-10T13:18:13.686187Z", + "start_time": "2025-01-10T13:18:11.584918Z" } }, "source": [ @@ -43,22 +41,23 @@ { "data": { "text/plain": [ - "'1.1.0'" + "'1.1.1.dev'" ] }, - "execution_count": 2, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 2 + "execution_count": 1 }, { "cell_type": "markdown", "id": "9eb96897", "metadata": {}, "source": [ - "### load data" + "### Load Data\n", + "Let's load the California housing dataset and split it into training and test sets." ] }, { @@ -66,8 +65,8 @@ "id": "7fca3f5a", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:15:48.676852Z", - "start_time": "2024-11-07T15:15:48.618317Z" + "end_time": "2025-01-10T13:18:13.733714Z", + "start_time": "2025-01-10T13:18:13.688188Z" } }, "source": [ @@ -78,7 +77,7 @@ "n_features = X_train.shape[1]" ], "outputs": [], - "execution_count": 3 + "execution_count": 2 }, { "cell_type": "markdown", @@ -87,7 +86,9 @@ "collapsed": false }, "source": [ - "### train a model" + "### Train a Model with Scikit-learn\n", + "Here we train a random forest regressor with 500 trees.\n", + "The model achieves a relatively high R2 score on the test set." ] }, { @@ -96,8 +97,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-11-07T15:16:05.998479Z", - "start_time": "2024-11-07T15:15:48.680848Z" + "end_time": "2025-01-10T13:18:28.078021Z", + "start_time": "2025-01-10T13:18:13.735711Z" } }, "source": [ @@ -118,7 +119,7 @@ ] } ], - "execution_count": 4 + "execution_count": 3 }, { "cell_type": "markdown", @@ -127,7 +128,7 @@ "collapsed": false }, "source": [ - "### model-agnostic explainer\n", + "### Model-Agnostic Explainer\n", "\n", "We use `shapiq.TabularExplainer` to explain any machine learning model for tabular data. \n", "\n", @@ -143,15 +144,15 @@ "id": "e6435098", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:07.310454Z", - "start_time": "2024-11-07T15:16:06.001479Z" + "end_time": "2025-01-10T13:18:29.026747Z", + "start_time": "2025-01-10T13:18:28.079935Z" } }, "source": [ "explainer_tabular = shapiq.TabularExplainer(model=model, data=X_train, index=\"SII\", max_order=2)" ], "outputs": [], - "execution_count": 5 + "execution_count": 4 }, { "cell_type": "markdown", @@ -166,15 +167,15 @@ "id": "9764e3c2", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:07.325455Z", - "start_time": "2024-11-07T15:16:07.311456Z" + "end_time": "2025-01-10T13:18:29.042755Z", + "start_time": "2025-01-10T13:18:29.028747Z" } }, "source": [ "x = X_test[24]" ], "outputs": [], - "execution_count": 6 + "execution_count": 5 }, { "cell_type": "markdown", @@ -190,8 +191,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-11-07T15:16:09.562414Z", - "start_time": "2024-11-07T15:16:07.328455Z" + "end_time": "2025-01-10T13:18:30.843907Z", + "start_time": "2025-01-10T13:18:29.044750Z" } }, "source": [ @@ -208,12 +209,12 @@ ")" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 7 + "execution_count": 6 }, { "cell_type": "markdown", @@ -228,8 +229,8 @@ "id": "79e54c1e", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:09.577928Z", - "start_time": "2024-11-07T15:16:09.564414Z" + "end_time": "2025-01-10T13:18:30.859898Z", + "start_time": "2025-01-10T13:18:30.845897Z" } }, "source": [ @@ -240,50 +241,50 @@ "data": { "text/plain": [ "{(): 0.0,\n", - " (0,): -0.01221294691582856,\n", - " (1,): -0.06805549701001842,\n", - " (2,): -0.04995963418603176,\n", - " (3,): 0.005856228106492294,\n", - " (4,): 0.006152613961076363,\n", - " (5,): -0.08883989239374751,\n", - " (6,): 0.1771379675795001,\n", - " (7,): -0.2776100355484444,\n", - " (0, 1): -0.030430304440182833,\n", - " (0, 2): 0.0409733947743084,\n", - " (0, 3): -0.006735285040975396,\n", - " (0, 4): -0.00584265471942099,\n", - " (0, 5): -0.057042709711482786,\n", - " (0, 6): -0.060954483612182475,\n", - " (0, 7): 0.03558046110939789,\n", - " (1, 2): -0.006521517081325771,\n", - " (1, 3): -0.004154456983576514,\n", - " (1, 4): -0.00560700633546335,\n", - " (1, 5): 0.07471407479283865,\n", - " (1, 6): -0.0071986920204653365,\n", - " (1, 7): -0.005214393115368101,\n", - " (2, 3): -0.008588302199393822,\n", - " (2, 4): -0.0037641409599387054,\n", - " (2, 5): -0.0035235279682149586,\n", - " (2, 6): 0.0027151081649867473,\n", - " (2, 7): -0.012570764436453927,\n", - " (3, 4): -0.004291162361399799,\n", - " (3, 5): -0.003961461841604401,\n", - " (3, 6): -0.005450982713619352,\n", - " (3, 7): -0.005364070146454759,\n", - " (4, 5): -0.012215475119607945,\n", - " (4, 6): -0.004613863863220258,\n", - " (4, 7): -0.003418052765388207,\n", - " (5, 6): -0.01840858915487052,\n", - " (5, 7): -0.00030334625171240555,\n", - " (6, 7): -0.07016564318093256}" + " (0,): 0.039693784765076436,\n", + " (1,): -0.08787130505402384,\n", + " (2,): -0.030182556659407715,\n", + " (3,): 0.010314497962081752,\n", + " (4,): 0.016404012689986223,\n", + " (5,): -0.16357903857975523,\n", + " (6,): 0.17346380234936085,\n", + " (7,): -0.26577439369516503,\n", + " (0, 1): -0.042585290353095114,\n", + " (0, 2): 0.024107340036971913,\n", + " (0, 3): -0.014564433306669166,\n", + " (0, 4): -0.017044014029018048,\n", + " (0, 5): -0.09701443947586665,\n", + " (0, 6): -0.05864803944795568,\n", + " (0, 7): 0.03137724668768478,\n", + " (1, 2): -0.011967872732625255,\n", + " (1, 3): -0.011074616354327904,\n", + " (1, 4): -0.012418359902248854,\n", + " (1, 5): 0.11602598619389336,\n", + " (1, 6): -0.014121881491565378,\n", + " (1, 7): -0.011773778503153455,\n", + " (2, 3): -0.013172619834983895,\n", + " (2, 4): -0.011354684202826508,\n", + " (2, 5): -0.016448764531913802,\n", + " (2, 6): -0.003504932321068288,\n", + " (2, 7): -0.01694232037048102,\n", + " (3, 4): -0.011503262190555282,\n", + " (3, 5): -0.009255317162506105,\n", + " (3, 6): -0.012895688622052844,\n", + " (3, 7): -0.012188341974186959,\n", + " (4, 5): -0.020107243385105573,\n", + " (4, 6): -0.011409873190119286,\n", + " (4, 7): -0.011490783339750722,\n", + " (5, 6): -0.02718532205687375,\n", + " (5, 7): 0.006640562391589585,\n", + " (6, 7): -0.05674171687511046}" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 8 + "execution_count": 7 }, { "cell_type": "markdown", @@ -298,8 +299,8 @@ "id": "d7b29c92", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:09.593930Z", - "start_time": "2024-11-07T15:16:09.579929Z" + "end_time": "2025-01-10T13:18:30.875899Z", + "start_time": "2025-01-10T13:18:30.860899Z" } }, "source": [ @@ -309,50 +310,50 @@ { "data": { "text/plain": [ - "{(0, 1): -0.030430304440182833,\n", - " (0, 2): 0.0409733947743084,\n", - " (0, 3): -0.006735285040975396,\n", - " (0, 4): -0.00584265471942099,\n", - " (0, 5): -0.057042709711482786,\n", - " (0, 6): -0.060954483612182475,\n", - " (0, 7): 0.03558046110939789,\n", - " (1, 2): -0.006521517081325771,\n", - " (1, 3): -0.004154456983576514,\n", - " (1, 4): -0.00560700633546335,\n", - " (1, 5): 0.07471407479283865,\n", - " (1, 6): -0.0071986920204653365,\n", - " (1, 7): -0.005214393115368101,\n", - " (2, 3): -0.008588302199393822,\n", - " (2, 4): -0.0037641409599387054,\n", - " (2, 5): -0.0035235279682149586,\n", - " (2, 6): 0.0027151081649867473,\n", - " (2, 7): -0.012570764436453927,\n", - " (3, 4): -0.004291162361399799,\n", - " (3, 5): -0.003961461841604401,\n", - " (3, 6): -0.005450982713619352,\n", - " (3, 7): -0.005364070146454759,\n", - " (4, 5): -0.012215475119607945,\n", - " (4, 6): -0.004613863863220258,\n", - " (4, 7): -0.003418052765388207,\n", - " (5, 6): -0.01840858915487052,\n", - " (5, 7): -0.00030334625171240555,\n", - " (6, 7): -0.07016564318093256}" + "{(0, 1): -0.042585290353095114,\n", + " (0, 2): 0.024107340036971913,\n", + " (0, 3): -0.014564433306669166,\n", + " (0, 4): -0.017044014029018048,\n", + " (0, 5): -0.09701443947586665,\n", + " (0, 6): -0.05864803944795568,\n", + " (0, 7): 0.03137724668768478,\n", + " (1, 2): -0.011967872732625255,\n", + " (1, 3): -0.011074616354327904,\n", + " (1, 4): -0.012418359902248854,\n", + " (1, 5): 0.11602598619389336,\n", + " (1, 6): -0.014121881491565378,\n", + " (1, 7): -0.011773778503153455,\n", + " (2, 3): -0.013172619834983895,\n", + " (2, 4): -0.011354684202826508,\n", + " (2, 5): -0.016448764531913802,\n", + " (2, 6): -0.003504932321068288,\n", + " (2, 7): -0.01694232037048102,\n", + " (3, 4): -0.011503262190555282,\n", + " (3, 5): -0.009255317162506105,\n", + " (3, 6): -0.012895688622052844,\n", + " (3, 7): -0.012188341974186959,\n", + " (4, 5): -0.020107243385105573,\n", + " (4, 6): -0.011409873190119286,\n", + " (4, 7): -0.011490783339750722,\n", + " (5, 6): -0.02718532205687375,\n", + " (5, 7): 0.006640562391589585,\n", + " (6, 7): -0.05674171687511046}" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 9 + "execution_count": 8 }, { "cell_type": "code", "id": "f0eb589b", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:09.609934Z", - "start_time": "2024-11-07T15:16:09.595932Z" + "end_time": "2025-01-10T13:18:30.891898Z", + "start_time": "2025-01-10T13:18:30.876904Z" } }, "source": [ @@ -362,30 +363,30 @@ { "data": { "text/plain": [ - "array([[-0.01221295, -0.0304303 , 0.04097339, -0.00673529, -0.00584265,\n", - " -0.05704271, -0.06095448, 0.03558046],\n", - " [-0.0304303 , -0.0680555 , -0.00652152, -0.00415446, -0.00560701,\n", - " 0.07471407, -0.00719869, -0.00521439],\n", - " [ 0.04097339, -0.00652152, -0.04995963, -0.0085883 , -0.00376414,\n", - " -0.00352353, 0.00271511, -0.01257076],\n", - " [-0.00673529, -0.00415446, -0.0085883 , 0.00585623, -0.00429116,\n", - " -0.00396146, -0.00545098, -0.00536407],\n", - " [-0.00584265, -0.00560701, -0.00376414, -0.00429116, 0.00615261,\n", - " -0.01221548, -0.00461386, -0.00341805],\n", - " [-0.05704271, 0.07471407, -0.00352353, -0.00396146, -0.01221548,\n", - " -0.08883989, -0.01840859, -0.00030335],\n", - " [-0.06095448, -0.00719869, 0.00271511, -0.00545098, -0.00461386,\n", - " -0.01840859, 0.17713797, -0.07016564],\n", - " [ 0.03558046, -0.00521439, -0.01257076, -0.00536407, -0.00341805,\n", - " -0.00030335, -0.07016564, -0.27761004]])" + "array([[ 0.03969378, -0.04258529, 0.02410734, -0.01456443, -0.01704401,\n", + " -0.09701444, -0.05864804, 0.03137725],\n", + " [-0.04258529, -0.08787131, -0.01196787, -0.01107462, -0.01241836,\n", + " 0.11602599, -0.01412188, -0.01177378],\n", + " [ 0.02410734, -0.01196787, -0.03018256, -0.01317262, -0.01135468,\n", + " -0.01644876, -0.00350493, -0.01694232],\n", + " [-0.01456443, -0.01107462, -0.01317262, 0.0103145 , -0.01150326,\n", + " -0.00925532, -0.01289569, -0.01218834],\n", + " [-0.01704401, -0.01241836, -0.01135468, -0.01150326, 0.01640401,\n", + " -0.02010724, -0.01140987, -0.01149078],\n", + " [-0.09701444, 0.11602599, -0.01644876, -0.00925532, -0.02010724,\n", + " -0.16357904, -0.02718532, 0.00664056],\n", + " [-0.05864804, -0.01412188, -0.00350493, -0.01289569, -0.01140987,\n", + " -0.02718532, 0.1734638 , -0.05674172],\n", + " [ 0.03137725, -0.01177378, -0.01694232, -0.01218834, -0.01149078,\n", + " 0.00664056, -0.05674172, -0.26577439]])" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 10 + "execution_count": 9 }, { "cell_type": "markdown", @@ -394,7 +395,7 @@ "collapsed": false }, "source": [ - "### visualization of Shapley interactions\n", + "### Visualization of Shapley interactions\n", "\n", "`shapiq` includes the following plotting functions:\n", "\n", @@ -415,8 +416,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-11-07T15:16:10.034177Z", - "start_time": "2024-11-07T15:16:09.612930Z" + "end_time": "2025-01-10T13:18:31.218180Z", + "start_time": "2025-01-10T13:18:30.892897Z" } }, "source": [ @@ -433,7 +434,7 @@ "(
, )" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, @@ -442,21 +443,21 @@ "text/plain": [ "
" ], - "image/png": "" + "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], - "execution_count": 11 + "execution_count": 10 }, { "cell_type": "code", "id": "49395db0", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:26.443Z", - "start_time": "2024-11-07T15:16:26.239629Z" + "end_time": "2025-01-10T13:18:31.450321Z", + "start_time": "2025-01-10T13:18:31.220178Z" } }, "source": [ @@ -471,21 +472,21 @@ "text/plain": [ "
" ], - "image/png": "" + "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], - "execution_count": 13 + "execution_count": 11 }, { "cell_type": "code", "id": "208c5241", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:29.323496Z", - "start_time": "2024-11-07T15:16:29.068931Z" + "end_time": "2025-01-10T13:18:31.607927Z", + "start_time": "2025-01-10T13:18:31.452241Z" } }, "source": [ @@ -500,13 +501,13 @@ "text/plain": [ "
" ], - "image/png": "" + "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], - "execution_count": 14 + "execution_count": 12 }, { "cell_type": "markdown", @@ -521,8 +522,8 @@ "id": "34db0c8f", "metadata": { "ExecuteTime": { - "end_time": "2024-11-07T15:16:33.627704Z", - "start_time": "2024-11-07T15:16:31.950592Z" + "end_time": "2025-01-10T13:18:32.424641Z", + "start_time": "2025-01-10T13:18:31.609925Z" } }, "source": [ @@ -534,13 +535,13 @@ "text/plain": [ "
" ], - "image/png": "" + "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], - "execution_count": 15 + "execution_count": 13 } ], "metadata": { diff --git a/docs/source/notebooks/vision_notebooks/vision_transformer.ipynb b/docs/source/notebooks/vision_notebooks/vision_transformer.ipynb index e5c5bc12..320f980f 100644 --- a/docs/source/notebooks/vision_notebooks/vision_transformer.ipynb +++ b/docs/source/notebooks/vision_notebooks/vision_transformer.ipynb @@ -188,9 +188,9 @@ ], "source": [ "# get the exact SII values explanation\n", - "from shapiq.exact import ExactComputer\n", + "from shapiq import ExactComputer\n", "\n", - "exact = ExactComputer(n_players=game_loaded.n_players, game_fun=game_loaded)\n", + "exact = ExactComputer(n_players=game_loaded.n_players, game=game_loaded)\n", "sii = exact(index=\"k-SII\", order=2)\n", "sii" ] @@ -268,7 +268,7 @@ "source": [ "# load the 16 player values and explain\n", "game_loaded = Game(path_to_values=\"pre_computed_image_16.npz\", normalize=True)\n", - "exact = ExactComputer(n_players=game_loaded.n_players, game_fun=game_loaded)\n", + "exact = ExactComputer(n_players=game_loaded.n_players, game=game_loaded)\n", "sii = exact(index=\"k-SII\", order=2)\n", "sii" ] From cf4fbbd3a3c280919fec7a20897e4534f6d069ee Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 15:07:15 +0100 Subject: [PATCH 16/16] ran TabPFN notebook --- .../tabular_notebooks/explaining_tabpfn.ipynb | 338 ++++++++++++------ .../tabular_notebooks/tabpfn_values.npz | Bin 0 -> 2151 bytes .../tabular_notebooks/tabpfn_values_copy.npz | Bin 2202 -> 0 bytes 3 files changed, 225 insertions(+), 113 deletions(-) create mode 100644 docs/source/notebooks/tabular_notebooks/tabpfn_values.npz delete mode 100644 docs/source/notebooks/tabular_notebooks/tabpfn_values_copy.npz diff --git a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb index a43907c7..11a4343c 100644 --- a/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb +++ b/docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb @@ -28,8 +28,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2025-01-10T11:47:28.951329Z", - "start_time": "2025-01-10T11:47:24.953799Z" + "end_time": "2025-01-10T13:55:35.932354Z", + "start_time": "2025-01-10T13:55:31.928667Z" } }, "source": [ @@ -70,8 +70,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-10T11:47:29.014925Z", - "start_time": "2025-01-10T11:47:28.953368Z" + "end_time": "2025-01-10T13:55:35.978513Z", + "start_time": "2025-01-10T13:55:35.933357Z" } }, "cell_type": "code", @@ -240,8 +240,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-10T11:47:29.030865Z", - "start_time": "2025-01-10T11:47:29.016917Z" + "end_time": "2025-01-10T13:55:35.994521Z", + "start_time": "2025-01-10T13:55:35.979512Z" } }, "cell_type": "code", @@ -283,8 +283,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-10T11:54:31.725938Z", - "start_time": "2025-01-10T11:54:31.533466Z" + "end_time": "2025-01-10T13:55:36.326775Z", + "start_time": "2025-01-10T13:55:35.995512Z" } }, "cell_type": "code", @@ -300,7 +300,7 @@ "TabPFNRegressor(device=device(type='cpu'), n_jobs=7)" ], "text/html": [ - "
TabPFNRegressor(device=device(type='cpu'), n_jobs=7)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + "
TabPFNRegressor(device=device(type='cpu'), n_jobs=7)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ] }, - "execution_count": 10, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 10 + "execution_count": 4 }, { "metadata": {}, @@ -732,8 +732,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-10T11:49:48.449670Z", - "start_time": "2025-01-10T11:47:29.286685Z" + "end_time": "2025-01-10T13:57:53.128517Z", + "start_time": "2025-01-10T13:55:36.333769Z" } }, "cell_type": "code", @@ -749,8 +749,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-10T11:49:48.464675Z", - "start_time": "2025-01-10T11:49:48.451681Z" + "end_time": "2025-01-10T13:57:53.144447Z", + "start_time": "2025-01-10T13:57:53.129439Z" } }, "cell_type": "code", @@ -771,8 +771,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "MSE: 0.27140348437031175 R2: 0.7964621203301282\n", - "Average prediction: 2.0861094\n" + "MSE: 0.27149947144257525 R2: 0.796390135236755\n", + "Average prediction: 2.0852828\n" ] } ], @@ -781,16 +781,17 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-10T11:54:38.179564Z", - "start_time": "2025-01-10T11:54:38.008602Z" + "end_time": "2025-01-10T13:57:53.331718Z", + "start_time": "2025-01-10T13:57:53.145436Z" } }, "cell_type": "code", "source": [ "# we will reset the model to less training data because we are on CPU\n", "if device == torch.device(\"cpu\"):\n", - " print(\"Resetting the model to less training data\")\n", - " model.fit(x_train[:200], y_train[:200])" + " print(\"Resetting the model to less training data:\", x_train.shape[0])\n", + " x_train, y_train = x_train[:50], y_train[:50]\n", + " model.fit(x_train, y_train)" ], "id": "7f6253cf223e9136", "outputs": [ @@ -798,11 +799,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "Resetting the model to less training data\n" + "Resetting the model to less training data: 500\n" ] } ], - "execution_count": 11 + "execution_count": 7 }, { "metadata": {}, @@ -817,14 +818,14 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-10T11:54:46.548358Z", - "start_time": "2025-01-10T11:54:40.883018Z" + "end_time": "2025-01-10T13:57:54.328409Z", + "start_time": "2025-01-10T13:57:53.334623Z" } }, "cell_type": "code", "source": [ - "x_explain = x_data.values[0]\n", - "y_explain = y_data.values[0]\n", + "x_explain = x_data.values[1000]\n", + "y_explain = y_data.values[1000]\n", "\n", "prediction = model.predict(x_explain.reshape(1, -1))[0]\n", "print(\"Prediction: \", prediction)\n", @@ -837,13 +838,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Prediction: 4.2348824\n", - "True value: 4.526\n", - "Average prediction: 2.0861094\n" + "Prediction: 1.8186865\n", + "True value: 1.844\n", + "Average prediction: 2.0852828\n" ] } ], - "execution_count": 12 + "execution_count": 8 }, { "metadata": {}, @@ -861,11 +862,9 @@ }, { "metadata": { - "jupyter": { - "is_executing": true - }, "ExecuteTime": { - "start_time": "2025-01-10T11:54:51.476773Z" + "end_time": "2025-01-10T14:02:39.961668Z", + "start_time": "2025-01-10T13:57:54.329359Z" } }, "cell_type": "code", @@ -886,14 +885,24 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "0a217f4df0d74db8a8d6b59cb29e2291" + "model_id": "58adb18d135f41429ff10942996b0c2a" } }, "metadata": {}, "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" } ], - "execution_count": null + "execution_count": 9 }, { "metadata": {}, @@ -914,8 +923,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-10T11:54:19.815427Z", - "start_time": "2025-01-10T11:54:19.815427Z" + "end_time": "2025-01-10T14:02:39.977683Z", + "start_time": "2025-01-10T14:02:39.963673Z" } }, "cell_type": "code", @@ -960,7 +969,7 @@ ], "id": "37a977c5f4a88aee", "outputs": [], - "execution_count": null + "execution_count": 10 }, { "metadata": {}, @@ -974,25 +983,25 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-10T11:54:19.816427Z", - "start_time": "2025-01-10T11:54:19.816427Z" + "end_time": "2025-01-10T14:02:39.993671Z", + "start_time": "2025-01-10T14:02:39.980669Z" } }, "cell_type": "code", "source": [ "import os\n", "\n", - "if not os.path.exists(\"tabpfn_values_copy.npz\"):\n", + "if not os.path.exists(\"tabpfn_values.npz\"):\n", " tabpfn_game = TabPFNGame(model, x_train, y_train, x_explain, average_prediction)\n", " tabpfn_game.verbose = True # see the pre-computation progress\n", " tabpfn_game.precompute()\n", " tabpfn_game.save_values(\"tabpfn_values.npz\")\n", "\n", - "tabpfn_game = shapiq.Game(path_to_values=\"tabpfn_values_copy.npz\", normalize=False)" + "tabpfn_game = shapiq.Game(path_to_values=\"tabpfn_values.npz\", normalize=False)" ], "id": "7b2606969b5bab0", "outputs": [], - "execution_count": null + "execution_count": 11 }, { "metadata": {}, @@ -1008,8 +1017,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-01-10T11:54:19.817430Z", - "start_time": "2025-01-10T11:54:19.817430Z" + "end_time": "2025-01-10T14:02:40.009674Z", + "start_time": "2025-01-10T14:02:39.994665Z" } }, "cell_type": "code", @@ -1019,60 +1028,122 @@ "print(\"Latitude and Longitude: \", tabpfn_game[(6, 7)]) # lat. and lon. are at index 6 and 7" ], "id": "a96e3795ea1df8a0", - "outputs": [], - "execution_count": null + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No features: 2.0861093997955322\n", + "All features: 1.8420544862747192\n", + "Latitude and Longitude: 1.6669323444366455\n" + ] + } + ], + "execution_count": 12 }, { "metadata": {}, "cell_type": "markdown", "source": [ - "With only latitude and longitude, we can see that the model predicts a higher price than with all or no features together.\n", - "Let's compute some explanation values for the TabPFN model with ``shapiq.ExactComputer``:" + "Only providing the latitude and longitude features results in a prediction of 1.66, which is less than the average prediction of around 2.0 and the prediction with all features, which would be 1.84. \n", + "This suggests that the latitude and longitude may reduce the house price.\n", + "Let's compute some explanation values for the TabPFN model with ``shapiq.ExactComputer`` and check this out:" ], "id": "704e9c58dd3273d" }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-10T14:03:00.196100Z", + "start_time": "2025-01-10T14:03:00.168899Z" + } + }, "cell_type": "code", "source": [ - "exact_computer = shapiq.ExactComputer(n_players=tabpfn_game.n_players, game_fun=tabpfn_game)\n", + "exact_computer = shapiq.ExactComputer(n_players=tabpfn_game.n_players, game=tabpfn_game)\n", "sv = exact_computer(index=\"SV\", order=1) # compute the Shapley values\n", "fsii = exact_computer(index=\"FSII\", order=2) # compute Faithful Shapley Interaction values" ], "id": "1887b05e6bd7cda8", "outputs": [], - "execution_count": null + "execution_count": 14 }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-10T14:03:02.672221Z", + "start_time": "2025-01-10T14:03:02.238547Z" + } + }, "cell_type": "code", "source": [ "display(sv.dict_values)\n", "sv.plot_force(feature_names=feature_names)" ], "id": "7bfdd3a9e1ff6b1d", - "outputs": [], - "execution_count": null + "outputs": [ + { + "data": { + "text/plain": [ + "{(): 2.0861093997955322,\n", + " (0,): -0.16847709885665363,\n", + " (1,): 0.030854925797099225,\n", + " (2,): -0.04534098236333772,\n", + " (3,): 0.06734204618703749,\n", + " (4,): 0.010694948690278039,\n", + " (5,): 0.023461689409755293,\n", + " (6,): -0.09278867258912055,\n", + " (7,): -0.06980176979587177}" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 15 }, { "metadata": {}, "cell_type": "markdown", "source": [ - "From the Shapley values, we can see that longitude has a very high positive impact on the prediction and increases the house price.\n", - "When we compute second order Shapley interactions (``index=FSII``, ``order=2``) we can see that the interaction between latitude and longitude together has a positive impact.\n", - "This suggests that the model learns the interactions between latitude and longitude features.\n", - "\n", - "Interestingly, longitude also has a couple of negative interactions with other features such as the median income, which decreases the house price.\n" + "From the Shapley values, we can see that both latitude and longitude have a negative impact on the house price.\n", + "When we compute second order Shapley interactions (``index=FSII``, ``order=2``) we can see that the interaction between latitude and longitude together actually has a very negative impact on the house price." ], "id": "fceea72f0e13feb1" }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-10T14:03:05.650284Z", + "start_time": "2025-01-10T14:03:05.111943Z" + } + }, "cell_type": "code", "source": "fsii.plot_force(feature_names=feature_names)", "id": "7df6eae3201659ab", - "outputs": [], - "execution_count": null + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 16 }, { "metadata": {}, @@ -1088,7 +1159,12 @@ "id": "e9b187c6a678a8a8" }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-10T14:03:19.687073Z", + "start_time": "2025-01-10T14:03:19.327787Z" + } + }, "cell_type": "code", "source": [ "approximator = shapiq.KernelSHAP(n=tabpfn_game.n_players, random_state=42)\n", @@ -1096,11 +1172,27 @@ "sv.plot_force(feature_names=feature_names)" ], "id": "7203ae35139cc10a", - "outputs": [], - "execution_count": null + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 17 }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-10T14:03:21.717660Z", + "start_time": "2025-01-10T14:03:21.298376Z" + } + }, "cell_type": "code", "source": [ "approximator = shapiq.RegressionFSII(n=tabpfn_game.n_players, random_state=42, max_order=2)\n", @@ -1108,8 +1200,28 @@ "fsii.plot_force(feature_names=feature_names)" ], "id": "c0baa11868d4769e", - "outputs": [], - "execution_count": null + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\1_Workspaces\\1_Phd_Projects\\shapiq\\shapiq\\approximator\\regression\\_base.py:342: UserWarning: Linear regression equation is singular, a least squares solutions is used instead.\n", + "\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 18 } ], "metadata": { diff --git a/docs/source/notebooks/tabular_notebooks/tabpfn_values.npz b/docs/source/notebooks/tabular_notebooks/tabpfn_values.npz new file mode 100644 index 0000000000000000000000000000000000000000..184e0ee67c3636934af20b2037788f4fec43b8d7 GIT binary patch literal 2151 zcma)7dpMM78z0AzVV21uIejU%Fb$PM&cuh+4yv)HayHo+Gc;t-5KFe0tV0{pki$+Y zv^C1cfG&&{@u@g!(7G0!2ke2 zN-(tnmNJi6!QueGRl!yU00vz4Lq`VT3^8HRA^_0#TY(!u=pnSbw-!9JgH+mmR8qi- zX`QVOP>7d3NTYDjC9@91N=oz=Nn|Lb{#+W{A9 zw?{WZCsV3Sy~XQsu8Wfmh}alEYTXz2Y$lF5jXS@-jArN6`%f#3XIpk9P>*^t8q5-0 z-%y7Je57UF#a>6fU+jWnCVt z6~K@_!$G7e&9Lqq2nJv%OeK$79K5e1>qmY%J>j8upFkn9Uyiigx|_&3YRq>$ z!G@40X=+q!lbo0a`19rYfpA-z$>hsAl3BV{&$Bfi5W+a)kGL5a-c(9_o?cxh8kaXc z8B_=1@>CqS9eSUa6+X{Go@);4Qx5S@PQjJ8+N{PFbs+jV8)@*5TeH!<*1XzxYxjeD zp(g`0Wu4bw6y)jQR4Fq?wxty+T}+Q);N?ZLPWyrK?SvIaEXejLD%*+9*Y0A~Ki+65rj4MJ_5H z#P5k8J>9EDbw7?Wq}Q#a^dQ=vnF za)i>GF`r_qkZPZq_9ldjvN6ZaCb6>e~*TXKmDk)O9-@Q)$jEIQ^-n^UZzz}p z=ctPYg@&r*+_MQOmZnHYb&6_npYgqG6ck=J8-QJ@R1Zj}5ngUi>QVm~d&N^92221U z9_zz?OFtQQScR;XC0^q~%aBAw+rAIQb5K*!csFyif@9q5H#PVdp3Fvb{HG6Ah*K3a zn9}d1bw_g)0OKB<#ksO&L#AA`7{q~F`QYeyef#=H zcJv>ZRr!|rgDt!TUu^j%xh6#en}wavz}vfN<6c-xHV<{wFx~F?Ab0cQ+n*J%n^qbc zZ*1ZkQk#@hH>{=P?4w$NYqMS;j==}sB_)lZIO4HX}P;AGkx1I|Q z@~Bbm>qDF-%t}LwM;>!7lx!B_#$N=kblIvWozg%qQTi{w;FWYdI8qjacS%EKES;t&| z+Wu-TlaRpse#Y_%w_Y;?M|CmYG{%ZqoL;%N0NPSq8;FV6-SZo!&5re0bnDpAuza(j zE$WSIuIp;v7%g4|wxdFIIxQbYiu|WS$q6b{0M-v35)p#M{9lz4dI_!6EBJt3gFd4&cM?7gdS5__+-rqi*yN+wBz*b@ypyPgs zN(#NI67qmEO@E4czzG>|UjMYfqxeJy@Tv)_l=B2UW`Z)#XT4u}RMwlGH(xV?Q%`!z zqAA-P`bF{IDHi(Gs4ML0$*BbY;!$5Y$Johlkj{2zYQ;)ibR2Z3B^Go86u6#!v5)TA>D^-S`y`ZgP0#ZzR7cQ$-Tw5j^(s0vee;QOz_urA0HdChNWF2CE}_ivip)BiIS#Nv+F7Dqqem)@@w-WQZ^R}3B!0sevIfT zs_o{I3|nsUBR}THl!x+bE~T{hZhLoYy(${XXY&KA-pbywN9s zAQ%7u*q~TO0K<&mqw+uiK&;{v41fV5@j(%Ua0_C{RV4s)?WzI>SVgQx;TgI6nTV|` zSY~~nMLH;pf_J1RrYYOMYv`V+-iUws&?jH_LU2NE&#Cm-lwDi= zA07K(o<=B%l`6YMl)=k$(PwTv#5-3-8CiPnoTtnvTJR z2qSKZ8i(_}k|*DWcesES+S`|AAtI&Up&GnRlZ`*lsp{li8(qdiWv#rHQr2aG^{d%> zsg0N@G%~Ow1T=e9dX`XE@UW0pwcb#ss(B*ow1Fztr?u-T!N21`jyt(iC^$!rY1$^m zlopBREn7bn61tBEzuz&#?n|b-I>((@`X@E?3aT-s9z8zLkZk*kt1_aS`)SGnRft5q zcdzzfjA7pyYio-z>eyJ}%LnSMUVBGw|5(G0Uk1uAwP;y!yML2fr^hE%Sp_P`ywTv( z3gepWGLyphhHaWE&^|KkD!EhCgHfmOX*du1jlx#?R=U;pP>%)V^01Td{7;Q`)>3U! zluVQ-W5(Vg$wF>M^o&xsCo0W@|rtZozv#F zT1h)2(5iMLb@bZTjlXv+nI}@rQN!)ljA91IPLDG;X%Zo>$@uw|iYRBK`=^oK-=vVH z$JpCBv-N|ToB;O?OJ|wIOpmI#MK_a~Q-8O6sBK&4?zoL#GMwYUGqZ)qROjBB>-QYL z{o$7VOR@QI%@aEL{8A?$AGlT0ej2zutv+*oq1|jkFK2r=Pm6=%sL{ssKkXlvSRBP% z=wKTuc<=NUsrHDLb+W_sflwz%J8WseW%dn4 zMO?KYY_r^*N1-#Xr^}i)&2%0ylYWdTu&w;AS-y`tUr|)TeXfq1b&BU)3x`U`@9)*O ze5OY+dIv=qE{A5n;Mh@{4G}dQt1uGI#LNUf+&<1v4HZ8?+^^ZWxtnL5lAEP1`_YPfInxvpi3FLg1Vw-+kosMTSXJF=wA z%jrE_op>2iRM%ly+tN@Gt+M#y(IIjDC4r+0A)L{&mF%SB3!QqREVhki$S1WPcU5zr z;V_6snRJCbWdJJXlm$&XnjpgrTJ}IKB~1K?%L*%*zw@7_#NDfQuFmdR>c|{rn$lt~&_7bWvi6sgeP5qGYnhNLJI4S~t){VF&n; zeN{g2?}in%?07Tr1%O^R)Y@EU_xLIOXQ*i^hMErvALM`8pG5q>!M2K8jk_KPP)+-w z^MQ6p*|FKGCeU%fWy39CGuVJ3xakfk_}37yMF%pfNRS87yyex?EXb3DbC@I+aR#~{ zf0o4_d(5(|kO{_~5-Vi>Z@H3BALMLYgpa(EqQ}NiqKG5xw=1HFxrgOl)Z5d)4u)?_ zcq-z-51D0bK9=NQUKvQ13P#7{E|dI^1v=0) z?x~zZBL;6c-%`flk?5))4I|YxG6y078c+cmxLXHnXsDCmZP;f80>@2YMl@lGiQ7S2LKNFoYW$BjN_TXwTOvyr1RaNyHt#o6=$$umPUrmA(7OYUq zHTha3LIbbSAEI;HL>H5KV7DZ>;m{&PoO`|h<}?3z^n=g;5r){FjYUcbV< y*Z%_Fa@Tj_Yc5$8u)Z(fQrFk-YpRl(g8JvxM4wPmJQn}}DYiMqF+%|b0R9b7Rk