Skip to content

Commit

Permalink
Add save and
Browse files Browse the repository at this point in the history
load_from_path methods to DiverseCF
  • Loading branch information
BirkhoffG committed Nov 8, 2023
1 parent 63db944 commit c79e07e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 13 deletions.
55 changes: 42 additions & 13 deletions nbs/methods/02_dice.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
Expand Down Expand Up @@ -139,7 +139,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_4795/3412149913.py:2: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\n",
"/tmp/ipykernel_5806/3412149913.py:2: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\n",
" return torch.from_numpy(x.__array__())\n"
]
}
Expand Down Expand Up @@ -169,7 +169,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"511 ms ± 61.5 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)\n"
"332 ms ± 3.7 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)\n"
]
}
],
Expand All @@ -187,7 +187,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"1.2 ms ± 536 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)\n"
"608 µs ± 101 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)\n"
]
}
],
Expand Down Expand Up @@ -304,6 +304,14 @@
" config = validate_configs(config, DiverseCFConfig)\n",
" name = \"DiverseCF\" if name is None else name\n",
" super().__init__(config, name=name)\n",
" \n",
" def save(self, path: str):\n",
" self.config.save(Path(path) / 'config.json')\n",
" \n",
" @classmethod\n",
" def load_from_path(cls, path: str):\n",
" config = DiverseCFConfig.load_from_json(Path(path) / 'config.json')\n",
" return cls(config=config)\n",
"\n",
" @auto_reshaping('x', reshape_output=False)\n",
" def generate_cf(\n",
Expand Down Expand Up @@ -364,7 +372,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f3eb72c256634815963126154d1c4c61",
"model_id": "5c8f4cde3e854d1dbee759ff11102211",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -378,7 +386,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7954e440e039461d8e1098e917b99dc6",
"model_id": "4077af1ed0504f08b2301aae100776f8",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -398,7 +406,7 @@
}
],
"source": [
"dcf = DiverseCF()\n",
"dcf = DiverseCF({'lambda_2': 4.0})\n",
"dcf.set_apply_constraints_fn(dm.apply_constraints)\n",
"dcf.set_compute_reg_loss_fn(dm.compute_reg_loss)\n",
"cf = dcf.generate_cf(xs_test[0], model.pred_fn, rng_key=jrand.PRNGKey(0))\n",
Expand All @@ -423,7 +431,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "87a9ab5c201b4469b4a562aac443a053",
"model_id": "f046536532a5420aa92b63eb9a8330ae",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -436,16 +444,34 @@
}
],
"source": [
"exp = relax.generate_cf_explanations(\n",
" dcf, dm, model.pred_fn\n",
")"
"dcf.save('tmp/dice/')\n",
"dcf_1 = DiverseCF.load_from_path('tmp/dice/')\n",
"dcf_1.set_apply_constraints_fn(dm.apply_constraints)\n",
"partial_gen_1 = ft.partial(dcf_1.generate_cf, pred_fn=model.pred_fn)\n",
"cfs_1 = jax.vmap(partial_gen_1)(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), xs_test.shape[0]))\n",
"\n",
"assert jnp.allclose(cfs, cfs_1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "db1e254a77a64c89be50e4905e2ae949",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
Expand Down Expand Up @@ -479,15 +505,15 @@
" <th>DiverseCF</th>\n",
" <td>0.983</td>\n",
" <td>1.0</td>\n",
" <td>1.803543</td>\n",
" <td>1.264458</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" acc validity proximity\n",
"dummy DiverseCF 0.983 1.0 1.803543"
"dummy DiverseCF 0.983 1.0 1.264458"
]
},
"execution_count": null,
Expand All @@ -496,6 +522,9 @@
}
],
"source": [
"exp = relax.generate_cf_explanations(\n",
" dcf, dm, model.pred_fn\n",
")\n",
"relax.benchmark_cfs([exp])"
]
}
Expand Down
3 changes: 3 additions & 0 deletions relax/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,9 @@
'relax/methods/dice.py'),
'relax.methods.dice.DiverseCF.generate_cf': ( 'methods/dice.html#diversecf.generate_cf',
'relax/methods/dice.py'),
'relax.methods.dice.DiverseCF.load_from_path': ( 'methods/dice.html#diversecf.load_from_path',
'relax/methods/dice.py'),
'relax.methods.dice.DiverseCF.save': ('methods/dice.html#diversecf.save', 'relax/methods/dice.py'),
'relax.methods.dice.DiverseCFConfig': ('methods/dice.html#diversecfconfig', 'relax/methods/dice.py'),
'relax.methods.dice._diverse_cf': ('methods/dice.html#_diverse_cf', 'relax/methods/dice.py'),
'relax.methods.dice.dpp_style_vmap': ('methods/dice.html#dpp_style_vmap', 'relax/methods/dice.py')},
Expand Down
8 changes: 8 additions & 0 deletions relax/methods/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ def __init__(self, config: dict | DiverseCF = None, *, name: str = None):
config = validate_configs(config, DiverseCFConfig)
name = "DiverseCF" if name is None else name
super().__init__(config, name=name)

def save(self, path: str):
self.config.save(Path(path) / 'config.json')

@classmethod
def load_from_path(cls, path: str):
config = DiverseCFConfig.load_from_json(Path(path) / 'config.json')
return cls(config=config)

@auto_reshaping('x', reshape_output=False)
def generate_cf(
Expand Down

0 comments on commit c79e07e

Please sign in to comment.