From c79e07e6087630b67be53592d10b004e8ebcd716 Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Wed, 8 Nov 2023 10:42:22 -0500 Subject: [PATCH] Add save and load_from_path methods to DiverseCF --- nbs/methods/02_dice.ipynb | 55 ++++++++++++++++++++++++++++++--------- relax/_modidx.py | 3 +++ relax/methods/dice.py | 8 ++++++ 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/nbs/methods/02_dice.ipynb b/nbs/methods/02_dice.ipynb index 50ebedd..de60475 100644 --- a/nbs/methods/02_dice.ipynb +++ b/nbs/methods/02_dice.ipynb @@ -44,7 +44,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" ] } ], @@ -139,7 +139,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_4795/3412149913.py:2: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\n", + "/tmp/ipykernel_5806/3412149913.py:2: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\n", " return torch.from_numpy(x.__array__())\n" ] } @@ -169,7 +169,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "511 ms ± 61.5 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)\n" + "332 ms ± 3.7 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)\n" ] } ], @@ -187,7 +187,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "1.2 ms ± 536 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)\n" + "608 µs ± 101 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)\n" ] } ], @@ -304,6 +304,14 @@ " config = validate_configs(config, DiverseCFConfig)\n", " name = \"DiverseCF\" if name is None else name\n", " super().__init__(config, name=name)\n", + " \n", + " def save(self, path: str):\n", + " self.config.save(Path(path) / 'config.json')\n", + " \n", + " @classmethod\n", + " def load_from_path(cls, path: str):\n", + " config = DiverseCFConfig.load_from_json(Path(path) / 'config.json')\n", + " return cls(config=config)\n", "\n", " @auto_reshaping('x', reshape_output=False)\n", " def generate_cf(\n", @@ -364,7 +372,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f3eb72c256634815963126154d1c4c61", + "model_id": "5c8f4cde3e854d1dbee759ff11102211", "version_major": 2, "version_minor": 0 }, @@ -378,7 +386,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7954e440e039461d8e1098e917b99dc6", + "model_id": "4077af1ed0504f08b2301aae100776f8", "version_major": 2, "version_minor": 0 }, @@ -398,7 +406,7 @@ } ], "source": [ - "dcf = DiverseCF()\n", + "dcf = DiverseCF({'lambda_2': 4.0})\n", "dcf.set_apply_constraints_fn(dm.apply_constraints)\n", "dcf.set_compute_reg_loss_fn(dm.compute_reg_loss)\n", "cf = dcf.generate_cf(xs_test[0], model.pred_fn, rng_key=jrand.PRNGKey(0))\n", @@ -423,7 +431,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "87a9ab5c201b4469b4a562aac443a053", + "model_id": "f046536532a5420aa92b63eb9a8330ae", "version_major": 2, "version_minor": 0 }, @@ -436,9 +444,13 @@ } ], "source": [ - "exp = relax.generate_cf_explanations(\n", - " dcf, dm, model.pred_fn\n", - ")" + "dcf.save('tmp/dice/')\n", + "dcf_1 = DiverseCF.load_from_path('tmp/dice/')\n", + "dcf_1.set_apply_constraints_fn(dm.apply_constraints)\n", + "partial_gen_1 = ft.partial(dcf_1.generate_cf, pred_fn=model.pred_fn)\n", + "cfs_1 = jax.vmap(partial_gen_1)(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), xs_test.shape[0]))\n", + "\n", + "assert jnp.allclose(cfs, cfs_1)" ] }, { @@ -446,6 +458,20 @@ "execution_count": null, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "db1e254a77a64c89be50e4905e2ae949", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -479,7 +505,7 @@ "