Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TabPFN Example and Alignment with Paper #299

Merged
merged 17 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
## Changelog

### Development
- renames ``game_fun`` parameter in ``shapiq.ExactComputer`` to ``game`` [#297](https://github.com/mmschlk/shapiq/issues/297)
- adds a TabPFN example notebook to the documentation
- removes warning when class_index is not provided in explainers [#298](https://github.com/mmschlk/shapiq/issues/298)
- adds the `sentence_plot` function to the `plot` module to visualize the contributions of words to a language model prediction in a sentence-like format
- makes abbreviations in the `plot` module optional [#281](https://github.com/mmschlk/shapiq/issues/281)
- adds the `upset_plot` function to the `plot` module to visualize the interactions of higher-order [#290](https://github.com/mmschlk/shapiq/issues/290)
Expand Down
105 changes: 66 additions & 39 deletions docs/source/notebooks/basics_notebooks/custom_games.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,23 @@
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-17T14:23:19.696179Z",
"start_time": "2024-12-17T14:23:18.268301Z"
"end_time": "2025-01-10T12:14:05.982266Z",
"start_time": "2025-01-10T12:14:04.426262Z"
}
},
"cell_type": "code",
"source": [
"import shapiq\n",
"import numpy as np\n",
"import os\n",
"\n",
"shapiq.__version__"
],
"outputs": [
{
"data": {
"text/plain": [
"'1.1.1'"
"'1.1.1.dev'"
]
},
"execution_count": 1,
Expand Down Expand Up @@ -86,8 +87,8 @@
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-17T14:23:19.711170Z",
"start_time": "2024-12-17T14:23:19.698170Z"
"end_time": "2025-01-10T12:14:05.997215Z",
"start_time": "2025-01-10T12:14:05.985240Z"
}
},
"cell_type": "code",
Expand Down Expand Up @@ -147,8 +148,8 @@
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-17T14:23:19.727173Z",
"start_time": "2024-12-17T14:23:19.713181Z"
"end_time": "2025-01-10T12:14:06.013212Z",
"start_time": "2025-01-10T12:14:06.000205Z"
}
},
"cell_type": "code",
Expand All @@ -174,8 +175,8 @@
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-17T14:23:19.742179Z",
"start_time": "2024-12-17T14:23:19.730173Z"
"end_time": "2025-01-10T12:14:06.029218Z",
"start_time": "2025-01-10T12:14:06.014204Z"
}
},
"cell_type": "code",
Expand Down Expand Up @@ -207,8 +208,8 @@
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-17T14:23:19.758170Z",
"start_time": "2024-12-17T14:23:19.745174Z"
"end_time": "2025-01-10T12:14:06.045214Z",
"start_time": "2025-01-10T12:14:06.033206Z"
}
},
"cell_type": "code",
Expand Down Expand Up @@ -245,8 +246,8 @@
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-17T14:23:19.789577Z",
"start_time": "2024-12-17T14:23:19.760172Z"
"end_time": "2025-01-10T12:14:06.061209Z",
"start_time": "2025-01-10T12:14:06.046217Z"
}
},
"cell_type": "code",
Expand Down Expand Up @@ -280,7 +281,7 @@
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "e6e0bc19180b4969bae2cbcabef70fdf"
"model_id": "218e4aac6918408d8a38f1c9646509fb"
}
},
"metadata": {},
Expand Down Expand Up @@ -308,40 +309,29 @@
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-17T14:23:20.357939Z",
"start_time": "2024-12-17T14:23:19.792499Z"
"end_time": "2025-01-10T12:14:06.076763Z",
"start_time": "2025-01-10T12:14:06.063214Z"
}
},
"cell_type": "code",
"source": [
"# save the precomputed values to a file\n",
"cooking_game.save_values(\"data/cooking_game_values.npz\")\n",
"save_path = os.path.join(\"..\", \"data\", \"cooking_game_values.npz\")\n",
"cooking_game.save_values(save_path)\n",
"\n",
"# load the precomputed values from the file\n",
"empty_cooking_game = CookingGame()\n",
"print(\"Values stored before loading: \", empty_cooking_game.value_storage)\n",
"empty_cooking_game.load_values(\"cooking_game_values.npz\")\n",
"empty_cooking_game.load_values(save_path)\n",
"print(\"Values stored after loading: \", empty_cooking_game.value_storage)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Values stored before loading: []\n"
]
},
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'cooking_game_values.npz'",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mFileNotFoundError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[1;32mIn[7], line 7\u001B[0m\n\u001B[0;32m 5\u001B[0m empty_cooking_game \u001B[38;5;241m=\u001B[39m CookingGame()\n\u001B[0;32m 6\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mValues stored before loading: \u001B[39m\u001B[38;5;124m\"\u001B[39m, empty_cooking_game\u001B[38;5;241m.\u001B[39mvalue_storage)\n\u001B[1;32m----> 7\u001B[0m \u001B[43mempty_cooking_game\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_values\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mcooking_game_values.npz\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[0;32m 8\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mValues stored after loading: \u001B[39m\u001B[38;5;124m\"\u001B[39m, empty_cooking_game\u001B[38;5;241m.\u001B[39mvalue_storage)\n",
"File \u001B[1;32mC:\\1_Workspaces\\1_Phd_Projects\\shapiq\\shapiq\\games\\base.py:426\u001B[0m, in \u001B[0;36mGame.load_values\u001B[1;34m(self, path, precomputed)\u001B[0m\n\u001B[0;32m 423\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m path\u001B[38;5;241m.\u001B[39mendswith(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.npz\u001B[39m\u001B[38;5;124m\"\u001B[39m):\n\u001B[0;32m 424\u001B[0m path \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.npz\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m--> 426\u001B[0m data \u001B[38;5;241m=\u001B[39m \u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 427\u001B[0m n_players \u001B[38;5;241m=\u001B[39m data[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mn_players\u001B[39m\u001B[38;5;124m\"\u001B[39m]\n\u001B[0;32m 428\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mn_players \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m n_players \u001B[38;5;241m!=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mn_players:\n",
"File \u001B[1;32mC:\\1_Workspaces\\1_Phd_Projects\\shapiq\\venv\\lib\\site-packages\\numpy\\lib\\npyio.py:427\u001B[0m, in \u001B[0;36mload\u001B[1;34m(file, mmap_mode, allow_pickle, fix_imports, encoding, max_header_size)\u001B[0m\n\u001B[0;32m 425\u001B[0m own_fid \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[0;32m 426\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m--> 427\u001B[0m fid \u001B[38;5;241m=\u001B[39m stack\u001B[38;5;241m.\u001B[39menter_context(\u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mos_fspath\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfile\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mrb\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m)\n\u001B[0;32m 428\u001B[0m own_fid \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[0;32m 430\u001B[0m \u001B[38;5;66;03m# Code to distinguish from NumPy binary files and pickles.\u001B[39;00m\n",
"\u001B[1;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'cooking_game_values.npz'"
"Values stored before loading: []\n",
"Values stored after loading: [ 0. 4. 3. 2. 9. 8. 7. 15.]\n"
]
}
],
Expand All @@ -356,19 +346,42 @@
]
},
{
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-10T12:14:06.092763Z",
"start_time": "2025-01-10T12:14:06.077767Z"
}
},
"cell_type": "code",
"source": [
"# initialize a game object directly from precomputed values\n",
"game = shapiq.Game(path_to_values=\"data/cooking_game_values.npz\")\n",
"game = shapiq.Game(path_to_values=save_path)\n",
"print(game)\n",
"\n",
"# query the value function of the game for the same coalitions as before\n",
"coals = np.array([[0, 0, 0], [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1]])\n",
"game(coals)"
],
"outputs": [],
"execution_count": null
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Game(3 players, normalize=False, normalization_value=0.0, precomputed=True)\n"
]
},
{
"data": {
"text/plain": [
"array([ 0., 9., 8., 7., 15.])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 8
},
{
"metadata": {},
Expand All @@ -379,7 +392,12 @@
]
},
{
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-10T12:14:06.108755Z",
"start_time": "2025-01-10T12:14:06.095753Z"
}
},
"cell_type": "code",
"source": [
"print(cooking_game.characteristic_function)\n",
Expand All @@ -388,8 +406,17 @@
"except AttributeError as e:\n",
" print(\"AttributeError:\", e)"
],
"outputs": [],
"execution_count": null
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{(): 0, (0,): 4, (1,): 3, (2,): 2, (0, 1): 9, (0, 2): 8, (1, 2): 7, (0, 1, 2): 15}\n",
"AttributeError: 'Game' object has no attribute 'characteristic_function'\n"
]
}
],
"execution_count": 9
}
],
"metadata": {
Expand Down
352 changes: 184 additions & 168 deletions docs/source/notebooks/basics_notebooks/data_valuation.ipynb

Large diffs are not rendered by default.

Loading
Loading