Skip to content

Commit

Permalink
Fix bug in indexing arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Nov 9, 2023
1 parent 51d55df commit e37ba40
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
12 changes: 11 additions & 1 deletion nbs/03_explain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
Expand Down Expand Up @@ -98,6 +98,9 @@
" else:\n",
" raise ValueError(f\"Unknown data name: {name}. Should be one of ['train', 'val', 'test']\")\n",
"\n",
" if isinstance(indices, list):\n",
" indices = jnp.array(indices)\n",
" \n",
" return {\n",
" 'xs': self.xs[indices],\n",
" 'ys': self.ys[indices],\n",
Expand Down Expand Up @@ -174,6 +177,13 @@
"exp = fake_explanation(n_cfs=1)\n",
"xs_shape = exp.xs.shape\n",
"assert exp.cfs.shape == (xs_shape[0], 1, xs_shape[-1])\n",
"train_exp = exp['train']\n",
"val_exp = exp['val']\n",
"test_exp = exp['test']\n",
"assert jnp.concatenate(\n",
" [train_exp['cfs'], val_exp['cfs']], axis=0\n",
").shape == exp.cfs.shape\n",
"assert test_exp['cfs'].shape == val_exp['cfs'].shape\n",
"\n",
"exp = fake_explanation(n_cfs=5)\n",
"assert exp.cfs.shape == (xs_shape[0], 5, xs_shape[-1])"
Expand Down
3 changes: 3 additions & 0 deletions relax/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __getitem__(self, name: Literal['train', 'val', 'test']) -> Dict[str, Array]
else:
raise ValueError(f"Unknown data name: {name}. Should be one of ['train', 'val', 'test']")

if isinstance(indices, list):
indices = jnp.array(indices)

return {
'xs': self.xs[indices],
'ys': self.ys[indices],
Expand Down

0 comments on commit e37ba40

Please sign in to comment.